In [1]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import torchxrayvision as xrv
import numpy as np
import pydicom as dicom
import pickle as pkl

from torch import nn
from typing import Optional
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


Create Dataset

In [17]:
import pydicom as dicom
from transformers import BatchFeature, PreTrainedTokenizerFast
import mpu.ml

class MimicDatset(Dataset):
    def __init__(self, target:str = 'palm', split:str = 'train', pil_transform: Optional[transforms.Compose] = None, tensor_transform: Optional[transforms.Compose] = None):
        self.target = target
        assert target in ['palm', 'flamingo']

        self.split = split
        assert split in ['train', 'val', 'test']

        self.pil_transform = pil_transform
        self.tensor_transform = tensor_transform
        
        self.dataset_path = Path('/mnt/209C31C29C3192F0/Datasets/Mimic-CXR/physionet.org/files/mimic-cxr/2.0.0/')

        with open(self.dataset_path / 'images2reports.pkl', 'rb') as f:
            data_list_pkl = pkl.load(f)

        self.data = data_list_pkl
        del data_list_pkl
        self.dataset_length = len(self.data)

        self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(self.dataset_path / 'tokenizer_mimic.json'), pad_token='[PAD]')

        # max length of text tokens
        self.max_length = 512
        
    def __len__(self):                  # TODO figure this out, for now only limited data, TODO dynamic loading
        
        if self.split == "palm":
            match self.split:
                case 'train':
                    return int(0.9 * self.dataset_length)
                case 'val':
                    return int(0.1 * self.dataset_length)
                case 'test':
                    return int(0.2 * self.dataset_length)
        else:
            match self.split:
                case 'train':
                    return int(0.8 * self.dataset_length)
                case 'val':
                    return int(0.2 * self.dataset_length)
                case 'test':
                    return int(0.2 * self.dataset_length)
        # return 10



    def __getitem__(self, index):
        # don't forget about the target

        # get item and convert text into tokens
        # data_sample = self.data[index]

        # because I don't have the entire dataset on my computer, find a sample to load without read error 
        while True:
            try:
                data_sample = self.data[index]
                with open(self.dataset_path / data_sample['report_path'], 'r') as f:
                    report_text = f.readlines()
                    
                # TODO make this faster with batch_encode_plus()
                report_tokenized = self.tokenizer.encode_plus(report_text[0], padding=True, truncation=True, max_length=self.max_length)
                batch_ids = report_tokenized['input_ids']

                # manually pad for batch stack
                batch_ids = [batch_ids + [self.tokenizer.pad_token_id] * (self.max_length - len(batch_ids))]
                # print(f'{batch_ids}')
                batch_one_hot = mpu.ml.indices2one_hot(batch_ids[0], nb_classes=15185)
                image = np.array(dicom.dcmread(self.dataset_path / data_sample['image_path']).pixel_array[None, :, :])
                break
            except:
                index = np.random.randint(0,self.dataset_length)


        batch_mask = report_tokenized['attention_mask']
        batch_type_ids = report_tokenized['token_type_ids']
        # print(f'report: {report_tokenized}')

        if self.target == 'palm':
            return torch.LongTensor(np.array(batch_ids)).squeeze(), torch.FloatTensor(np.array(batch_one_hot)).squeeze()
        
        else:
            # get corresponding images
            image = self.pil_transform(image)

            # maybe normalize?
            xrv.datasets.normalize(image, maxval=np.max(image))

            # add temporal channel to images
            return torch.LongTensor(np.array(batch_ids)).squeeze(), torch.from_numpy(image[None,:,:,:]), torch.FloatTensor(np.array(batch_one_hot)).squeeze()


Create Classifier

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from flamingo_pytorch import FlamingoPaLM

In [4]:
from torch.utils.tensorboard import SummaryWriter

class FlamingoModule(pl.LightningModule):
    def __init__(self, image_encoder, target):
        super().__init__()
        self.image_encoder = image_encoder
        self.model = FlamingoPaLM(
                        num_tokens = 15185,          # number of tokens
                        dim = 18,                  # dimensions
                        depth = 12,                  # depth
                        heads = 8,                   # attention heads
                        dim_head = 64,               # dimension per attention head
                        img_encoder = image_encoder, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
                        media_token_id = 3,          # the token id representing the [media] or [image]
                        cross_attn_every = 3,        # how often to cross attend
                        perceiver_num_latents = 16,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
                        perceiver_depth = 2          # perceiver resampler depth
                    )

        # TODO DEFINE LOSS cross entropy?
        self.loss = torch.nn.CrossEntropyLoss()
        # elif:
            # self.loss =
        
        self.target = target

        self.train_preds = []
        self.train_gts = []
        self.val_preds = []
        self.val_gts = []
        self.test_preds = []
        self.test_gts = []
        self.reset_metrics()

        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

        self.writer = SummaryWriter()

    def training_step(self, batch, batch_idx):

        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()

        loss = self.loss(logits.float(), y)
        self.log('train/loss', loss, on_step=False, on_epoch=True)
        return {'loss': loss}


    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.

        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()
        loss = self.loss(logits.squeeze(), y)
        self.update_metrics(text, logits, split='val')
        self.val_loss.append(loss.item())
        return {'loss': loss}

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        if self.target == 'palm':
            text, y = batch
            logits = self.model(
                text = text,
            )
        elif self.target == 'flamingo':
            text, image, y = batch
            logits = self.model(
                text = text,
                images = image
            )
        else:
            raise NotImplementedError()
        loss = self.loss(logits.squeeze(), y)
        self.update_metrics(text, logits, split='test')
        self.test_loss.append(loss.item())
        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    def reset_metrics(self, split=None):
        if split == 'train':
            self.train_preds = []
            self.train_gts = []
        elif split == 'val':
            self.val_preds = []
            self.val_gts = []
        elif split == 'test':
            self.test_preds = []
            self.test_gts = []
        else:
            self.train_preds = []
            self.train_gts = []
            self.val_preds = []
            self.val_gts = []
            self.test_preds = []
            self.test_gts = []

    def update_metrics(self, gt, pred, split='train'):
        if split == 'train':
            self.train_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.train_gts.extend(gt.detach().cpu().numpy())
        elif split == 'val':
            self.val_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.val_gts.extend(gt.detach().cpu().numpy())
        elif split == 'test':
            self.test_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.test_gts.extend(gt.detach().cpu().numpy())
        else:
            raise NotImplementedError()

    def training_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.writer.add_scalar('Epoch_loss/train', loss, self.current_epoch)
        self.reset_metrics(split='train')

    def validation_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.log('val_loss', loss)
        self.writer.add_scalar('Epoch_loss/validation', loss, self.current_epoch)
        self.reset_metrics(split='val')
    
    def test_epoch_end(self, output):
        loss = 0
        for o in output:
            loss +=  o['loss'].item()
        loss = loss / len(output)
        self.writer.add_scalar('Epoch_loss/test', loss, self.current_epoch)
        self.reset_metrics(split='test')

Create Training Setup For PaLM Training

In [5]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                    xrv.datasets.XRayResizer(224),])
train_dataset = MimicDatset(target='palm', split='train', pil_transform=pil_transform)
val_dataset = MimicDatset(target='palm', split='val', pil_transform=pil_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)

image_encoder = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb")
model = FlamingoModule(image_encoder, target='palm')                            # maybe change target

Train PaLM

In [12]:
trainer = pl.Trainer(gpus=[1], max_epochs=100, num_sanity_val_steps=0) 
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | image_encoder | DenseNet         | 7.0 M 
1 | model         | FlamingoPaLM     | 7.8 M 
2 | loss          | CrossEntropyLoss | 0     
---------------------------------------------------
806 K     Trainable params
7.0 M     Non-trainable params
7.8 M     Total params
31.091    Total estimated model params size (MB)


Epoch 6:  45%|████▌     | 520/1147 [30:30<36:47,  3.52s/it, loss=0.186, v_num=4]   

In [None]:
loss_values = model.val_loss
plt.plot(loss_values)
plt.title("Validation loss")
plt.show()

In [None]:
# save language model
torch.save(model.state_dict(), 'PaLM_mimic_BPETokenizer_bs32_Adam.pt')

Create Training Setup For Flamingo + PaLM Training

In [18]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                    xrv.datasets.XRayResizer(224),])
                                    
train_dataset = MimicDatset(target='flamingo', split='train', pil_transform=pil_transform)
val_dataset = MimicDatset(target='flamingo', split='val', pil_transform=pil_transform)
test_dataset = MimicDatset(target='flamingo', split='test', pil_transform=pil_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, pin_memory=True, num_workers=4)

In [19]:
image_encoder = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb")
model = FlamingoModule(image_encoder, target='flamingo')                            # maybe change target
model.load_state_dict(torch.load('/home/andrei/mlmi/home/mlmi-matthias/Andrei/mlmi-vqa/models/PaLM_mimic_BPETokenizer_unfiltered_bs32_Adam.pt'))

<All keys matched successfully>

In [20]:
trainer = pl.Trainer(gpus=[1], max_epochs=100, num_sanity_val_steps=0) 
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/andrei/mlmi/home/mlmi-matthias/Andrei/mlmi-vqa/src/models/multimodal/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | image_encoder | DenseNet         | 7.0 M 
1 | model         | FlamingoPaLM     | 7.8 M 
2 | loss          | CrossEntropyLoss | 0     
---------------------------------------------------
806 K     Trainable params
7.0 M     Non-trainable params
7.8 M     Total params
31.091    Total estimated model params size (MB)


Epoch 0:   4%|▍         | 48/1147 [06:20<2:25:13,  7.93s/it, loss=0.186, v_num=0]

Test Model

In [None]:
trainer.test(model, dataloaders=test_loader)