In [1]:
import sys 
sys.path.append('..')

In [2]:
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import pytorch_lightning as pl
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,PreTrainedTokenizerFast, GPT2Tokenizer
from PIL import Image
from tqdm import tqdm as tqdm
import torchvision.transforms as T

from src.datasets.imageclef_dataset import ImageCLEF2021DataModule

In [3]:

augmentations = {
    
    'train': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        #T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                        ]),
    'val': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),
    'test': T.Compose([T.Resize((224,224)),
                        T.ToTensor(),
                        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]),
}

In [4]:
dataset_params = {
    "root": "/Users/caghankoksal/Desktop/imageclef/",
    "batch_size": 32,
    "tokenizer": "gpt2",
    "return_size": False,
    "num_data_workers": 0,
    "limit_num_samples" : 128
}
imageclef_datamodule = ImageCLEF2021DataModule(**dataset_params,transforms=augmentations,)

In [5]:
train_loader = imageclef_datamodule.train_dataloader()
val_loader = imageclef_datamodule.val_dataloader()

In [6]:
for batch in train_loader:
    cur_images = batch['image']
    print(cur_images.shape )
    break

torch.Size([32, 3, 224, 224])


In [7]:
# placeholders
psum    = torch.tensor([0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0])

# loop through images
for batch in tqdm(train_loader):
    psum    += batch["image"].sum(axis = [0, 2, 3])
    psum_sq += (batch["image"] ** 2).sum(axis = [0, 2, 3])

###### FINAL CALCULATIONS

# pixel count
count = len(imageclef_datamodule.train_dataset) * 224 * 224

# mean and std
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('mean: '  + str(total_mean))
print('std:  '  + str(total_std))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


100%|██████████| 141/141 [00:36<00:00,  3.89it/s]

mean: tensor([-0.9954, -0.8883, -0.6623])
std:  tensor([1.1835, 1.2106, 1.2051])





In [None]:
batch["image"].shape

In [None]:
from src.models.multimodal.flamingo_module import FlamingoModule
import pytorch_lightning as pl

In [None]:
VOCAB_SIZE_OF_TOKENIZER = 50257 # mimic_datamodule.train_dataset.tokenizer.vocab_size
LANGUAGE_MODEL = 'gpt2'
NUM_TOKENS = VOCAB_SIZE_OF_TOKENIZER +3 if LANGUAGE_MODEL=="gpt2" else 31092
FLAMINGO_EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 8
ATT_HEAD_DIM = 64
CROOS_ATT_EVERY=3
MEDIA_TOKEN_ID = imageclef_datamodule.train_dataset.tokenizer.all_special_ids[imageclef_datamodule.train_dataset.tokenizer.all_special_tokens.index('<image>')]
PERCEIVER_NUM_LATENTS = 64
PERCEIVER_DEPTH = 2
IMAGE_ENCODER = "clip"
PRETRAINED_CLIP_PATH = '/Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth'
PRETRAINED_GPT2_PATH = "/Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin"


In [None]:
model_hyperparams = {
    'pretrained_clip_path': PRETRAINED_CLIP_PATH,
    'warmup_steps': 569,
    'num_tokens': NUM_TOKENS,
    'dim': FLAMINGO_EMBED_DIM,
    'depth': DEPTH,
    'num_heads': NUM_HEADS,
    'dim_head': ATT_HEAD_DIM,
    'cross_attn_every': CROOS_ATT_EVERY,
    'media_token_id': MEDIA_TOKEN_ID,
    'perceiver_num_latents': PERCEIVER_NUM_LATENTS,
    'perceiver_depth': PERCEIVER_DEPTH,
    'image_encoder': IMAGE_ENCODER,
    'language_model': LANGUAGE_MODEL,
    'pretrained_gpt2_path': PRETRAINED_GPT2_PATH,
}

for k,v in model_hyperparams.items():
    print(f"{k}: {v}")

In [None]:
model = FlamingoModule(**model_hyperparams)

In [None]:
from pytorch_lightning.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(max_epochs=6,
                     accelerator="cpu", devices=1,
                     callbacks=[lr_monitor],
                     log_every_n_steps=1,
                      )



In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
CHECKPOINT_PATH = "/Users/caghankoksal/Desktop/SS2022/lightning_logs/version_21/checkpoints/epoch=56-val_loss=0.39-other_metric=0.00.ckpt"

In [None]:
model.load_state_dict(torch.load(CHECKPOINT_PATH,map_location=torch.device('cpu'))["state_dict"])
model.eval()

In [None]:
val_img = imageclef_datamodule.val_dataset[3]["image"]
val_qa_pair= imageclef_datamodule.val_dataset[3]["qa_pair"]
val_qa_pair.split('answer')
val_question =  imageclef_datamodule.val_dataset[3]["question"]
val_question


In [None]:
val_answer =  imageclef_datamodule.val_dataset[3]["answer"]
val_answer

In [None]:
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
from transformers import GPT2Tokenizer
import numpy as np

from torch import nn as nn
import torch.nn.functional as F
def generate(image,context, cur_model, ntok=20):
    for _ in range(ntok):
        out = cur_model({'image': image,'input_ids': context })
        logits = out[:, -1, :]
        indices_to_remove = logits < torch.topk(logits, 10)[0][..., -1, None]
        logits[indices_to_remove] = np.NINF
        next_tok = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1).squeeze(1)
        context = torch.cat([context, next_tok.unsqueeze(-1)], dim=-1)
    return context


tokenizer = imageclef_datamodule.train_dataset.tokenizer

In [None]:
dataset_params = {
    "root": "/Users/caghankoksal/Desktop/imageclef/",
    "batch_size": 1,
    "tokenizer": "gpt2",
    "return_size": False,
    "num_data_workers": 0,
    "limit_num_samples" : None
}
imageclef_datamodule = ImageCLEF2021DataModule(**dataset_params,transforms=augmentations,)

In [None]:
val_dataloader = imageclef_datamodule.val_dataloader()

In [None]:
val_loader_iter = iter(val_dataloader)

In [None]:
batch = next(val_loader_iter)

In [None]:
batch["question"][0]

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])


In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])

In [None]:
batch = next(val_loader_iter)
context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+batch["question"][0] + ' answer:')]) 
out = generate(val_img.unsqueeze(0), context, model, ntok=20)
print("Model out : ",tokenizer.decode(out[0]))
print("Correct Answer: " + batch["answer"][0])