In [None]:
# Import comet_ml at the top of your file
import comet_ml
from pytorch_lightning.loggers import CometLogger
import sys 
sys.path.append('..')



In [None]:
import os
from src.datasets.mimic_cxr_dataset import MIMICCXRDataModule
from pytorch_lightning import Trainer, seed_everything
import torchvision.transforms as T
import torch
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
seed_everything(42)


In [None]:

augmentations = {'train':
    T.Compose(
    [
        T.Resize((224, 224)),
        T.RandomApply([T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)], p=0.8),
        T.RandomGrayscale(p=0.2),
        T.GaussianBlur(kernel_size=9),
        T.ToTensor(),
    ]),
    'val':
    T.Compose(
    [
        T.Resize((224, 224)),
        T.RandomApply([T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1)], p=0.8),
        T.RandomGrayscale(p=0.2),
        T.GaussianBlur(kernel_size=9),
        T.ToTensor(),
    ])
}


In [None]:
BATCH_SIZE = 32
NUM_EPOCHS = 200
LIMIT_NUM_SAMPLES = 64
NUM_DATA_WORKERS  = 0
ONLY_IMAGES = False

In [None]:
data_path = '/Users/caghankoksal/Desktop/development/Flamingo-playground/physionet.org/files/mimic-cxr/2.0.0/files/'
mimic_datamodule = MIMICCXRDataModule(data_path, transforms=augmentations, only_images=False, batch_size=BATCH_SIZE,
                                limit_num_samples=LIMIT_NUM_SAMPLES, num_data_workers=NUM_DATA_WORKERS, tokenizer="gpt2")
train_loader = mimic_datamodule.train_dataloader()
val_loader = mimic_datamodule.val_dataloader()

In [None]:
len(mimic_datamodule.train_dataset)

In [None]:
len(train_loader)

In [None]:
for batch in train_loader:
    print("Image shape : ",batch["image"].shape)
    print(batch["text"][0]) # First sample in the batch
    print("Input Ids shape : ",batch["input_ids"].shape)
    print("Targets shape : ",batch["targets"].shape)
    break

In [None]:
from src.models.multimodal.flamingo_module import FlamingoModule
import pytorch_lightning as pl

In [None]:
VOCAB_SIZE_OF_TOKENIZER = mimic_datamodule.train_dataset.tokenizer.vocab_size
VOCAB_SIZE_OF_TOKENIZER
LANGUAGE_MODEL = 'gpt2'
NUM_TOKENS = VOCAB_SIZE_OF_TOKENIZER+3 if LANGUAGE_MODEL=="gpt2" else 31092
FLAMINGO_EMBED_DIM = 768
DEPTH = 12
NUM_HEADS = 8
ATT_HEAD_DIM = 64
CROOS_ATT_EVERY=3
MEDIA_TOKEN_ID = mimic_datamodule.train_dataset.tokenizer.all_special_ids[mimic_datamodule.train_dataset.tokenizer.all_special_tokens.index('<image>')]
PERCEIVER_NUM_LATENTS = 64
PERCEIVER_DEPTH = 2
IMAGE_ENCODER = "clip"
PRETRAINED_CLIP_PATH = '/Users/caghankoksal/Desktop/development/PubMedCLIP_ViT32.pth'
PRETRAINED_GPT2_PATH = "/Users/caghankoksal/Desktop/development/TransformerPlay/gpt2-pytorch_model.bin"

print("LANGUAGE_MODEL : ",LANGUAGE_MODEL, "\n"
        "NUM_TOKENS : ",NUM_TOKENS, "\n"
        "FLAMINGO_EMBED_DIM : ",FLAMINGO_EMBED_DIM, "\n"
        "DEPTH : ",DEPTH, "\n"
        "NUM_HEADS : ",NUM_HEADS, "\n"
        "ATT_HEAD_DIM : ",ATT_HEAD_DIM, "\n"
        "CROOS_ATT_EVERY : ",CROOS_ATT_EVERY, "\n"
        "MEDIA_TOKEN_ID : ",MEDIA_TOKEN_ID, "\n"
        "PERCEIVER_NUM_LATENTS : ",PERCEIVER_NUM_LATENTS, "\n"
        "PERCEIVER_DEPTH : ",PERCEIVER_DEPTH, "\n"
        "IMAGE_ENCODER : ",IMAGE_ENCODER, "\n"
        "PRETRAINED_CLIP_PATH : ",PRETRAINED_CLIP_PATH, "\n"
        "PRETRAINED_GPT2_PATH : ",PRETRAINED_GPT2_PATH, "\n")

In [None]:
model = FlamingoModule(pretrained_clip_path = PRETRAINED_CLIP_PATH,
                      total_steps=100, num_tokens = NUM_TOKENS,
                      dim=FLAMINGO_EMBED_DIM, depth=DEPTH, heads=NUM_HEADS, dim_head=ATT_HEAD_DIM,
                      media_token_id=MEDIA_TOKEN_ID, cross_attn_every=CROOS_ATT_EVERY,
                      perceiver_num_latents = PERCEIVER_NUM_LATENTS, perceiver_depth = PERCEIVER_DEPTH,
                      image_encoder =IMAGE_ENCODER, language_model = LANGUAGE_MODEL,
                      pretrained_gpt2_path=PRETRAINED_GPT2_PATH
                        )



In [None]:
COMET_API_KEY = "F2L19mQwKXSoeF1IYEDA2AeHD",
PROJECT_KEY = "flamingo-playground",
import os
from pytorch_lightning.loggers import CometLogger

comet_logger = CometLogger(
    api_key= "F2L19mQwKXSoeF1IYEDA2AeHD",
    project_name="flamingo-gpt2")


In [None]:
from pytorch_lightning.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='step')

In [None]:
from pytorch_lightning import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger(save_dir="pll_logs/")



In [None]:
trainer = pl.Trainer(max_epochs=6,deterministic=True,
                     accelerator="cpu", devices=1,
                     logger=[tb_logger,comet_logger],
                     callbacks=[lr_monitor],
                     log_every_n_steps=1,
                     )



In [None]:
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
comet_logger.experiment.end()

In [None]:
!tensorboard --logdir=pll_logs/

In [None]:

import torchxrayvision as xrv
image_encoder = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb")

In [None]:
os.getcwd()