In [21]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import pytorch_lightning as pl

import torchxrayvision as xrv
from torchvision import transforms

import nltk
from nltk.tokenize import word_tokenize
# import tensorflow as tf
# from tensorflow.keras.preprocessing.text import Tokenizer

from PIL import Image

from typing import Optional
from pathlib import Path, PurePath
import pandas as pd

In [22]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:1


Create Dataset

In [23]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [24]:
class VqaRadDataset(Dataset):
    def __init__(self, root, split:str = 'train',
                 transform=None, tokenizer=None):
        
        self.root = root
        self.transform = transform
        
        assert split in ['train', 'val', 'test']
        self.split = split
        if split == "train":
            self.annotations = pd.read_csv(os.path.join(root,"vqa_rad_train.csv"))
        elif split == "val":
            self.annotations = pd.read_csv(os.path.join(root,"vqa_rad_valid.csv"))
        elif split == "test":
            self.annotations = pd.read_csv(os.path.join(root,"vqa_rad_test.csv"))

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        cur_ann = self.annotations.iloc[idx]
        image_name = cur_ann["img_id"]
        answer = cur_ann["answer"]
        mode = cur_ann["mode"]
        question = cur_ann["question"]
        
        #####
        ## Image Processing
        #####
        # Img path of the given question
        cur_image_path = PurePath(self.root, "imgs", image_name).as_posix()
        
        img = Image.open(cur_image_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)
            
        #####
        ## Text Processing
        #####
        question_tokens = word_tokenize(question)
        answer_tokens = word_tokenize(answer)
        tokens = np.concatenate([["<BOS>"],["<image>"],["Question:"],question_tokens,["Answer:"],answer_tokens,["<EOC>"]]).ravel()
        
        #text = f"<BOS> <image> Question: {question} Answer: {answer} <EOC>"
        #tokens = word_tokenize(text)
        
        sample = {'image': img, 'text': tokens}
        return sample

Create Classifier

In [25]:
import numpy as np
import matplotlib.pyplot as plt
from flamingo_pytorch import FlamingoPaLM

In [26]:
from torch.utils.tensorboard import SummaryWriter

class FlamingoModule(pl.LightningModule):
    def __init__(self, image_encoder, target):
        super().__init__()
        self.image_encoder = image_encoder
        self.model = FlamingoPaLM(
                        num_tokens = 15185,          # number of tokens
                        dim = 18,                  # dimensions
                        depth = 12,                  # depth
                        heads = 8,                   # attention heads
                        dim_head = 64,               # dimension per attention head
                        img_encoder = image_encoder, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
                        media_token_id = 3,          # the token id representing the [media] or [image]
                        cross_attn_every = 3,        # how often to cross attend
                        perceiver_num_latents = 16,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
                        perceiver_depth = 2          # perceiver resampler depth
                    )

        # TODO DEFINE LOSS cross entropy?
        self.loss = torch.nn.CrossEntropyLoss()
        # elif:
            # self.loss =
        
        self.target = target

        self.train_preds = []
        self.train_gts = []
        self.val_preds = []
        self.val_gts = []
        self.test_preds = []
        self.test_gts = []
        self.reset_metrics()

        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

        self.writer = SummaryWriter()

    def training_step(self, batch, batch_idx):

        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()

        loss = self.loss(logits.float(), y)
        self.log('train/loss', loss, on_step=False, on_epoch=True)
        return {'loss': loss}


    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.

        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()
        loss = self.loss(logits.squeeze(), y)
        self.update_metrics(text, logits, split='val')
        self.val_loss.append(loss.item())
        return {'loss': loss}

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()
        loss = self.loss(logits.squeeze(), y)
        self.update_metrics(text, logits, split='test')
        self.test_loss.append(loss.item())
        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    def reset_metrics(self, split=None):
        if split == 'train':
            self.train_preds = []
            self.train_gts = []
        elif split == 'val':
            self.val_preds = []
            self.val_gts = []
        elif split == 'test':
            self.test_preds = []
            self.test_gts = []
        else:
            self.train_preds = []
            self.train_gts = []
            self.val_preds = []
            self.val_gts = []
            self.test_preds = []
            self.test_gts = []

    def update_metrics(self, gt, pred, split='train'):
        if split == 'train':
            self.train_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.train_gts.extend(gt.detach().cpu().numpy())
        elif split == 'val':
            self.val_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.val_gts.extend(gt.detach().cpu().numpy())
        elif split == 'test':
            self.test_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.test_gts.extend(gt.detach().cpu().numpy())
        else:
            raise NotImplementedError()

    def training_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.writer.add_scalar('Epoch_loss/train', loss, self.current_epoch)
        self.reset_metrics(split='train')

    def validation_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.log('val_loss', loss)
        self.writer.add_scalar('Epoch_loss/validation', loss, self.current_epoch)
        self.reset_metrics(split='val')
    
    def test_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.writer.add_scalar('Epoch_loss/test', loss, self.current_epoch)
        self.reset_metrics(split='test')

### Load Data

In [31]:
path = PurePath(r"D:\Dev\PythonProjects\Project\emic-vqa\data\external\vqa_rad")
path

PureWindowsPath('D:/Dev/PythonProjects/Project/emic-vqa/data/external/vqa_rad')

In [32]:
root = path.as_posix()
print(root)

D:/Dev/PythonProjects/Project/emic-vqa/data/external/vqa_rad


In [33]:
def transform(img):
    img = xrv.datasets.normalize(img, 255) 

    # Check that images are 2D arrays
    if len(img.shape) > 2:
        img = img[:, :, 0]
    if len(img.shape) < 2:
        print("error, dimension lower than 2 for image")

    # Add color channel
    img = img[None, :, :]

    transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                                xrv.datasets.XRayResizer(224)])

    img = transform(img)

In [41]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                     xrv.datasets.XRayResizer(224),])

# train_dataset = VqaRadDataset(root, split='train',transform=pil_transform)
# val_dataset = VqaRadDataset(root, split='val',transform=pil_transform)
# test_dataset = VqaRadDataset(root, split='test',transform=pil_transform)

train_dataset = VqaRadDataset(root, split='train')
val_dataset = VqaRadDataset(root, split='val')
test_dataset = VqaRadDataset(root, split='test')

In [42]:
print('Train Dataset Size:', len(train_dataset))
print('Validation Dataset Size:', len(val_dataset))
print('Test Dataset Size:', len(test_dataset))

Train Dataset Size: 1797
Validation Dataset Size: 226
Test Dataset Size: 225


In [43]:
sample = train_dataset[3]
print(sample)

{'image': tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039],
         [0.0039, 0.0039, 0.0039,  ..., 0.0039, 0.0039, 0.0039]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 

# Data Processing

In [44]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)


image_encoder = xrv.models.DenseNet(weights="densenet121-res224-all")
model = FlamingoModule(image_encoder, target='flamingo')

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O C:\Users\alaed\.torchxrayvision\models_data\nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


TypeError: Can't instantiate abstract class FlamingoModule with abstract methods forward

Train PaLM

In [6]:
trainer = pl.Trainer(gpus=[1], max_epochs=100, num_sanity_val_steps=0) 
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | image_encoder | DenseNet         | 7.0 M 
1 | model         | FlamingoPaLM     | 7.8 M 
2 | loss          | CrossEntropyLoss | 0     
---------------------------------------------------
806 K     Trainable params
7.0 M     Non-trainable params
7.8 M     Total params
31.091    Total estimated model params size (MB)


Epoch 2:   1%|          | 12/1147 [00:56<1:28:21,  4.67s/it, loss=0.186, v_num=3]  

In [None]:
loss_values = model.val_loss
plt.plot(loss_values)
plt.title("Validation loss")
plt.show()

In [None]:
# save language model
torch.save(model.state_dict(), 'PaLM_mimic_BPETokenizer_bs32_Adam.pt')

Create Training Setup For Flamingo + PaLM Training

In [None]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                    xrv.datasets.XRayResizer(224),])
                                    
train_dataset = MimicDatset(target='flamingo', split='train', pil_transform=pil_transform)
val_dataset = MimicDatset(target='flamingo', split='val', pil_transform=pil_transform)
test_dataset = MimicDatset(target='flamingo', split='test', pil_transform=pil_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

Test Model

In [None]:
trainer.test(model, dataloaders=test_loader)