In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("raddar/chest-xrays-indiana-university")

print("Path to dataset files:", path)

Path to dataset files: C:\Users\v-mziadeh\.cache\kagglehub\datasets\raddar\chest-xrays-indiana-university\versions\2


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.__version__)

Using device: cuda
2.7.1+cu128


In [3]:
import pandas

dataset_folder = path
images_folder = dataset_folder + "/images/images_normalized"
projections = pandas.read_csv(dataset_folder + "/indiana_projections.csv")
reports = pandas.read_csv(dataset_folder + "/indiana_reports.csv")

combined_dataset = projections.merge(reports, on="uid", how="inner")

def IsNotAvailable(value):
    return value.str.contains("unavailable", case=False, na=False) \
        | value.str.contains("not available", case=False, na=False) \
        | value.str.contains("none", case=False, na=False)

combined_dataset.loc[IsNotAvailable(combined_dataset["comparison"]), "comparison"] = "None"

combined_dataset["indication"] = combined_dataset["indication"].fillna("None")
combined_dataset["findings"] = combined_dataset["findings"].fillna("None")
combined_dataset["impression"] = combined_dataset["impression"].fillna("None")
combined_dataset["comparison"] = combined_dataset["comparison"].fillna("None")
combined_dataset["report"] = combined_dataset["findings"]
# combined_dataset["report"] = (
#     "Indication: " + combined_dataset["indication"].astype(str) + "\n"
#     + "Findings: " + combined_dataset["findings"].astype(str) + "\n"
#     + "Impression: " + combined_dataset["impression"].astype(str) + "\n"
#     + "Comparison: " + combined_dataset["comparison"].astype(str)
# )

combined_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7466 entries, 0 to 7465
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   uid         7466 non-null   int64 
 1   filename    7466 non-null   object
 2   projection  7466 non-null   object
 3   MeSH        7466 non-null   object
 4   Problems    7466 non-null   object
 5   image       7466 non-null   object
 6   indication  7466 non-null   object
 7   comparison  7466 non-null   object
 8   findings    7466 non-null   object
 9   impression  7466 non-null   object
 10  report      7466 non-null   object
dtypes: int64(1), object(10)
memory usage: 641.7+ KB


In [4]:
combined_dataset.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...
4,3,3_IM-1384-1001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p...",


In [5]:
for r in combined_dataset["report"].head(5).to_list():
    print(r)
    print("-----")

The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
-----
The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
-----
Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
-----
Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
-----
None
-----


In [6]:
reduced_dataset = combined_dataset.loc[:, ("filename", "report")]
reduced_dataset

Unnamed: 0,filename,report
0,1_IM-0001-4001.dcm.png,The cardiac silhouette and mediastinum size ar...
1,1_IM-0001-3001.dcm.png,The cardiac silhouette and mediastinum size ar...
2,2_IM-0652-1001.dcm.png,Borderline cardiomegaly. Midline sternotomy XX...
3,2_IM-0652-2001.dcm.png,Borderline cardiomegaly. Midline sternotomy XX...
4,3_IM-1384-1001.dcm.png,
...,...,...
7461,3997_IM-2048-1002.dcm.png,"Heart size within normal limits. Small, nodula..."
7462,3998_IM-2048-1001.dcm.png,
7463,3998_IM-2048-1002.dcm.png,
7464,3999_IM-2049-1001.dcm.png,


In [7]:
from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(combined_dataset, test_size=0.2, random_state=42, shuffle=True)
test_df, valid_df = train_test_split(temp_df, test_size=0.5, random_state=42, shuffle=True)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}, Valid shape: {valid_df.shape}")

Train shape: (5972, 11), Test shape: (747, 11), Valid shape: (747, 11)


In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timm import create_model, list_models

sample_tfms = [
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(),
    A.ColorJitter(),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=45, p=0.5),
    A.HueSaturationValue(p=0.3),
]

train_tfms = A.Compose([
    *sample_tfms,
    A.Resize(224,224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224,224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
    ToTensorV2()
])

  original_init(self, **validated_kwargs)


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
import numpy
from torchvision import transforms

class ImageReportDataset(Dataset):
    def __init__(self, dataset, img_dir, tokenizer, transform=None):
        self.data = dataset
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data.loc[idx, "filename"])
        image = Image.open(img_name).convert("RGB")
        image = numpy.array(image)
        if self.transform:
            image = self.transform(image=image)["image"]
        report = self.data.loc[idx, "report"] + "<|endoftext|>"
        # inputs = self.tokenizer(report, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
        # return image, inputs["input_ids"].squeeze(), inputs["attention_mask"].squeeze()
        inputs = self.tokenizer(report, truncation=True)
        input_ids = inputs["input_ids"]
        labels = input_ids.copy()
        labels[:-1] = input_ids[1:]
        return image, input_ids, labels
        # return image, inputs["input_ids"].squeeze(), inputs["attention_mask"].squeeze()

In [10]:
from transformers import GPT2TokenizerFast
from torchvision import transforms

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

def collate_fn(batch):
    image = [i[0] for i in batch]
    input_ids = [i[1] for i in batch]
    labels = [i[2] for i in batch]
    image = torch.stack(image, dim=0)
    input_ids = tokenizer.pad(
        {"input_ids": input_ids},
        padding="longest",
        return_attention_mask=False,
        return_tensors="pt"
    )['input_ids']
    
    labels = tokenizer.pad(
        {"input_ids": labels},
        padding="longest",
        return_attention_mask=False,
        return_tensors="pt"
    )['input_ids']
    
    mask = (input_ids != tokenizer.pad_token_id).long()
    labels[mask==0] = -100
    return image, input_ids, labels

In [11]:
class GPT2Attention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension should be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len

        self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)
        self.scale = self.head_size ** -0.5

        self.register_buffer("mask", torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)

    def forward(self, x):
        b, t, c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        q = q.view(b, t, self.n_heads,self.head_size).permute(0, 2, 1, 3) # batch x n_heads x seq_len x head_dim
        k = k.view(b, t, self.n_heads,self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, t, self.n_heads,self.head_size).permute(0, 2, 1, 3)

        qk_t = (q @ k.transpose(-2, -1)) * self.scale
        qk_t = qk_t.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
        qk_t = F.softmax(qk_t, dim=-1)
        weights = self.attn_dropout(qk_t)

        attention = weights @ v # batch x n_heads x t x head_size
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c) # batch x t x embed_dim

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

In [12]:
class GPT2CrossAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension by be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len

        self.q = nn.Linear(self.embed_dim, self.embed_dim)
        self.k = nn.Linear(self.embed_dim, self.embed_dim)
        self.v = nn.Linear(self.embed_dim, self.embed_dim)
        self.scale = self.head_size ** -0.5

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, q, k, v):
        b, t, c = q.shape

        q = self.q(q)
        k = self.k(k)
        v = self.v(v)

        q = q.view(b, q.size(1), self.n_heads,self.head_size).permute(0, 2, 1, 3) # batch x n_heads x seq_len x head_dim
        k = k.view(b, k.size(1), self.n_heads,self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, v.size(1), self.n_heads,self.head_size).permute(0, 2, 1, 3)

        qk_t = (q @ k.transpose(-2, -1)) * self.scale
        qk_t = F.softmax(qk_t,dim=-1)
        weights = self.attn_dropout(qk_t)

        attention = weights @ v # batch x n_heads x t x head_size
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c) # batch x t x embed_dim

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

In [13]:
class GPT2MLP(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.mlp_ratio = config.mlp_ratio
        self.mlp_dropout = config.mlp_dropout

        self.c_fc = nn.Linear(self.embed_dim,self.embed_dim*self.mlp_ratio)
        self.c_proj = nn.Linear(self.embed_dim*self.mlp_ratio,self.embed_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(self.mlp_dropout)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [14]:
class GPT2Block(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.ln_1 = nn.LayerNorm(self.embed_dim)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(self.embed_dim)
        self.mlp = GPT2MLP(config)
        self.ln_3 = nn.LayerNorm(self.embed_dim)
        self.cross_attn = GPT2CrossAttention(config)

    def forward(self,x,enc_out):
        x = x+self.attn(self.ln_1(x))
        x = x+self.cross_attn(self.ln_2(x),enc_out,enc_out)
        x = x+self.mlp(self.ln_3(x))
        return x

In [None]:
class VisionGPT2Model(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        vit = create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        self.patch_embed = vit.patch_embed
        num_patches = self.patch_embed.num_patches

        self.cls_token = vit.cls_token
        embed_len = num_patches + vit.num_prefix_tokens
        self.pos_embed = vit.pos_embed
        self.pos_drop = nn.Dropout(p=0.)

        self.blocks = nn.ModuleList([vit.blocks[i] for i in range(config.depth)])

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size,config.embed_dim),
            wpe = nn.Embedding(config.seq_len,config.embed_dim),
            drop = nn.Dropout(config.emb_dropout),
            h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
            ln_f = nn.LayerNorm(config.embed_dim)
        ))
        
        self.lm_head = nn.Linear(config.embed_dim,config.vocab_size,bias=False)
        self.transformer.wte.weight = self.lm_head.weight

    def _pos_embed(self,x):
        pos_embed = self.pos_embed
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + pos_embed
        return self.pos_drop(x)

    def pretrained_layers_trainable(self,trainable=False):
        layers = [
            self.cls_token, self.patch_embed, self.pos_embed, self.blocks,
            self.transformer.wte, self.transformer.wpe,
            self.transformer.ln_f, self.lm_head
        ]
        gpt_layers = [[
            self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
            self.transformer.h[i].attn,self.transformer.h[i].mlp
        ] for i in range(self.config.depth)]
        for l in gpt_layers:
            layers.extend(l)

        for layer in layers:
            if not isinstance(layer,nn.Parameter):
                for p in layer.parameters():
                    p.requires_grad = trainable
            else:
                layer.requires_grad = trainable

        total_frozen_params = sum([p.numel() for p in self.parameters() if not p.requires_grad])
        print(f'{total_frozen_params=}')

    def unfreeze_gpt_layers(self,):
        gpt_layers = [[
            self.transformer.h[i].ln_1,self.transformer.h[i].ln_2,
            self.transformer.h[i].attn,self.transformer.h[i].mlp
        ] for i in range(self.config.depth)]
        flatten = []
        for l in gpt_layers:
            flatten.extend(l)

        for layer in flatten:
            if not isinstance(layer,nn.Parameter):
                for p in layer.parameters():
                    p.requires_grad = True
            else:
                layer.requires_grad = True

    @classmethod
    def from_pretrained(self, config):
        model = VisionGPT2Model(config)
        sd = model.state_dict()
        keys = sd.keys()
        ignore_matches = ["blocks.", "cross_attn.", "ln_3", "cls_token", "pos_embed", "patch_embed.", ".attn.mask"]
        vit_keys = [key for key in keys if any(match in key for match in ignore_matches)]
        gpt_keys = [key for key in keys if key not in vit_keys]

        gpt2_small = GPT2LMHeadModel.from_pretrained("gpt2")
        sd_hf = gpt2_small.state_dict()
        hf_keys = sd_hf.keys()
        hf_keys = [k for k in hf_keys if not k.endswith(".attn.masked_bias")]
        hf_keys = [k for k in hf_keys if not k.endswith(".attn.bias")]
        transposed = ["attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"]

        for k in hf_keys:
            if any(match in k for match in ignore_matches):
                continue
                
            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        model.load_state_dict(sd)

        return model

    def forward(self, image, input_ids, labels=None):

        image = self.patch_embed(image)
        image = self._pos_embed(image)

        token_embeddings = self.transformer.wte(input_ids) # batch x seq_len
        pos_embs = torch.arange(0, input_ids.size(1)).to(input_ids.device)
        positional_embeddings = self.transformer.wpe(pos_embs)
        input_ids = self.transformer.drop(token_embeddings+positional_embeddings)

        for i in range(self.config.depth):
            image = self.blocks[i](image)
            input_ids = self.transformer.h[i](input_ids, image)

        input_ids = self.transformer.ln_f(input_ids)

        if labels is not None:
            lm_logits = self.lm_head(input_ids)
            loss = F.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
            return loss

        lm_logits = self.lm_head(input_ids[:, [-1], :])
        return lm_logits

    def generate(self,image,sequence,max_tokens=50,temperature=1.0,deterministic=False):
        for _ in range(max_tokens):
            out = self(image,sequence)
            out = out[:, -1, :] / temperature
            probs = F.softmax(out, dim=-1)
            if deterministic:
                next_token = torch.argmax(probs, dim=-1, keepdim=True)
            else:
                next_token = torch.multinomial(probs, num_samples=1)
            sequence = torch.cat([sequence,next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break

        return sequence.cpu().flatten()

In [16]:
from torch import GradScaler, autocast
from tqdm.auto import tqdm
import gc

class Trainer:
    def __init__(self, model_config, train_config, dls):
        self.train_config = train_config
        self.model_config = model_config
        self.device = self.train_config.device

        self.model = VisionGPT2Model.from_pretrained(model_config).to(self.device)
        self.model.pretrained_layers_trainable(trainable=False)

        print(f"Trainable parameters: {sum([p.numel() for p in self.model.parameters() if p.requires_grad])}")

        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.scaler = GradScaler(self.device)

        self.train_dl, self.val_dl = dls

        total_steps = len(self.train_dl)

        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.train_config.lr / 25.)
        self.sched = torch.optim.lr_scheduler.OneCycleLR(
            self.optim,
            max_lr=self.train_config.lr,
            epochs=self.train_config.epochs,
            steps_per_epoch=total_steps
        )

#         self.sched = get_linear_schedule_with_warmup(self.optim,num_warmup_steps=0,num_training_steps=total_steps)

        self.metrics = pandas.DataFrame()
        self.metrics[["train_loss", "train_perplexity", "val_loss", "val_perplexity"]] = None

        self.gen_tfms = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ])

    def save_model(self,):
        self.train_config.model_path.mkdir(exist_ok=True)
        sd = self.model.state_dict()
        torch.save(self.model, self.train_config.model_path/"captioner.pt")

    def load_best_model(self,):
        sd = torch.load(self.train_config.model_path/"captioner.pt", weights_only=False)
        self.model.load_state_dict(sd)

    def train_one_epoch(self, epoch):
        prog = tqdm(self.train_dl, total=len(self.train_dl))
        running_loss = 0.

        for image, input_ids, labels in prog:
            with autocast(self.device):
                image = image.to(self.device)
                input_ids = input_ids.to(self.device)
                labels = labels.to(self.device)

                loss = self.model(image,input_ids,labels)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optim)
                self.scaler.update()
                self.sched.step()
                self.optim.zero_grad(set_to_none=True)

                running_loss += loss.item()

                prog.set_description(f'train loss: {loss.item():.3f}')

            del image, input_ids, labels, loss
        
        print()

        train_loss = running_loss / len(self.train_dl)
        train_pxp = numpy.exp(train_loss)

        self.metrics.loc[epoch, ["train_loss", "train_perplexity"]] = (train_loss, train_pxp)

    @torch.no_grad()
    def valid_one_epoch(self, epoch):

        prog = tqdm(self.val_dl, total=len(self.val_dl))

        running_loss = 0.

        for image, input_ids, labels in prog:

            with autocast(self.device):
                image = image.to(self.device)
                input_ids = input_ids.to(self.device)
                labels = labels.to(self.device)

                loss = self.model(image,input_ids,labels)
                running_loss += loss.item()

                prog.set_description(f"Valid loss: {loss.item():.3f}")

            del image, input_ids, labels, loss

        print()

        val_loss = running_loss / len(self.val_dl)
        val_pxp = numpy.exp(val_loss)

        self.metrics.loc[epoch, ["val_loss", "val_perplexity"]] = (val_loss,val_pxp)

        return val_pxp

    def clean(self):
        gc.collect()
        torch.cuda.empty_cache()

    def fit(self,):
        best_pxp = 1e9
        best_epoch = -1
        prog = tqdm(range(self.train_config.epochs))

        for epoch in prog:
            if epoch == self.train_config.freeze_epochs_gpt:
                self.model.unfreeze_gpt_layers()
                print("Unfreezing GPT2 entirely...")

            if epoch == self.train_config.freeze_epochs_all:
                self.model.pretrained_layers_trainable(trainable=True)

            self.model.train()
            prog.set_description("Training")
            self.train_one_epoch(epoch)
            self.clean()

            self.model.eval()
            prog.set_description("Validating")
            pxp = self.valid_one_epoch(epoch)
            self.clean()

            print(self.metrics.tail(1))

            if pxp < best_pxp:
                best_pxp = pxp
                best_epoch = epoch
                print("Saving best model...")
                self.save_model()

        return {
            "best_perplexity": best_pxp,
            "best_epoch": best_epoch
        }

    @torch.no_grad()
    def generate_caption(self, image, max_tokens=50, temperature=1.0, deterministic=False):

        self.model.eval()

        image = Image.open(image).convert("RGB")
        image = numpy.array(image)
        image = self.gen_tfms(image=image)["image"]
        image = image.unsqueeze(0).to(self.device)
        sequence = torch.ones(1, 1).to(device=self.device).long() * self.tokenizer.bos_token_id

        caption = self.model.generate(
            image,
            sequence,
            max_tokens=max_tokens,
            temperature=temperature,
            deterministic=deterministic
        )
        caption = self.tokenizer.decode(caption.numpy(),skip_special_tokens=True)

        return caption

In [17]:
from types import SimpleNamespace
from pathlib import Path

model_config = SimpleNamespace(
    vocab_size = 50_257,
    embed_dim = 768, # 768
    num_heads = 12,
    seq_len = 1024,
    depth = 12,
    attention_dropout = 0.1,
    residual_dropout = 0.1,
    mlp_ratio = 4,
    mlp_dropout = 0.1,
    emb_dropout = 0.1,
)

train_config = SimpleNamespace(
    epochs = 5,
    freeze_epochs_gpt = 1,
    freeze_epochs_all = 2,
    lr = 1e-4,
    device = 'cuda',
    model_path = Path('./training/multi_modal'),
    batch_size = 32
)

In [18]:
train_dataset = ImageReportDataset(train_df, images_folder, tokenizer, train_tfms)
valid_dataset = ImageReportDataset(valid_df, images_folder, tokenizer, valid_tfms)
test_dataset = ImageReportDataset(test_df, images_folder, tokenizer, valid_tfms)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=train_config.batch_size,
    shuffle=True,
    # pin_memory=True,
    # num_workers=2,
    # persistent_workers=True,
    collate_fn=collate_fn
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=train_config.batch_size,
    shuffle=False,
    # pin_memory=True,
    # num_workers=2,
    # persistent_workers=True,
    collate_fn=collate_fn
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=train_config.batch_size,
    shuffle=False,
    # pin_memory=True,
    # num_workers=2,
    # persistent_workers=True,
    collate_fn=collate_fn
)

In [28]:
trainer = Trainer(model_config, train_config, (train_dataloader, valid_dataloader))
trainer.fit()

total_frozen_params=210236928
Trainable parameters: 28366848


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]




  0%|          | 0/24 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
0   8.665057      5796.772245  5.812038     334.299603
Saving best model...
Unfreezing GPT2 entirely...


  0%|          | 0/187 [00:00<?, ?it/s]




  0%|          | 0/24 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
1   3.367834        29.015617  1.855544       6.395175
Saving best model...
total_frozen_params=0


  0%|          | 0/187 [00:00<?, ?it/s]




  0%|          | 0/24 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
2   1.749689         5.752815  1.401026       4.059363
Saving best model...


  0%|          | 0/187 [00:00<?, ?it/s]




  0%|          | 0/24 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
3   1.439368         4.218028  1.270694       3.563325
Saving best model...


  0%|          | 0/187 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [32]:
trainer.load_best_model()

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL __main__.VisionGPT2Model was not an allowed global by default. Please use `torch.serialization.add_safe_globals([__main__.VisionGPT2Model])` or the `torch.serialization.safe_globals([__main__.VisionGPT2Model])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [19]:
pandas.read_csv("training\multi_modal\metrics.csv")

  pandas.read_csv("training\multi_modal\metrics.csv")


Unnamed: 0,train_loss,train_perplexity,val_loss,val_perplexity
0,10.132064,25136.198009,8.420802,4540.544581
1,6.224047,504.742017,4.302428,73.878977
2,3.132028,22.920412,2.013945,7.492819
3,1.939077,6.952334,1.537005,4.65064
4,1.532418,4.629359,1.272288,3.569008
5,1.304426,3.685574,1.132638,3.103833
6,1.143736,3.138472,1.027499,2.794068
7,1.02434,2.785256,0.946582,2.576886
8,0.927683,2.528643,0.891535,2.438871
9,0.844745,2.327385,0.826734,2.285841


In [20]:
best_model = torch.load("training/multi_modal/captioner.pt", weights_only=False)
best_model.to(device)

VisionGPT2Model(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_

In [21]:
def generate_caption(model, image, max_tokens=200, temperature=1.0, deterministic=False):
    # model.eval()
    gen_tfms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])

    image = Image.open(image).convert('RGB')
    image = numpy.array(image)
    image = gen_tfms(image=image)['image']
    image = image.unsqueeze(0).to(device)  # Move the input image tensor to the same device as the model
    sequence = torch.ones(1, 1).long().to(device) * tokenizer.bos_token_id

    caption = model.generate(
        image,
        sequence,
        max_tokens=max_tokens,
        temperature=temperature,
        deterministic=deterministic
    )
    caption = tokenizer.decode(caption.cpu().numpy(), skip_special_tokens=True)  # Move the generated caption back to CPU for decoding

    return caption

img_path = os.path.join(images_folder, test_df.loc[0, "filename"])
generated_report = generate_caption(best_model, img_path)
print(test_df.loc[0, "report"])
print("------------")
print(generated_report)

Cardiomediastinal contour stable and within normal limits. Changes of prior CABG again noted. Normal pulmonary vascularity. Streaky bibasilar opacities decreased from previous, possibly subsegmental atelectasis and/or scar. No pneumothorax or pleural effusion demonstrated. Redemonstrated severe L1 XXXX fracture. Slight interval increase in XXXX loss of T11 and there is XXXX mild to moderate anterior XXXX loss of T10. Degenerative changes of the spine. Abdominal aortic stent.
------------
 XXXX examination consists of frontal and lateral radiographs of the chest. The cardiomediastinal contours are within normal limits allowing for low lung volumes and patient rotation. There is posterolateral view of the patient body without significant change. Removal of XXXX opacities in the left lung base may represent atelectasis or scarring. No focal consolidation, pleural effusion, or pneumothorax identified. Dense fracture of the left hemidiaphragm is stable. Degenerative disease of the thoracic 

In [22]:
import torch
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from rouge_score import rouge_scorer
import nltk

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\v-mziadeh\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [26]:
# Function to calculate BLEU score
def compute_bleu(reference_texts, generated_texts):
    """
    Compute BLEU score between generated texts and references.
    
    :param reference_texts: List of lists of reference texts (for each generated report)
    :param generated_texts: List of generated reports
    :return: BLEU score
    """
    references = [[ref.split()] for ref in reference_texts]  # List of list of reference tokens
    candidates = [gen.split() for gen in generated_texts]   # List of list of generated tokens
    bleu_score = corpus_bleu(references, candidates)
    return bleu_score

# Function to calculate ROUGE score
def compute_rouge(reference_texts, generated_texts):
    """
    Compute ROUGE score between generated texts and references.
    
    :param reference_texts: List of reference reports
    :param generated_texts: List of generated reports
    :return: ROUGE score
    """
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    rouge_scores = {"rouge1": [], "rouge2": [], "rougeL": []}
    
    for reference, generated in zip(reference_texts, generated_texts):
        scores = scorer.score(reference, generated)
        for key in rouge_scores:
            rouge_scores[key].append(scores[key].fmeasure)
    
    avg_rouge1 = sum(rouge_scores["rouge1"]) / len(rouge_scores["rouge1"])
    avg_rouge2 = sum(rouge_scores["rouge2"]) / len(rouge_scores["rouge2"])
    avg_rougeL = sum(rouge_scores["rougeL"]) / len(rouge_scores["rougeL"])
    
    return avg_rouge1, avg_rouge2, avg_rougeL

# Evaluation function
def evaluate_model(model, folder_path, eval_set):
    generated_reports = []
    reference_reports = []

    print(f"Starting evaluation for model using {len(eval_set)} items")
    for idx in range(len(eval_set)):
        # Generate report for each image
        generated_report = generate_caption(model, os.path.join(folder_path, eval_set.loc[idx, "filename"]))
        reference_report = eval_set.loc[idx, "report"]
        
        generated_reports.append(generated_report)
        reference_reports.append(reference_report)

        print(f"Generated report {idx + 1}/{len(eval_set)}    \r", end="")

    print()
    
    # Compute BLEU
    bleu_score = compute_bleu(reference_reports, generated_reports)
    print(f"BLEU Score: {bleu_score:.4f}")
    
    # Compute ROUGE
    rouge1, rouge2, rougeL = compute_rouge(reference_reports, generated_reports)
    print(f"ROUGE-1: {rouge1:.4f}, ROUGE-2: {rouge2:.4f}, ROUGE-L: {rougeL:.4f}")

In [27]:
# Evaluate the model
evaluate_model(best_model, images_folder, test_df)

Starting evaluation for model using 747 items
Generated report 747/747    
BLEU Score: 0.0458
ROUGE-1: 0.2948, ROUGE-2: 0.0945, ROUGE-L: 0.2019
