In [4]:
%load_ext autoreload
%autoreload 2
import itertools
import math

import torch
from torch import nn, Tensor
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

import clip
from transformers import BertTokenizer, BertModel

from src.data.diffusion_db_module import DiffusionDBModule
from src.models.transformer import TransformerImg2Prompt

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Load the pretrained models
clip_model, preprocess = clip.load("ViT-B/32")
# For some reason, the weights in the CLIP model are automatically converted
# to float16. We convert it back to float32
clip_model = clip_model.float()

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Load the dataset
dm = DiffusionDBModule(batch_size=4,
                       subset_name="large_first_1k",
                       img_transform=preprocess,
                       bert_tokenizer=bert_tokenizer)
dm.setup("train")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Found cached dataset diffusiondb (/home/minhduc0711/.cache/huggingface/datasets/poloclub___diffusiondb/large_first_1k/0.9.1/547894e3a57aa647ead68c9faf14832

In [7]:
emsize = 768  # embedding dimension (of BERT)
d_hid = 500  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 10  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 6  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability

model = TransformerImg2Prompt(clip_model, bert_model,
    emsize, nhead, d_hid, nlayers, dropout)

logger = TensorBoardLogger("training_logs", name="transformers")
ckpt_callback = ModelCheckpoint(dirpath="model_ckpts/transformers",
                                filename="epoch={epoch}-step={step}-val_loss={val/CE_loss:.6f}",
                                save_top_k=1,
                                monitor="val/CE_loss",
                                mode="min",
                                auto_insert_metric_name=False)
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
early_stop_callback = EarlyStopping(monitor="val/CE_loss", 
                                    min_delta=0.01, patience=5, verbose=False, mode="min")

# NOTE: change accelerator if running on GPU
trainer = pl.Trainer(max_epochs=100, accelerator="cpu",
                     logger=logger,
                     log_every_n_steps=30,
                     callbacks=[ckpt_callback, lr_monitor_callback, early_stop_callback])
trainer.fit(model=model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Found cached dataset diffusiondb (/home/minhduc0711/.cache/huggingface/datasets/poloclub___diffusiondb/large_first_1k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 210.29it/s][A

  | Name        | Type        | Params
--------------------------------------------
0 | clip_model  | CLIP        | 151 M 
1 | bert_model  | BertModel   | 109 M 
2 | transformer | Transformer | 54.8 M
3 | fc_img      | Linear      | 393 K 
--------------------------------------------
315 M     Trainable params
0         Non-trainable params
315 M     Total params
1,263.888 Total estimated model params size (MB)


                                                                                                                                         



Epoch 0:   4%|███▏                                                                   | 9/200 [00:27<09:49,  3.09s/it, loss=8.39, v_num=7]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
