In [2]:
import torch
import torch.nn.functional as F
import torchvision
import pytorch_lightning as pl
import torchxrayvision as xrv
import numpy as np
import pydicom as dicom
import csv
import pickle as pkl

from torch import nn
from typing import Optional
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path

Create Dataset

In [3]:
# Run this only once to create the data dictionaries, DATA SET SPECIFIC
import multiprocessing as mp
from tqdm import tqdm
from multiprocessing import Process, Manager

dataset_path = Path('/mnt/209C31C29C3192F0/Datasets/Mimic-CXR/physionet.org/files/mimic-cxr/2.0.0/')

# be careful here, loading both amounts to ~10GB of memory
image_list = list(csv.DictReader(open(dataset_path / 'cxr-record-list.csv', 'r'), delimiter=','))         # list of dictionaries
report_list = list(csv.DictReader(open(dataset_path / 'cxr-study-list.csv', 'r'), delimiter=','))           # list of dictionaries

dataset_length = len(image_list)

n_proc = 6
offset = 0

chunksize = dataset_length // n_proc
proc_slices = []

for i_proc in range(n_proc):
        chunkstart = int(offset + (i_proc * chunksize))
        # make sure to include the division remainder for the last process
        chunkend = int(offset + (i_proc + 1) * chunksize) if i_proc < n_proc - 1 else int(offset + dataset_length)
        proc_slices.append(np.s_[chunkstart:chunkend])

print(f'Number of slices: {len(proc_slices)}\n{proc_slices}')

def process(data, slice, rank):                    # split it up into slices
    # preprocess reports and images
    # iterate through images and find corresponding report
    for image in tqdm(image_list[slice]):
        image_path = image['path']
        id = image['study_id']
        # find corresponding report
        report_path = None
        for report in report_list:
            if id in report['study_id']:
                report_path = report['path']
        
        if report is None:
                print(f'Not found report for image with path {image_path}, study id: {id} and dicom')
                continue

        entry = {'subject_id': report['subject_id'],
                'study_id': id,
                'dicom_id': image['dicom_id'],
                'report_path': report_path,
                'image_path': image_path,
                }

        data.append(entry)


data = Manager().list()
processes = []
for i in range(n_proc):
        p = Process(target=process, args=(data, proc_slices[i], i))  # Passing the list
        p.start()
        processes.append(p)
for p in processes:
        p.join()


data_list = list(data)
print(len(data_list))
with open(dataset_path / 'images2reports.pkl', 'wb') as f:
        pkl.dump(data_list, f, pkl.DEFAULT_PROTOCOL)

Number of slices: 6
[slice(0, 62851, None), slice(62851, 125702, None), slice(125702, 188553, None), slice(188553, 251404, None), slice(251404, 314255, None), slice(314255, 377110, None)]


  0%|          | 87/62855 [00:02<34:52, 30.00it/s]


KeyboardInterrupt: 





Process Process-3:
  0%|          | 93/62851 [00:02<32:35, 32.10it/s]Process Process-6:
Process Process-4:
Process Process-2:
Traceback (most recent call last):

Process Process-7:
Traceback (most recent call last):
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Process Process-5:
  File "/tmp/ipykernel_10782/3684903847.py", line 37, in process
    if id in report['study_id']:
  File "/home/andrei/anaconda3/envs/ml

In [4]:
# visualize data
dataset_path = Path('/mnt/209C31C29C3192F0/Datasets/Mimic-CXR/physionet.org/files/mimic-cxr/2.0.0/')
image_list = list(csv.DictReader(open(dataset_path / 'cxr-record-list.csv', 'r'), delimiter=','))         # list of dictionaries
report_list = list(csv.DictReader(open(dataset_path / 'cxr-study-list.csv', 'r'), delimiter=','))           # list of dictionaries

print(image_list[0])
print(report_list[0])

{'subject_id': '10000032', 'study_id': '50414267', 'dicom_id': '02aa804e-bde0afdd-112c0b34-7bc16630-4e384014', 'path': 'files/p10/p10000032/s50414267/02aa804e-bde0afdd-112c0b34-7bc16630-4e384014.dcm'}
{'subject_id': '10000032', 'study_id': '50414267', 'path': 'files/p10/p10000032/s50414267.txt'}


In [10]:
from transformers import PreTrainedTokenizerFast, AutoTokenizer
import pydicom as dicom

class MimicDatset(Dataset):
    def __init__(self, target:str = 'palm', split:str = 'train', pil_transform: Optional[transforms.Compose] = None, tensor_transform: Optional[transforms.Compose] = None):
        self.target = target
        assert target in ['palm', 'flamingo']

        self.split = split
        assert split in ['train', 'val', 'test']

        self.pil_transform = pil_transform
        self.tensor_transform = tensor_transform
        
        self.dataset_path = Path('/mnt/209C31C29C3192F0/Datasets/Mimic-CXR/physionet.org/files/mimic-cxr/2.0.0/')

        with open(dataset_path / 'images2reports.pkl', 'rb') as f:
            data_list_pkl = pkl.load(f)

        self.data = data_list_pkl
        del data_list_pkl
        self.dataset_length = len(self.data)

        # self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(self.data_dir / 'mimic_tokenizer.json'), pad_token='<pad>')
        
    def __len__(self):                  # TODO figure this out, for now only limited data, TODO dynamic loading
        if self.target == 'palm':
            # return self.dataset_length
            return 10
        else:
            # return self.dataset_length
            return 2000


    def __getitem__(self, index):
        # don't forget about the target

        # get item and convert text into tokens
        data_sample = self.data[index]

        with open(self.dataset_path / data_sample['report_path'], 'r') as f:
            unfiltered_report = f.readlines()

        rep = ''.join(unfiltered_report)
        report = rep[(rep.find('FINDINGS:') + 9):rep.find('IMPRESSION:')].replace('\n','')

        # tokenize report ? 
        # report_tokenized = np.array(self.tokenizer(report))

        if self.target == 'palm':
            return torch.from_numpy(report).to(self.device)
        
        else:
            # get corresponding images
            image = np.array(dicom.dcmread(self.dataset_path / data_sample['image_path']).pixel_array[None, :, :])
            image = self.pil_transform(image)

            # maybe normalize?
            # xrv.datasets.normalize(image,maxval=np.max(image))

            return torch.from_numpy(report).to(self.device), torch.from_numpy(image).to(self.device)


Create Classifier

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from flamingo_pytorch import FlamingoPaLM
from transformers import BatchFeature, PreTrainedTokenizerFast, AutoTokenizer

In [12]:
class FlamingoModule(pl.LightningModule):
    def __init__(self, image_encoder, target):
        super().__init__()
        self.image_encoder = image_encoder
        self.model = FlamingoPaLM(
                        num_tokens = 20000,          # number of tokens
                        dim = 18,                    # dimensions
                        depth = 12,                  # depth
                        heads = 8,                   # attention heads
                        dim_head = 64,               # dimension per attention head
                        img_encoder = self.image_encoder,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
                        media_token_id = 3,          # the token id representing the [media] or [image]
                        cross_attn_every = 3,        # how often to cross attend
                        perceiver_num_latents = 16,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
                        perceiver_depth = 2          # perceiver resampler depth
                    )

        # TODO DEFINE LOSS cross entropy? 
        self.loss = torch.nn.CrossEntropyLoss()
        self.target = target

        self.train_preds = []
        self.train_gts = []
        self.val_preds = []
        self.val_gts = []
        self.test_preds = []
        self.test_gts = []
        self.reset_metrics()

        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

    def training_step(self, batch, batch_idx):

        print(f'batch: {batch.shape}')

        model_output = None
        if self.target == 'palm':
            logits = self.model(
                text = batch.tokens.input_ids,
            )
        elif self.target == 'flamingo':
            logits = self.model(
                text = batch.tokens.input_ids,
                images = batch.images
            )
        else:
            raise NotImplementedError()

        print(f'model output: {logits.shape}')
        loss = self.loss(logits.squeeze(), y)
        self.log('train/loss', loss, on_step=False, on_epoch=True)
        return {'loss': loss, 'ACC': self.calculate_accuracy(logits).detach().cpu()}


    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.

        if self.target == 'palm':
            model_output = self.model(
                text = batch.tokens.input_ids,
            )
        elif self.target == 'flamingo':
            model_output = self.model(
                text = batch.tokens.input_ids,
                images = batch.images
            )
        else:
            raise NotImplementedError()
        loss = self.loss(y_hat.squeeze(), y)
        self.update_metrics(y, y_hat, split='val')
        self.val_loss.append(loss.item())
        return {'val_loss': loss}

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss(y_hat.squeeze(), y)
        self.update_metrics(y, y_hat, split='test')
        self.test_loss.append(loss.item())
        return {'test_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    def reset_metrics(self, split=None):
        if split == 'train':
            self.train_preds = []
            self.train_gts = []
        elif split == 'val':
            self.val_preds = []
            self.val_gts = []
        elif split == 'test':
            self.test_preds = []
            self.test_gts = []
        else:
            self.train_preds = []
            self.train_gts = []
            self.val_preds = []
            self.val_gts = []
            self.test_preds = []
            self.test_gts = []

    def update_metrics(self, gt, pred, split='train'):
        if split == 'train':
            self.train_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.train_gts.extend(gt.detach().cpu().numpy())
        elif split == 'val':
            self.val_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.val_gts.extend(gt.detach().cpu().numpy())
        elif split == 'test':
            self.test_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.test_gts.extend(gt.detach().cpu().numpy())
        else:
            raise NotImplementedError()

    def training_epoch_end(self, outputs):
        self.evaluate_predictions(split='train')
        self.reset_metrics(split='train')

    def validation_epoch_end(self, outputs):
        self.evaluate_predictions(split='val')
        self.reset_metrics(split='val')
    
    def test_epoch_end(self, outputs):
        self.evaluate_predictions(split='test')
        self.reset_metrics(split='test')

    def evaluate_predictions(self, split):
        if split == 'train':
            preds = self.train_preds
            gts = self.train_gts
        elif split == 'val':
            preds = self.val_preds
            gts = self.val_gts
        elif split == 'test':
            preds = self.test_preds
            gts = self.test_gts
        else:
            raise NotImplementedError()

        cls_report = classification_report(gts, preds)
        print(split)
        print(cls_report)


Create Training Setup For PaLM Training

In [13]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                    xrv.datasets.XRayResizer(224),])
train_dataset = MimicDatset(target='palm', split='train', pil_transform=pil_transform)
val_dataset = MimicDatset(target='palm', split='val', pil_transform=pil_transform)
test_dataset = MimicDatset(target='palm', split='test', pil_transform=pil_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

image_encoder = xrv.models.DenseNet(weights="densenet121-res224-mimic_nb")
model = FlamingoModule(image_encoder, target='palm')                            # maybe change target

Train PaLM

In [14]:
trainer = pl.Trainer(gpus=[1], max_epochs=10, num_sanity_val_steps=0)                    
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/andrei/TUM/SoSe2022/MLMI/emic-vqa/notebooks/playground/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type             | Params
---------------------------------------------------
0 | image_encoder | DenseNet         | 7.0 M 
1 | model         | FlamingoPaLM     | 7.9 M 
2 | loss          | CrossEntropyLoss | 0     
---------------------------------------------------
893 K     Trainable params
7.0 M     Non-trainable params
7.9 M     Total params
31.438    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/32 [00:00<?, ?it/s] 

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/andrei/anaconda3/envs/mlmi/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_10782/2541399743.py", line 41, in __getitem__
    with open(self.dataset_path / data_sample['report_path'], 'r') as f:
FileNotFoundError: [Errno 2] No such file or directory: '/mnt/209C31C29C3192F0/Datasets/Mimic-CXR/physionet.org/files/mimic-cxr/2.0.0/files/p16/p16669959/s57012984.txt'


In [16]:
s = '/wholebrain/scratch/amancu/mergeError/Nodes/Trainings/SegSmall_betterDataSet_r3000_ConvPoint_SearchQuantized_betterDataset_Adam_StepLR_CrossEntropy_Classification/state_dict.pth'
import os
os.path.basename(os.path.dirname(s))

'SegSmall_betterDataSet_r3000_ConvPoint_SearchQuantized_betterDataset_Adam_StepLR_CrossEntropy_Classification'

Create Training Setup For Flamingo + PaLM Training

In [None]:
pil_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(),
                                    xrv.datasets.XRayResizer(224),])
                                    
train_dataset = MimicDatset(target='flamingo', split='train', pil_transform=pil_transform)
val_dataset = MimicDatset(target='flamingo', split='val', pil_transform=pil_transform)
test_dataset = MimicDatset(target='flamingo', split='test', pil_transform=pil_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4)

Test Model

In [None]:
trainer.test(model, dataloaders=test_loader)