In [None]:
%pip install -q transformers
%pip install -q git+https://github.com/cthiounn/dalle-tiny.git

In [None]:
from dalle_tiny.model import TinyDalleModel
from dalle_tiny.util import TinyDalleDataset
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.manual_seed_all(42)
torch.manual_seed(42)

training_data = TinyDalleDataset(parquet_file="https://github.com/cthiounn/dalle-tiny/raw/main/archive_train.parquet",dataset_type="train")
test_data = TinyDalleDataset(parquet_file="https://github.com/cthiounn/dalle-tiny/raw/main/archive_val.parquet",dataset_type="val")

train_dataloader = DataLoader(training_data, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=10, shuffle=True)


In [None]:
from tqdm import tqdm
from transformers import BartForConditionalGeneration
import gc

def freeze_params(model):
    for par in model.parameters():
        par.requires_grad = False

#del model
gc.collect()
torch.cuda.empty_cache()
model=TinyDalleModel.from_pretrained('facebook/bart-large-cnn')
model.reinit_lm_head_for_images()
model.final_logits_bias=torch.rand(16384)
model=model.to(device)
freeze_params(model.get_encoder())
model.train()
optimizer = optim.Adam(model.parameters(), lr=5e-5)
def loss_fn(logits, labels):
    loss = F.cross_entropy(logits, labels)
    loss = loss.mean()
    return loss

num_batches_test = len(test_dataloader)


for epoch in range(5):
    train_loss=0
    for batch in tqdm(train_dataloader):
        caption,label =batch
        inp=caption.to(device)
        lab=label[0].to(device)
        predict=model(input_ids=inp, decoder_input_ids =lab)
        
        optimizer.zero_grad()
        loss = loss_fn(predict.logits, F.one_hot(lab,16384).type(torch.cuda.FloatTensor))
        loss.backward()
        train_loss+=loss.item()
        optimizer.step()
        print(f"train_loss by batch:{loss.item()}")

    test_loss=0
    with torch.no_grad():
        for batch in tqdm(test_dataloader): 
            caption,label =batch
            inp=caption.to(device)
            lab=label[0].to(device)
            predict=model(input_ids=inp, decoder_input_ids =lab)
            loss = loss_fn(predict.logits, F.one_hot(lab,16384).type(torch.cuda.FloatTensor))
            test_loss += loss.item()
        mean_test_loss=test_loss/num_batches_test
        print(f"mean test loss:{mean_test_loss}")