In [1]:
import sys 
sys.path.append('..')
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:

import os
import pytorch_lightning as pl
from src.models.multimodal.flamingo_module import FlamingoModule
from src.datasets.imageclef_dataset import ImageCLEF2021DataModule
from src.utils.utils import load_flamingo_weights, print_hyperparams

from pytorch_lightning import Trainer, seed_everything
import torchvision.transforms as T
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image
from transformers import GPT2Tokenizer
import numpy as np

from torch import nn as nn
import torch.nn.functional as F

In [3]:
augmentations = {
        
        'train': T.Compose([T.Resize((224,224)),
                            T.ToTensor(),
                            T.RandomHorizontalFlip(p=0.5),
                            T.RandomRotation(degrees=10),
                            T.Normalize(mean=(0.2570, 0.2570, 0.2570), std=(0.2710, 0.2710, 0.2710))
                            ]),
        'val': T.Compose([T.Resize((224,224)),
                            T.ToTensor(),
                            T.RandomHorizontalFlip(p=0.5),
                            T.RandomRotation(degrees=10),
                            T.Normalize(mean=(0.2570, 0.2570, 0.2570), std=(0.2710, 0.2710, 0.2710))
                            ]),
        'test': T.Compose([T.Resize((224,224)),
                            T.ToTensor(),
                            T.RandomHorizontalFlip(p=0.5),
                            T.RandomRotation(degrees=10),   
                            T.Normalize(mean=(0.2570, 0.2570, 0.2570), std=(0.2710, 0.2710, 0.2710))
                            ])
    }

    
# Hyperparameters
NUM_DATA_WORKERS  = 2
ONLY_IMAGES = False
BATCH_SIZE = 1
NUM_EPOCHS = 120
LIMIT_NUM_SAMPLES = None
DATASET = "IMAGECLEF"
LOAD_TRAINED_IMAGECLEF = True


if os.getcwd().startswith('/home/mlmi-matthias'):
    ACCELERATOR = "gpu"
    DEVICES = [6,7]
    PRETRAINED_CLIP_PATH = '/home/mlmi-matthias/Caghan/pretrained_models/PubMedCLIP_ViT32.pth'
    PRETRAINED_GPT2_PATH = "/home/mlmi-matthias/Caghan/pretrained_models/gpt2-pytorch_model.bin"
    MIMIC_CXR_DCM_PATH = '/home/mlmi-matthias/physionet.org/files/mimic-cxr/2.0.0/files/'
    MIMIC_CXR_JPG_PATH = "/home/mlmi-matthias/physionet.org/files/mimic-cxr-jpg/2.0.0/files/"
    SPLIT_PATH = '/home/mlmi-matthias/Caghan/mlmi-vqa/data/external/'
    IMAGECLEF_PATH ='/home/mlmi-matthias/imageclef/'
    #CHECKPOINT_PATH = "/home/mlmi-matthias/Caghan/mlmi-vqa/notebooks/lightning_logs/version_20/checkpoints/epoch=114-val_loss=0.84-other_metric=0.00.ckpt"
    # Latest ROCO Training 
    CHECKPOINT_PATH ="/home/mlmi-matthias/Caghan/mlmi-vqa/notebooks/lightning_logs/version_77/checkpoints/epoch=61-val_loss_generation_epoch=1.80.ckpt"
    ANSWERS_LIST_PATH = '/home/mlmi-matthias/Caghan/mlmi-vqa//data/external/answer_list_imageclef.txt'
    IMAGECLEF_CHECKPOINT_PATH = "/home/mlmi-matthias/Caghan/mlmi-vqa/notebooks/lightning_logs/version_101/checkpoints/last.ckpt"


elif os.getcwd().startswith('/Users/caghankoksal'):
    PRETRAINED_CLIP_PATH = '/Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth'
    PRETRAINED_GPT2_PATH = "/Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin"
    ACCELERATOR = "cpu"
    DEVICES = 1
    MIMIC_CXR_DCM_PATH = '/Users/caghankoksal/Desktop/development/Flamingo-playground/physionet.org/files/mimic-cxr/2.0.0/files/'
    MIMIC_CXR_JPG_PATH = '/Users/caghankoksal/Desktop/development/physionet.org/files/mimic-cxr-jpg/2.0.0/files/'
    SPLIT_PATH = '/Users/caghankoksal/Desktop/SS2022/mlmi-vqa/data/external/'
    IMAGECLEF_PATH = "/Users/caghankoksal/Desktop/imageclef/"
    CHECKPOINT_PATH = "/Users/caghankoksal/Desktop/SS2022/lightning_logs/version_77/checkpoints/epoch=66-val_loss_generation_epoch=1.80.ckpt"
    ANSWERS_LIST_PATH = '/Users/caghankoksal/Desktop/SS2022/mlmi-vqa/data/external/answer_list_imageclef.txt'
    IMAGECLEF_CHECKPOINT_PATH = "/Users/caghankoksal/Desktop/SS2022/lightning_logs/version_102/checkpoints/last.ckpt"


IMAGE_TYPE = "jpg"
TOKENIZER  = "gpt2"
PREPROCESSED = True
RETURN_IDX_EOC = True

dataset_hyperparameters = {
    "root": IMAGECLEF_PATH,
    "batch_size": BATCH_SIZE,
    "tokenizer": TOKENIZER,
    "num_data_workers": NUM_DATA_WORKERS,
    "return_size": False,
    "answers_list_path": ANSWERS_LIST_PATH,
    "return_idx_answer_eoc": RETURN_IDX_EOC,
    "transforms": augmentations,
    "limit_num_samples": LIMIT_NUM_SAMPLES,
}


datamodule = ImageCLEF2021DataModule(**dataset_hyperparameters)


train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

print("Len training dataset : ", len(datamodule.train_dataset),
    "Batch Size : ", BATCH_SIZE, "NUM_EPOCHS : ",NUM_EPOCHS )
print("Total training steps : ", len(datamodule.train_dataset)//BATCH_SIZE*NUM_EPOCHS)


Len training dataset :  4500 Batch Size :  1 NUM_EPOCHS :  120
Total training steps :  540000


In [4]:
# MODEL HPRAMS
VOCAB_SIZE_OF_TOKENIZER = 50257 # mimic_datamodule.train_dataset.tokenizer.vocab_size
LANGUAGE_MODEL = 'gpt2'
NUM_TOKENS = VOCAB_SIZE_OF_TOKENIZER +4 if LANGUAGE_MODEL=="gpt2" else 31092
FLAMINGO_EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 8
ATT_HEAD_DIM = 64
CROOS_ATT_EVERY=3
MEDIA_TOKEN_ID = datamodule.train_dataset.tokenizer.\
    all_special_ids[datamodule.train_dataset.tokenizer.all_special_tokens.index('<image>')]
PERCEIVER_NUM_LATENTS = 64
PERCEIVER_DEPTH = 2
IMAGE_ENCODER = "clip"
CLASSIFICATION_MODE = False
NUM_CLASSES = 332
FLAMINGO_MODE = True
LABEL_SMOOTHING = 0.5
# Label smoothing for classification task
TOKEN_LABEL_SMOOTHING = 0.0
GRADIENT_CLIP_VAL = 1
LEARNING_RATE = 1e-4
USE_IMAGE_EMBEDDINGS = True
TRAIN_EMBEDDING_LAYER = True
CLASSIFIER_DROPOUT = 0.5


hyperparams = {
    'pretrained_clip_path': PRETRAINED_CLIP_PATH,
    'warmup_steps': 30,
    'num_tokens': NUM_TOKENS,
    'dim': FLAMINGO_EMBED_DIM,
    'depth': DEPTH,
    'num_heads': NUM_HEADS,
    'dim_head': ATT_HEAD_DIM,
    'cross_attn_every': CROOS_ATT_EVERY,
    'media_token_id': MEDIA_TOKEN_ID,
    'perceiver_num_latents': PERCEIVER_NUM_LATENTS,
    'perceiver_depth': PERCEIVER_DEPTH,
    'image_encoder': IMAGE_ENCODER,
    'language_model': LANGUAGE_MODEL,
    'pretrained_gpt2_path': PRETRAINED_GPT2_PATH,
    'classification_mode': CLASSIFICATION_MODE,
    'classification_num_classes': NUM_CLASSES,  # 332 if DATASET=="IMAGECLEF"
    'flamingo_mode': FLAMINGO_MODE,
    "label_smoothing": LABEL_SMOOTHING,
    "token_label_smoothing": TOKEN_LABEL_SMOOTHING,
    "learning_rate":LEARNING_RATE,
    "use_image_embeddings": USE_IMAGE_EMBEDDINGS,
    "train_embedding_layer": TRAIN_EMBEDDING_LAYER,
    "classifier_dropout": CLASSIFIER_DROPOUT
    }

print_hyperparams(hyperparams)

model = FlamingoModule(**hyperparams)
START_FROM_CHECKPOINT = True



INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /var/folders/61/9c2llh9n2pjb81c4dmhmb67w0000gn/T/tmpe0ztr82p
INFO:torch.distributed.nn.jit.instantiator:Writing /var/folders/61/9c2llh9n2pjb81c4dmhmb67w0000gn/T/tmpe0ztr82p/_remote_module_non_sriptable.py


pretrained_clip_path /Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth
warmup_steps 30
num_tokens 50261
dim 768
depth 12
num_heads 8
dim_head 64
cross_attn_every 3
media_token_id 50258
perceiver_num_latents 64
perceiver_depth 2
image_encoder clip
language_model gpt2
pretrained_gpt2_path /Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin
classification_mode False
classification_num_classes 332
flamingo_mode True
label_smoothing 0.5
token_label_smoothing 0.0
learning_rate 0.0001
use_image_embeddings True
train_embedding_layer True
classifier_dropout 0.5
Clip architecture is being loaded
Clip pretrained weights are being loaded
Flamingo is being initialized with  gpt2  as language model
GPT 2 Weights are loading...
Loaded GPT2 weights and Embeddings num_weights loaded :  156


In [5]:
IMAGECLEF_CHECKPOINT_PATH = "/Users/caghankoksal/Desktop/SS2022/lightning_logs/version_110/checkpoints/last.ckpt"

In [6]:
if LOAD_TRAINED_IMAGECLEF:
    print("Pretrained Flamingo Model is loaded from checkpoint : ",IMAGECLEF_CHECKPOINT_PATH)
    if os.getcwd().startswith('/home/mlmi-matthias'):
        model.load_state_dict(torch.load(IMAGECLEF_CHECKPOINT_PATH)["state_dict"],strict=False)
    else:
        model.load_state_dict(torch.load(IMAGECLEF_CHECKPOINT_PATH,map_location=torch.device('cpu'))["state_dict"],strict=False)
        print("Checkpoint Weights are loaded")

Pretrained Flamingo Model is loaded from checkpoint :  /Users/caghankoksal/Desktop/SS2022/lightning_logs/version_110/checkpoints/last.ckpt
Checkpoint Weights are loaded


In [7]:

def generate_gradio(image, context, cur_model, ntok=20):
    for _ in range(ntok):
        out= cur_model({'image': image,'input_ids': context})
        logits = out[:, -1, :]
        indices_to_remove = logits < torch.topk(logits, 10)[0][..., -1, None]
        logits[indices_to_remove] = np.NINF
        #next_tok1 = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1).squeeze(1)
        #print(next_tok1.shape)
        softmax_out = F.softmax(logits, dim=-1)
        #print(softmax_out.shape)
        next_tok = torch.argmax(softmax_out,dim=-1,keepdim=False)
        #print(next_tok.shape)
        context = torch.cat([context, next_tok.unsqueeze(-1)], dim=-1)
    return context


tokenizer = datamodule.train_dataset.tokenizer

In [8]:
def predict_gradio(image, question):
    print("Input question")
    process_img = augmentations["val"](image).unsqueeze(0)
    print("Process_img succesfull")
    context   = torch.tensor([tokenizer.encode("<|endoftext|> <image> question: "+question + ' <EOQ>'+ ' answer:')]) 
    out = generate_gradio( process_img,context, model, ntok=20)
    #print("Model's answer : ",tokenizer.decode(out[0]).split('answer:')[1].split('<EOC>')[0])
    result = tokenizer.decode(out[0]).split('answer:')[1].split('<EOC>')[0]
    return result

    

In [9]:
import gradio as gr




title = "Visual_Question_Answering on Medical Data ImageCLEF"
description = "Gradio Demo for Medical Visual_Question_Answering. Upload your own image (high-resolution images are recommended) or click any one of the examples, and click " \
              "\"Submit\" and then wait for our's answer. "
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
          "Repo</a></p> "
examples = [['demo_images/synpic16279.jpg','what is abnormal in the x-ray?' ], ['demo_images/synpic17959.jpg',  'what is most alarming about this mri?']]
io = gr.Interface(fn=predict_gradio, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"),
                 examples=examples,title=title, description=description,
                  debug=False)
io.launch(share=True)

Running on local URL:  http://127.0.0.1:7884/


INFO:paramiko.transport:Connected (version 2.0, client OpenSSH_7.6p1)
INFO:paramiko.transport:Authentication (publickey) successful!


Running on public URL: https://45285.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)


(<gradio.routes.App at 0x29d6f6ee0>,
 'http://127.0.0.1:7884/',
 'https://45285.gradio.app')