In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import stanza

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision.models import resnet50, ResNet50_Weights
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torchvision import transforms
from PIL import Image


In [None]:
SEED = 6464
DATASET_DIR = '../../../dataset'
MIMIC_JPG_DIR = f'{DATASET_DIR}/mimic-cxr-jpg-2.0.0-resized'
MODELS_DIR = '../../../models'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

Device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SubjectId = 'subject_id'
StudyId = 'study_id'
DicomId = 'dicom_id'
InputIds = 'input_ids'
TokenTypeIds = 'token_type_ids'
AttentionMask = 'attention_mask'
Labels = 'labels'

In [None]:
from pathlib import Path

def jpg_exists(row):
    path_str = f'{MIMIC_JPG_DIR}/p{str(row.subject_id)[:2]}/p{row.subject_id}/s{row.study_id}/{row.dicom_id}.jpg'
    return Path(path_str).exists()

def to_bin_to_dec(row):
    bin_str = ''.join(map(lambda v: str(int(v)), row.tolist()))
    return int(bin_str, 2)

def get_xr_labels_reports():
    df_labels = pd.read_csv(f'{DATASET_DIR}/metadata/mimic-cxr-2.0.0-chexpert.csv')
    df_labels = df_labels.replace(-1.0, 0.0)
    df_labels = df_labels.fillna(0)
    df_labels = df_labels.iloc[:, :-1]
    df_labels['multi_label'] = df_labels.iloc[:, 2:].apply(to_bin_to_dec, axis=1)
    df_labels['label'] = df_labels.multi_label.apply(lambda x: 1 if x > 0 else 0)
    df_labels = df_labels[[SubjectId, StudyId, 'label']]

    df_reports = pd.read_csv(f'{DATASET_DIR}/processed/mimic-cxr-reports/reports.csv')
    df_reports.report = df_reports.report.str.strip()
    df_reports = df_reports.sort_values([SubjectId, StudyId])
    df_reports.columns = [SubjectId, StudyId, 'sentence']

    df = pd.merge(df_reports, df_labels, on=[SubjectId, StudyId], how='inner')
    df = df[(df.sentence.str.split().apply(len) > 3) & (df.sentence != '')]

    return df

def get_xr_dicom_splits():
    df = pd.read_csv(f'{DATASET_DIR}/metadata/mimic-cxr-2.0.0-split.csv')
    df = df[[SubjectId, StudyId, DicomId, 'split']]
    return df

def get_xr_df():
    reports_df = get_xr_labels_reports()
    splits_df = get_xr_dicom_splits()
    merged_df = pd.merge(reports_df, splits_df, how='left', on=[SubjectId, StudyId])
    merged_df['found'] = merged_df.apply(jpg_exists, axis=1)
    merged_df = merged_df[merged_df.found]
    return merged_df.iloc[:, :-1].reset_index(drop=True)

def get_jpg_path(subject_id, study_id, dicom_id):
    return f'{MIMIC_JPG_DIR}/p{str(subject_id)[:2]}/p{subject_id}/s{study_id}/{dicom_id}.jpg'

def random_crop(img):
    w, h = img.size
    r = random.uniform(0.6, 1.0)
    w, h = int(w * r), int(h * r)
    return transforms.RandomCrop((h, w))


In [None]:
# (351894, 2835, 4637)
class ContrastiveDataset(Dataset):

    def __init__(self, df, mode='train', subset=None):
        if mode not in {'train', 'validate', 'test', 'test_val', 'all'}:
            raise KeyError('mode')
        if mode == 'test_val':
            df = df[(df.split == 'test') | (df.split == 'validate')]
        elif mode != 'all':
            df = df[df.split == mode]
        if subset is not None:
            if subset > 1:
                df = df.sample(n=subset).reset_index(drop=True)
            else:
                df = df.sample(frac=subset).reset_index(drop=True)
        self.df = df[[SubjectId, StudyId, 'sentence', 'label', DicomId]]
        self.tokenizer = AutoTokenizer.from_pretrained(f'{MODELS_DIR}/lib/emilyalsentzer/ClinicalBERT')
        self.nlp = stanza.Pipeline(lang='en', processors='tokenize')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        jpg_path = get_jpg_path(row[SubjectId], row[StudyId], row[DicomId])
        report = row['sentence']
        label = row['label']

        img = Image.open(jpg_path)
        transform = transforms.Compose([
            random_crop(img),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees=20, translate=(0.1, 0.1), scale=(0.95, 1.05)),
            transforms.ColorJitter(brightness=(0.6, 1.4), contrast=(0.6, 1.4)),
            transforms.GaussianBlur(kernel_size=(5,5), sigma=(0.1, 3.0)),
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        img = transform(img)

        text = random.sample(self.nlp(report).sentences, 1)[0].text
        tokenized = self.tokenizer(text, padding='max_length', truncation=True, max_length=299, return_tensors='pt')

        return [img, tokenized, torch.tensor(label, dtype=torch.int)]

    @staticmethod
    def collate(batch):
        imgs, texts, labels = zip(*batch)

        input_ids = torch.stack([m[InputIds] for m in texts]).squeeze(1)
        token_type_ids = torch.stack([m[TokenTypeIds] for m in texts]).squeeze(1)
        attention_mask = torch.stack([m[AttentionMask] for m in texts]).squeeze(1)
        texts = {InputIds: input_ids, TokenTypeIds: token_type_ids, AttentionMask: attention_mask}

        return torch.stack(imgs), texts, torch.stack(labels)



In [None]:
class ContrastiveModel(nn.Module):

    def __init__(self, encoding_size_d=512, tau=0.1, loss_weight=0.75):
        super(ContrastiveModel, self).__init__()

        self.encoding_size_d = encoding_size_d
        self.tau = tau
        self.loss_weight = loss_weight

        self.resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.resnet.fc = nn.Linear(in_features=self.resnet.fc.in_features, out_features=self.encoding_size_d)
        self.gv = nn.Linear(in_features=self.encoding_size_d, out_features=self.encoding_size_d)

        self.cbert = AutoModelForSequenceClassification.from_pretrained(f'{MODELS_DIR}/lib/emilyalsentzer/ClinicalBERT')
        for p in self.cbert.bert.embeddings.parameters():
            p.requires_grad = False
        for i in range(6):
            for p in self.cbert.bert.encoder.layer[i].parameters():
                p.requires_grad = False
        self.cbert.classifier = nn.Linear(
            in_features=self.cbert.classifier.in_features,
            out_features=self.encoding_size_d
        )
        self.gu = nn.Linear(in_features=self.encoding_size_d, out_features=self.encoding_size_d)

    def forward(self, img, text):

        v = F.relu(self.resnet(img))
        v = self.gv(v)

        u = F.relu(self.cbert(**text).logits)
        u = self.gu(u)

        v_sim_u = F.cosine_similarity(v, u) / self.tau
        v_sim_all_u = F.cosine_similarity(v.unsqueeze(1), u.unsqueeze(0)) / self.tau
        loss_vu = (-v_sim_u) + torch.logsumexp(v_sim_all_u, dim=-1)

        u_sim_v = F.cosine_similarity(u, v) / self.tau
        u_sim_all_v = F.cosine_similarity(u.unsqueeze(1), v.unsqueeze(0)) / self.tau
        loss_uv = (-u_sim_v) + torch.logsumexp(u_sim_all_v, dim=-1)

        L = (self.loss_weight * loss_vu + (1 - self.loss_weight) * loss_uv).mean()

        return L



In [None]:
class TrainingConfig():
    def __init__(
        self,
        train_sub=1.0,
        val_sub=1.0,
        test_sub=1.0,
        epochs=150,
        batch_size=32,
        lr=1e-4,
        weight_decay=1e-6,
        lr_sched_steps=5,
        lr_sched_gamma=0.5
    ):
        self.train_sub = train_sub
        self.val_sub = val_sub
        self.test_sub = test_sub
        self.epochs = epochs
        self.batch_size = batch_size
        self.lr = lr
        self.weight_decay = weight_decay if weight_decay is not None else 0.0
        self.lr_sched_steps = lr_sched_steps
        self.lr_sched_gamma = lr_sched_gamma

In [None]:
from tqdm.notebook import tqdm, trange

def train(config, batch_log_idx=10):
    model = ContrastiveModel()
    model.to(Device)
    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    xr_df = get_xr_df()
    train_dataset = ContrastiveDataset(xr_df, subset=config.train_sub)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, collate_fn=ContrastiveDataset.collate)
    val_dataset = ContrastiveDataset(xr_df, mode='test_val', subset=config.val_sub)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, collate_fn=ContrastiveDataset.collate)
    input_keys = {'input_ids', 'token_type_ids', 'attention_mask'}
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=config.lr_sched_gamma)
    val_epoch_losses = []
    min_val_epoch_loss = float('inf')

    for epoch in trange(config.epochs, desc='Epoch'):
        model.train()
        running_loss = 0.0
        batch_loss = 0.0
        batch_pbar = tqdm(total=len(train_dataset)//config.batch_size, desc=f'[{epoch}] Train', unit='batch')
        for i, x in enumerate(train_loader):
            img, text, labels = x
            img = img.to(Device)
            #labels = labels.to(Device)
            text = {k: vs.to(Device) for k, vs in text.items()}
            optimizer.zero_grad()
            loss = model(img, text)
            loss.backward()
            optimizer.step()
            batch_loss += loss.item()
            if (i + 1) % batch_log_idx == 0:
                print(f'Batch {i + 1}: train loss {(batch_loss)/batch_log_idx:.4f}')
                batch_loss = 0.0
            running_loss += loss.item() * img.size(0)
            batch_pbar.update(1)
        batch_pbar.close()
        epoch_loss = running_loss / len(train_dataset)
        print(f'Epoch {epoch + 1}/{config.epochs}: train loss {epoch_loss:.4f}')

        # validation
        model.eval()
        val_losses = []
        val_pbar = tqdm(total=len(val_dataset)//config.batch_size, desc=f'[{epoch}] Validation', unit='batch')
        for i, x in enumerate(val_loader):
            img, text, labels = x
            img = img.to(Device)
            #labels = labels.to(Device)
            text = {k: vs.to(Device) for k, vs in text.items()}
            loss = model(img, text)
            val_losses.append(loss.item())
            val_pbar.update(1)
        val_pbar.close()
        val_epoch_loss = sum(val_losses)/len(val_losses)
        print(f'Epoch {epoch + 1}: Val epoch loss: {val_epoch_loss:.4f}')
        val_epoch_losses.append(val_epoch_loss)
        if (val_epoch_loss < min_val_epoch_loss):
             min_val_epoch_loss = val_epoch_loss
             print(f'Saving checkpoint with {val_epoch_loss=:.4f}')
             torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': val_epoch_loss
                },
                f'{MODELS_DIR}/checkpoints/checkpoint.pt'
             )
        if len(val_epoch_losses) > 1 \
            and len(val_epoch_losses) % 5 == 0 \
            and val_epoch_losses[-1] >= val_epoch_losses[-config.lr_sched_steps]:

            print(f'\tStepping learning rate to lr * 0.5 (current={optimizer.state_dict()["param_groups"][0]["lr"]})')
            lr_scheduler.step()
            print(f'\tLearning rate now {optimizer.state_dict()["param_groups"][0]["lr"]}')




In [None]:
config = TrainingConfig(train_sub=218000, val_sub=5000)
train(config, batch_log_idx=100)