In [1]:
import sys 
sys.path.append('..')

In [2]:
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import pytorch_lightning as pl
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,PreTrainedTokenizerFast, GPT2Tokenizer
from PIL import Image
from tqdm import tqdm as tqdm
import torchvision.transforms as T

from src.datasets.imageclef_dataset import ImageCLEF2021DataModule

In [3]:

augmentations = {
    
    'train': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),
    'val': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),
    'test': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),
}

In [4]:
dataset_params = {
    "root": "/Users/caghankoksal/Desktop/imageclef/",
    "batch_size": 32,
    "tokenizer": "gpt2",
    "return_size": False,
    "num_data_workers": 0,
    "limit_num_samples" : 128
}

In [5]:
imageclef_datamodule = ImageCLEF2021DataModule(**dataset_params,transforms=augmentations,)

In [6]:
train_loader = imageclef_datamodule.train_dataloader()
val_loader = imageclef_datamodule.val_dataloader()

In [7]:
from src.models.multimodal.flamingo_module import FlamingoModule
import pytorch_lightning as pl

In [8]:
VOCAB_SIZE_OF_TOKENIZER = 50257 # mimic_datamodule.train_dataset.tokenizer.vocab_size
LANGUAGE_MODEL = 'gpt2'
NUM_TOKENS = VOCAB_SIZE_OF_TOKENIZER +3 if LANGUAGE_MODEL=="gpt2" else 31092
FLAMINGO_EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 8
ATT_HEAD_DIM = 64
CROOS_ATT_EVERY=3
MEDIA_TOKEN_ID = imageclef_datamodule.train_dataset.tokenizer.all_special_ids[imageclef_datamodule.train_dataset.tokenizer.all_special_tokens.index('<image>')]
PERCEIVER_NUM_LATENTS = 64
PERCEIVER_DEPTH = 2
IMAGE_ENCODER = "clip"
PRETRAINED_CLIP_PATH = '/Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth'
PRETRAINED_GPT2_PATH = "/Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin"


In [9]:
model_hyperparams = {
    'pretrained_clip_path': PRETRAINED_CLIP_PATH,
    'warmup_steps': 569,
    'num_tokens': NUM_TOKENS,
    'dim': FLAMINGO_EMBED_DIM,
    'depth': DEPTH,
    'num_heads': NUM_HEADS,
    'dim_head': ATT_HEAD_DIM,
    'cross_attn_every': CROOS_ATT_EVERY,
    'media_token_id': MEDIA_TOKEN_ID,
    'perceiver_num_latents': PERCEIVER_NUM_LATENTS,
    'perceiver_depth': PERCEIVER_DEPTH,
    'image_encoder': IMAGE_ENCODER,
    'language_model': LANGUAGE_MODEL,
    'pretrained_gpt2_path': PRETRAINED_GPT2_PATH,
}

for k,v in model_hyperparams.items():
    print(f"{k}: {v}")

pretrained_clip_path: /Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth
warmup_steps: 569
num_tokens: 50260
dim: 768
depth: 12
num_heads: 8
dim_head: 64
cross_attn_every: 3
media_token_id: 50258
perceiver_num_latents: 64
perceiver_depth: 2
image_encoder: clip
language_model: gpt2
pretrained_gpt2_path: /Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin


In [10]:
model = FlamingoModule(**model_hyperparams)

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /var/folders/61/9c2llh9n2pjb81c4dmhmb67w0000gn/T/tmp02saf68j
INFO:torch.distributed.nn.jit.instantiator:Writing /var/folders/61/9c2llh9n2pjb81c4dmhmb67w0000gn/T/tmp02saf68j/_remote_module_non_sriptable.py


Pretrained clip is being loaded
Flamingo is being initialized with  gpt2  as language model
GPT 2 Weights are loading...
Loaded GPT2 weights and Embeddings num_weights loaded :  156


In [11]:
from pytorch_lightning.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(max_epochs=6,
                     accelerator="cpu", devices=1,
                     callbacks=[lr_monitor],
                     log_every_n_steps=1,
                      )



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


  | Name          | Type          | Params
------------------------------------------------
0 | flamingo_palm | FlamingoModel | 249 M 
------------------------------------------------
76.8 M    Trainable params
172 M     Non-trainable params
249 M     Total params
998.888   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]