# Training of Text-to-Image network translation

In [3]:
import torch
import lightning as L
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer

from datasets import EncodedMNIST
from cINN import ConditionalRealNVP
from autoencoder import AutoencoderSimple
from utils import get_best_device

## Load pretrained models

In [4]:
embedding_model = SentenceTransformer("intfloat/multilingual-e5-small")

ae_path = "./models/ae_100.pth"
autoencoder = AutoencoderSimple()
autoencoder.load_state_dict(torch.load(ae_path, map_location=get_best_device()))

<All keys matched successfully>

## Load data

In [7]:
train_data = EncodedMNIST(autoencoder=autoencoder, embedding_model=embedding_model, train=True)

TRAIN_SIZE = int(0.8 * len(train_data))
VAL_SIZE = len(train_data) - TRAIN_SIZE

train_data, val_data = torch.utils.data.random_split(train_data, [TRAIN_SIZE, VAL_SIZE])

train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
val_loader = DataLoader(val_data, batch_size=100, shuffle=True)

Encoding:   0%|          | 0/600 [00:00<?, ?it/s]

Encoding: 100%|██████████| 600/600 [01:03<00:00,  9.40it/s]


## Train cINN

In [8]:
device = get_best_device()
cinn = ConditionalRealNVP(input_size=64, hidden_size=128, n_blocks=20, condition_size=384).to(device)

trainer = L.Trainer(max_epochs=100)
trainer.fit(model=cinn, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type       | Params
-----------------------------------------------
0 | coupling_blocks | ModuleList | 1.6 M 
-----------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.252     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\luke\anaconda3\envs\gnn\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:492: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
c:\Users\luke\anaconda3\envs\gnn\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!