# Training of Text-to-Image network translation

In [1]:
import torch
import lightning as L
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer

from datasets import EncodedMNIST
from cINN import ConditionalRealNVP
from autoencoder import AutoencoderSimple
from utils import get_best_device

## Load pretrained models

In [2]:
embedding_model = SentenceTransformer("intfloat/multilingual-e5-small")

ae_path = "./models/ae_100.pth"
autoencoder = AutoencoderSimple()
autoencoder.load_state_dict(torch.load(ae_path, map_location=get_best_device()))

<All keys matched successfully>

## Load data

In [3]:
train_data = EncodedMNIST(autoencoder=autoencoder, embedding_model=embedding_model, train=True)

TRAIN_SIZE = int(0.8 * len(train_data))
VAL_SIZE = len(train_data) - TRAIN_SIZE

train_data, val_data = torch.utils.data.random_split(train_data, [TRAIN_SIZE, VAL_SIZE])

train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
val_loader = DataLoader(val_data, batch_size=100, shuffle=True)

Encoding: 100%|██████████| 600/600 [02:21<00:00,  4.24it/s]


## Train cINN

In [5]:
cinn = ConditionalRealNVP(input_size=64, hidden_size=128, n_blocks=20, condition_size=384)

trainer = L.Trainer(max_epochs=100)
trainer.fit(model=cinn, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type          | Params
------------------------------------------------------
0 | coupling_blocks     | ModuleList    | 1.6 M 
1 | orthogonal_matrices | ParameterList | 77.8 K
------------------------------------------------------
1.6 M     Trainable params
77.8 K    Non-trainable params
1.6 M     Total params
6.563     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.