# Training of Text-to-Image network translation

In [2]:
import torch
import lightning as L
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer

from datasets import CustomMNIST
from cINN import ConditionalRealNVP
from autoencoder import AutoencoderSimple
from utils import get_best_device

## Load pretrained models

In [3]:
embedding_model = SentenceTransformer("intfloat/multilingual-e5-small")

ae_path = "./models/ae_100.pth"
autoencoder = AutoencoderSimple()
autoencoder.load_state_dict(torch.load(ae_path, map_location=get_best_device()))

<All keys matched successfully>

## Load data

In [4]:
train_data = CustomMNIST(train=True)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)

### Encode dataset

In [5]:
encoded_dataset = []

for batch_imgs, batch_labels in tqdm(train_loader, desc='Encoding'):
    encoded_imgs = autoencoder.encoder(batch_imgs)
    batch_labels = [str(label) for label in batch_labels]
    encoded_lables = embedding_model.encode(batch_labels, convert_to_tensor=True)
    encoded_dataset.extend(zip(encoded_imgs, encoded_lables))

Encoding: 100%|██████████| 600/600 [01:19<00:00,  7.53it/s]


## Train cINN

In [10]:
cinn = ConditionalRealNVP(input_size=64, hidden_size=128, n_blocks=20, condition_size=384)
trainer = L.Trainer(max_epochs=100)
trainer.fit(model=cinn, )

torch.Size([64])