In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import clip
from transformers import BertTokenizer, BertModel

from src.data.diffusion_db_module import DiffusionDBModule

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Load the pretrained models
clip_model, preprocess = clip.load("ViT-B/32")
# For some reason, the weights in the CLIP model are automatically converted
# to float16. We convert it back to float32
clip_model = clip_model.float()

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# Load the dataset
dm = DiffusionDBModule(batch_size=32,
                       subset_name="large_first_1k",
                       img_transform=preprocess,
                       bert_tokenizer=bert_tokenizer)
dm.setup("train")

Found cached dataset diffusiondb (/home/minhduc0711/.cache/huggingface/datasets/poloclub___diffusiondb/large_first_1k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 81.15it/s]


In [None]:
from src.models import Seq2SeqDecoder
model = Seq2SeqDecoder(clip_model, bert_model)

logger = TensorBoardLogger("training_logs", name="seq2seq")
ckpt_callback = ModelCheckpoint(dirpath="model_ckpts/seq2seq",
                                filename="epoch={epoch}-step={step}-val_loss={val/mse_loss:.6f}",
                                save_top_k=2,
                                monitor="val/mse_loss",
                                mode="min",
                                auto_insert_metric_name=False)

# NOTE: change accelerator if running on GPU
trainer = pl.Trainer(max_epochs=5, accelerator="cpu",
                     logger=logger,
                     log_every_n_steps=1,
                     callbacks=[ckpt_callback])
trainer.fit(model=model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset diffusiondb (/home/minhduc0711/.cache/huggingface/datasets/poloclub___diffusiondb/large_first_1k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 232.75it/s][A

  | Name       | Type    | Params
---------------------------------------
0 | clip_model | CLIP    | 151 M 
1 | decoder    | Decoder | 243 M 
---------------------------------------
394 M     Trainable params
0         Non-trainable params
394 M     Total params
1,579.390 Total estimated model params size (MB)


Epoch 0:   0%|                                                                                                   | 0/800 [00:00<?, ?it/s]