In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("raddar/chest-xrays-indiana-university")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/chest-xrays-indiana-university


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.__version__)

Using device: cuda
2.6.0+cu124


In [3]:
import pandas

dataset_folder = path
images_folder = dataset_folder + "/images/images_normalized"
projections = pandas.read_csv(dataset_folder + "/indiana_projections.csv")
reports = pandas.read_csv(dataset_folder + "/indiana_reports.csv")

combined_dataset = projections.merge(reports, on="uid", how="inner")

def IsNotAvailable(value):
    return value.str.contains("unavailable", case=False, na=False) \
        | value.str.contains("not available", case=False, na=False) \
        | value.str.contains("none", case=False, na=False)

combined_dataset.loc[IsNotAvailable(combined_dataset["comparison"]), "comparison"] = "None"

combined_dataset["indication"] = combined_dataset["indication"].fillna("None")
combined_dataset["findings"] = combined_dataset["findings"].fillna("None")
combined_dataset["impression"] = combined_dataset["impression"].fillna("None")
combined_dataset["comparison"] = combined_dataset["comparison"].fillna("None")
combined_dataset["report"] = combined_dataset["findings"]
combined_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7466 entries, 0 to 7465
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   uid         7466 non-null   int64 
 1   filename    7466 non-null   object
 2   projection  7466 non-null   object
 3   MeSH        7466 non-null   object
 4   Problems    7466 non-null   object
 5   image       7466 non-null   object
 6   indication  7466 non-null   object
 7   comparison  7466 non-null   object
 8   findings    7466 non-null   object
 9   impression  7466 non-null   object
 10  report      7466 non-null   object
dtypes: int64(1), object(10)
memory usage: 641.7+ KB


In [4]:
combined_dataset.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...
4,3,3_IM-1384-1001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p...",


In [5]:
for r in combined_dataset["report"].head(5).to_list():
    print(r)
    print("-----")

The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
-----
The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
-----
Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
-----
Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
-----
None
-----


In [6]:
# reduced_dataset = combined_dataset.loc[:, ("filename", "projection", "report")]
# reduced_dataset["report"] = reduced_dataset["report"].fillna("None")
# reduced_dataset

In [7]:
import torchxrayvision as xrv
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import Dataset, DataLoader
import skimage
import os
import pandas as pd
import numpy
from torchvision import transforms

class XRayVisionKeywordsExtractor():
    def __init__(self, model_name="densenet121-res224-all"):
        self.model = xrv.models.DenseNet(weights=model_name)
        self.transform = transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224)])

    def extract(self, images):
        keywords = []
        for filepath in images:
            img = skimage.io.imread(filepath)
            img = xrv.datasets.normalize(img, 255)
            img = numpy.expand_dims(img, axis=0)  # Add channel dimension

            img = self.transform(img)
            img = torch.from_numpy(img)

            outputs = self.model(img[None, ...])
            predicted_pathologies = pd.DataFrame(zip(self.model.pathologies, outputs[0].detach().numpy()), columns=["Pathology", "Score"])
            top_pathologies = predicted_pathologies.loc[predicted_pathologies["Score"] >= 0.5].sort_values(by="Score", ascending=False)
            keywords.append(top_pathologies["Pathology"].values)

        return keywords

# xrv_keywords_extractor = XRayVisionKeywordsExtractor()
# xrv_keywords = xrv_keywords_extractor.extract(reduced_dataset["filename"].head(2).to_list())
# xrv_keywords
# xrv_model = xrv.models.DenseNet(weights="densenet121-res224-all")

# xrv_transform = transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224)])

# def add_keywords(images):
#     keywords = []
#     total = len(images)
#     count = 0
#     for filename in images:
#         img = skimage.io.imread(os.path.join(images_folder, filename))
#         img = xrv.datasets.normalize(img, 255)
#         img = numpy.expand_dims(img, axis=0)  # Add channel dimension

#         img = xrv_transform(img)
#         img = torch.from_numpy(img)

#         outputs = xrv_model(img[None, ...])
#         predicted_pathologies = pandas.DataFrame(zip(xrv_model.pathologies, outputs[0].detach().numpy()), columns=["Pathology", "Score"])
#         top_pathologies = predicted_pathologies.loc[predicted_pathologies["Score"] >= 0.5].sort_values(by="Score", ascending=False)
#         keywords.append(top_pathologies["Pathology"].values)

#         count += 1
#         print(f"Processed {count}/{total} images        ", end="\r")

#     return keywords

# reduced_dataset["keywords"] = add_keywords(reduced_dataset["filename"].to_list())
# reduced_dataset

In [8]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
# reduced_dataset.to_csv("./reduced_dataset_with_keywords.csv", index=False)
reduced_dataset = pandas.read_csv("/content/drive/MyDrive/Colab Notebooks/reduced_dataset_with_keywords.csv")
reduced_dataset["report"] = reduced_dataset["report"].fillna("None")
reduced_dataset["keywords"] = reduced_dataset["keywords"].fillna("").apply(lambda x: x.replace(";", ", "))

combined_dataset["keywords"] = reduced_dataset["keywords"]
merged_dataset = combined_dataset.groupby("uid").agg({
    "filename": list,
    "projection": list,
    "MeSH": "first",
    "Problems": "first",
    "image": "first",
    "indication": "first",
    "comparison": "first",
    "findings": "first",
    "impression": "first",
    "report": "first",
    "keywords": lambda x: [i for i in set(", ".join(x).split(", ")) if i != ""]  # Combine keywords into a single list
}).reset_index()

merged_dataset.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report,keywords
0,1,"[1_IM-0001-4001.dcm.png, 1_IM-0001-3001.dcm.png]","[Frontal, Lateral]",normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,The cardiac silhouette and mediastinum size ar...,"[Edema, Consolidation, Infiltration, Effusion,..."
1,2,"[2_IM-0652-1001.dcm.png, 2_IM-0652-2001.dcm.png]","[Frontal, Lateral]",Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Borderline cardiomegaly. Midline sternotomy XX...,"[Fibrosis, Pleural_Thickening, Consolidation, ..."
2,3,"[3_IM-1384-1001.dcm.png, 3_IM-1384-2001.dcm.png]","[Frontal, Lateral]",normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p...",,"[Fibrosis, Consolidation, Nodule, Enlarged Car..."
3,4,"[4_IM-2050-1001.dcm.png, 4_IM-2050-2001.dcm.png]","[Frontal, Lateral]","Pulmonary Disease, Chronic Obstructive;Bullous...","Pulmonary Disease, Chronic Obstructive;Bullous...","PA and lateral views of the chest XXXX, XXXX a...",XXXX-year-old XXXX with XXXX.,,There are diffuse bilateral interstitial and a...,1. Bullous emphysema and interstitial fibrosis...,There are diffuse bilateral interstitial and a...,"[Fibrosis, Pleural_Thickening, Edema, Consolid..."
4,5,"[5_IM-2117-1003002.dcm.png, 5_IM-2117-1004003....","[Frontal, Lateral]",Osteophyte/thoracic vertebrae/multiple/small;T...,Osteophyte;Thickening;Lung,Xray Chest PA and Lateral,Chest and nasal congestion.,,The cardiomediastinal silhouette and pulmonary...,No acute cardiopulmonary abnormality.,The cardiomediastinal silhouette and pulmonary...,"[Fibrosis, Edema, Effusion, Consolidation, Ate..."


In [10]:
# Use the value from "impression" if "report" is "None"
merged_dataset.loc[merged_dataset["report"] == "None", "report"] = merged_dataset.loc[merged_dataset["report"] == "None", "impression"]

# Drop rows where "report" is still "None"
merged_dataset = merged_dataset[merged_dataset["report"] != "None"]

# Needed dataset for training/evaluation
cleaned_dataset = merged_dataset.loc[:, ("filename", "keywords", "report")]
cleaned_dataset.head()

Unnamed: 0,filename,keywords,report
0,"[1_IM-0001-4001.dcm.png, 1_IM-0001-3001.dcm.png]","[Edema, Consolidation, Infiltration, Effusion,...",The cardiac silhouette and mediastinum size ar...
1,"[2_IM-0652-1001.dcm.png, 2_IM-0652-2001.dcm.png]","[Fibrosis, Pleural_Thickening, Consolidation, ...",Borderline cardiomegaly. Midline sternotomy XX...
2,"[3_IM-1384-1001.dcm.png, 3_IM-1384-2001.dcm.png]","[Fibrosis, Consolidation, Nodule, Enlarged Car...","No displaced rib fractures, pneumothorax, or p..."
3,"[4_IM-2050-1001.dcm.png, 4_IM-2050-2001.dcm.png]","[Fibrosis, Pleural_Thickening, Edema, Consolid...",There are diffuse bilateral interstitial and a...
4,"[5_IM-2117-1003002.dcm.png, 5_IM-2117-1004003....","[Fibrosis, Edema, Effusion, Consolidation, Ate...",The cardiomediastinal silhouette and pulmonary...


In [11]:
from sklearn.model_selection import train_test_split

train_df, temp_df = train_test_split(cleaned_dataset, test_size=0.2, random_state=42, shuffle=True)
test_df, valid_df = train_test_split(temp_df, test_size=0.5, random_state=42, shuffle=True)
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}, Valid shape: {valid_df.shape}")

Train shape: (3060, 3), Test shape: (383, 3), Valid shape: (383, 3)


In [12]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timm import create_model, list_models

sample_tfms = [
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(),
    A.ColorJitter(),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.3, rotate_limit=45, p=0.5),
    A.HueSaturationValue(p=0.3),
]

train_tfms = A.Compose([
    *sample_tfms,
    A.Resize(224,224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
    ToTensorV2()
])

valid_tfms = A.Compose([
    A.Resize(224,224),
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]),
    ToTensorV2()
])

In [13]:
from PIL import Image

def concat_images(image_paths, direction="horizontal"):
    """
    Concatenate multiple images into a single image.

    :param image_paths: List of image file paths.
    :param direction: 'horizontal' or 'vertical'.
    :return: PIL.Image object of the concatenated image.
    """
    if (len(image_paths) == 1):
        return Image.open(image_paths[0]).convert("RGB")

    mid = len(image_paths) // 2
    left, right = image_paths[:mid], image_paths[mid:]
    left_img = concat_images(left, direction="vertical" if direction == "horizontal" else "horizontal")
    right_img = concat_images(right, direction="vertical" if direction == "horizontal" else "horizontal")

    images = [left_img, right_img] # [Image.open(os.path.join(images_folder, p)).convert("RGB") for p in image_paths]
    widths, heights = zip(*(img.size for img in images))

    if direction == "horizontal":
        total_width = sum(widths)
        max_height = max(heights)
        new_im = Image.new('RGB', (total_width, max_height))
        x_offset = 0
        for img in images:
            new_im.paste(img, (x_offset, 0))
            x_offset += img.size[0]
    else:  # vertical
        max_width = max(widths)
        total_height = sum(heights)
        new_im = Image.new('RGB', (max_width, total_height))
        y_offset = 0
        for img in images:
            new_im.paste(img, (0, y_offset))
            y_offset += img.size[1]

    return new_im

# Example usage: concatenate the first row's images horizontally
# big_image = concat_images(cleaned_dataset.loc[43, "filename"], direction='horizontal')
# big_image.show()

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
import numpy
from torchvision import transforms

class ImageReportWithKeywordsDataset(Dataset):
    def __init__(self, dataset, img_dir, tokenizer, transform=None):
        self.data = dataset
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_paths = [os.path.join(self.img_dir, filename) for filename in self.data.loc[idx, "filename"]]

        # Concatenate all images of one report into a single image
        image = concat_images(img_paths)
        image = numpy.array(image)

        if self.transform:
            image = self.transform(image=image)["image"]

        keywords = self.data.loc[idx, "keywords"]
        keywords = "keywords: " + ", ".join(self.data.loc[idx, "keywords"])
        keywords = self.tokenizer(keywords, truncation=True, padding="max_length", max_length=50, return_tensors="pt")["input_ids"]

        report = self.data.loc[idx, "report"] + "<|endoftext|>"
        inputs = self.tokenizer(report, truncation=True)
        input_ids = inputs["input_ids"]
        labels = input_ids.copy()
        labels[:-1] = input_ids[1:]
        return image, input_ids, keywords, labels

In [15]:
from transformers import GPT2TokenizerFast
from torchvision import transforms

tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

def collate_fn(batch):
    image = [i[0] for i in batch]
    input_ids = [i[1] for i in batch]
    keywords = [i[2] for i in batch]
    labels = [i[3] for i in batch]
    image = torch.stack(image, dim=0)
    input_ids = tokenizer.pad(
        {"input_ids": input_ids},
        padding="longest",
        return_attention_mask=False,
        return_tensors="pt"
    )['input_ids']

    keywords = tokenizer.pad(
        {"input_ids": keywords},
        padding="longest",
        return_attention_mask=False,
        return_tensors="pt"
    )['input_ids']

    labels = tokenizer.pad(
        {"input_ids": labels},
        padding="longest",
        return_attention_mask=False,
        return_tensors="pt"
    )['input_ids']

    mask = (input_ids != tokenizer.pad_token_id).long()
    labels[mask==0] = -100
    return image, input_ids, keywords, labels

In [16]:
class GPT2Attention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension should be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len

        self.c_attn = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=True)
        self.scale = self.head_size ** -0.5

        self.register_buffer("mask", torch.tril(torch.ones(1,1,self.seq_len,self.seq_len)))

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)

    def forward(self, x):
        b, t, c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        q = q.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3) # batch x n_heads x seq_len x head_dim
        k = k.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, t, self.n_heads, self.head_size).permute(0, 2, 1, 3)

        qk_t = (q @ k.transpose(-2, -1)) * self.scale
        qk_t = qk_t.masked_fill(self.mask[:, :, :t, :t] == 0, float("-inf"))
        qk_t = F.softmax(qk_t, dim=-1)
        weights = self.attn_dropout(qk_t)

        attention = weights @ v # batch x n_heads x t x head_size
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c) # batch x t x embed_dim

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

In [17]:
class GPT2CrossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, "embedding dimension must be divisible by number of heads"
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.seq_len

        self.q = nn.Linear(self.embed_dim, self.embed_dim)
        self.k = nn.Linear(self.embed_dim, self.embed_dim)
        self.v = nn.Linear(self.embed_dim, self.embed_dim)
        self.scale = self.head_size ** -0.5

        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

        self.attn_dropout = nn.Dropout(config.attention_dropout)
        self.resid_dropout = nn.Dropout(config.residual_dropout)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, q, k, v):
        b, t, c = q.shape

        q = self.q(q)
        k = self.k(k)
        v = self.v(v)

        q = q.view(b, q.size(1), self.n_heads, self.head_size).permute(0, 2, 1, 3) # batch x n_heads x seq_len x head_dim
        k = k.view(b, k.size(1), self.n_heads, self.head_size).permute(0, 2, 1, 3)
        v = v.view(b, v.size(1), self.n_heads, self.head_size).permute(0, 2, 1, 3)

        qk_t = (q @ k.transpose(-2, -1)) * self.scale
        qk_t = F.softmax(qk_t, dim=-1)
        weights = self.attn_dropout(qk_t)

        attention = weights @ v # batch x n_heads x t x head_size
        attention = attention.permute(0, 2, 1, 3).contiguous().view(b, t, c) # batch x t x embed_dim

        out = self.c_proj(attention)
        out = self.resid_dropout(out)

        return out

In [18]:
class GPT2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.mlp_ratio = config.mlp_ratio
        self.mlp_dropout = config.mlp_dropout

        self.c_fc = nn.Linear(self.embed_dim,self.embed_dim * self.mlp_ratio)
        self.c_proj = nn.Linear(self.embed_dim * self.mlp_ratio,self.embed_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(self.mlp_dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [19]:
class GPT2Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.ln_1 = nn.LayerNorm(self.embed_dim)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(self.embed_dim)
        self.mlp = GPT2MLP(config)
        self.ln_3 = nn.LayerNorm(self.embed_dim)
        self.cross_attn_1 = GPT2CrossAttention(config)
        self.ln_4 = nn.LayerNorm(self.embed_dim)
        self.cross_attn_2 = GPT2CrossAttention(config)

    def forward(self, x, enc_out, keywords):
        x = x + self.attn(self.ln_1(x))
        x = x + self.cross_attn_1(self.ln_2(x), enc_out, enc_out)
        x = x + self.cross_attn_2(self.ln_3(x), keywords, keywords)
        x = x + self.mlp(self.ln_4(x))
        return x

In [20]:
class VisionGPT2Model(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        vit = create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        self.patch_embed = vit.patch_embed
        num_patches = self.patch_embed.num_patches

        self.cls_token = vit.cls_token
        embed_len = num_patches + vit.num_prefix_tokens
        self.pos_embed = vit.pos_embed
        self.pos_drop = nn.Dropout(p=0.)

        self.blocks = nn.ModuleList([vit.blocks[i] for i in range(config.depth)])

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.embed_dim),
            wpe = nn.Embedding(config.seq_len, config.embed_dim),
            drop = nn.Dropout(config.emb_dropout),
            h = nn.ModuleList([GPT2Block(config) for _ in range(config.depth)]),
            ln_f = nn.LayerNorm(config.embed_dim)
        ))

        self.lm_head = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight

    def _pos_embed(self,x):
        pos_embed = self.pos_embed
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + pos_embed
        return self.pos_drop(x)

    def pretrained_layers_trainable(self, trainable=False):
        layers = [
            self.cls_token, self.patch_embed, self.pos_embed, self.blocks,
            self.transformer.wte, self.transformer.wpe,
            self.transformer.ln_f, self.lm_head
        ]

        gpt_layers = [[
            self.transformer.h[i].ln_1, self.transformer.h[i].ln_2,
            self.transformer.h[i].attn, self.transformer.h[i].mlp
        ] for i in range(self.config.depth)]

        for l in gpt_layers:
            layers.extend(l)

        for layer in layers:
            if not isinstance(layer, nn.Parameter):
                for p in layer.parameters():
                    p.requires_grad = trainable
            else:
                layer.requires_grad = trainable

        total_frozen_params = sum([p.numel() for p in self.parameters() if not p.requires_grad])
        print(f'{total_frozen_params = }')

    def unfreeze_gpt_layers(self,):
        gpt_layers = [[
            self.transformer.h[i].ln_1, self.transformer.h[i].ln_2,
            self.transformer.h[i].attn, self.transformer.h[i].mlp
        ] for i in range(self.config.depth)]

        flatten = []
        for l in gpt_layers:
            flatten.extend(l)

        for layer in flatten:
            if not isinstance(layer,nn.Parameter):
                for p in layer.parameters():
                    p.requires_grad = True
            else:
                layer.requires_grad = True

    @classmethod
    def from_pretrained(self, config):
        model = VisionGPT2Model(config)
        sd = model.state_dict()
        keys = sd.keys()
        ignore_matches = ["blocks.", "cross_attn.", "ln_3", "cls_token", "pos_embed", "patch_embed.", ".attn.mask"]
        vit_keys = [key for key in keys if any(match in key for match in ignore_matches)]
        gpt_keys = [key for key in keys if key not in vit_keys]

        gpt2_small = GPT2LMHeadModel.from_pretrained("gpt2")
        sd_hf = gpt2_small.state_dict()
        hf_keys = sd_hf.keys()
        hf_keys = [k for k in hf_keys if not k.endswith(".attn.masked_bias")]
        hf_keys = [k for k in hf_keys if not k.endswith(".attn.bias")]
        transposed = ["attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"]

        for k in hf_keys:
            if any(match in k for match in ignore_matches):
                continue

            if any(k.endswith(w) for w in transposed):
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        model.load_state_dict(sd)

        return model

    def forward(self, image, input_ids, keywords, labels=None):

        image = self.patch_embed(image)
        image = self._pos_embed(image)

        token_embeddings = self.transformer.wte(input_ids) # batch x seq_len
        pos_embs = torch.arange(0, input_ids.size(1)).to(input_ids.device)
        positional_embeddings = self.transformer.wpe(pos_embs)
        input_ids = self.transformer.drop(token_embeddings + positional_embeddings)

        keywords = self.transformer.wte(keywords).squeeze(1)  # batch x seq_len x embed_dim

        for i in range(self.config.depth):
            image = self.blocks[i](image)
            input_ids = self.transformer.h[i](input_ids, image, keywords)

        input_ids = self.transformer.ln_f(input_ids)

        if labels is not None:
            lm_logits = self.lm_head(input_ids)
            loss = F.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
            return loss

        lm_logits = self.lm_head(input_ids[:, [-1], :])
        return lm_logits

    def generate(self, image, sequence, keywords, max_tokens=50, temperature=1.0, deterministic=False):
        for _ in range(max_tokens):
            out = self(image, sequence, keywords)
            out = out[:, -1, :] / temperature
            probs = F.softmax(out, dim=-1)
            if deterministic:
                next_token = torch.argmax(probs, dim=-1, keepdim=True)
            else:
                next_token = torch.multinomial(probs, num_samples=1)
            sequence = torch.cat([sequence,next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break

        return sequence.cpu().flatten()

In [21]:
from torch import GradScaler, autocast
from tqdm.auto import tqdm
import gc

class Trainer:
    def __init__(self, model_config, train_config, dls):
        self.train_config = train_config
        self.model_config = model_config
        self.device = self.train_config.device

        self.model = VisionGPT2Model.from_pretrained(model_config).to(self.device)
        self.model.pretrained_layers_trainable(trainable=False)

        print(f"Trainable parameters: {sum([p.numel() for p in self.model.parameters() if p.requires_grad])}")

        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.scaler = GradScaler(self.device)

        self.train_dl, self.val_dl = dls

        total_steps = len(self.train_dl)

        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.train_config.lr / 25.)
        self.sched = torch.optim.lr_scheduler.OneCycleLR(
            self.optim,
            max_lr=self.train_config.lr,
            epochs=self.train_config.epochs,
            steps_per_epoch=total_steps
        )

        self.metrics = pandas.DataFrame()
        self.metrics[["train_loss", "train_perplexity", "val_loss", "val_perplexity"]] = None

        self.gen_tfms = A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ToTensorV2()
        ])

    def save_model(self,):
        self.train_config.model_path.mkdir(exist_ok=True)
        sd = self.model.state_dict()
        torch.save(self.model, self.train_config.model_path/"captioner.pt")

    def load_best_model(self,):
        sd = torch.load(self.train_config.model_path/"captioner.pt", weights_only=False)
        self.model.load_state_dict(sd)

    def train_one_epoch(self, epoch):
        prog = tqdm(self.train_dl, total=len(self.train_dl))
        running_loss = 0.

        for image, input_ids, keywords, labels in prog:
            with autocast(self.device):
                image = image.to(self.device)
                input_ids = input_ids.to(self.device)
                keywords = keywords.to(self.device)
                labels = labels.to(self.device)

                loss = self.model(image, input_ids, keywords, labels)

                self.scaler.scale(loss).backward()
                self.scaler.step(self.optim)
                self.scaler.update()
                self.sched.step()
                self.optim.zero_grad(set_to_none=True)

                running_loss += loss.item()

                prog.set_description(f'train loss: {loss.item():.3f}')

            del image, input_ids, keywords, labels, loss

        print()

        train_loss = running_loss / len(self.train_dl)
        train_pxp = numpy.exp(train_loss)

        self.metrics.loc[epoch, ["train_loss", "train_perplexity"]] = (train_loss, train_pxp)

    @torch.no_grad()
    def valid_one_epoch(self, epoch):

        prog = tqdm(self.val_dl, total=len(self.val_dl))

        running_loss = 0.

        for image, input_ids, keywords, labels in prog:

            with autocast(self.device):
                image = image.to(self.device)
                input_ids = input_ids.to(self.device)
                keywords = keywords.to(self.device)
                labels = labels.to(self.device)

                loss = self.model(image, input_ids, keywords, labels)
                running_loss += loss.item()

                prog.set_description(f"Valid loss: {loss.item():.3f}")

            del image, input_ids, keywords, labels, loss

        print()

        val_loss = running_loss / len(self.val_dl)
        val_pxp = numpy.exp(val_loss)

        self.metrics.loc[epoch, ["val_loss", "val_perplexity"]] = (val_loss,val_pxp)

        return val_pxp

    def clean(self):
        gc.collect()
        torch.cuda.empty_cache()

    def fit(self,):
        best_pxp = 1e9
        best_epoch = -1
        prog = tqdm(range(self.train_config.epochs))

        for epoch in prog:
            if epoch == self.train_config.freeze_epochs_gpt:
                self.model.unfreeze_gpt_layers()
                print("Unfreezing GPT2 entirely...")

            if epoch == self.train_config.freeze_epochs_all:
                self.model.pretrained_layers_trainable(trainable=True)

            self.model.train()
            prog.set_description("Training")
            self.train_one_epoch(epoch)
            self.clean()

            self.model.eval()
            prog.set_description("Validating")
            pxp = self.valid_one_epoch(epoch)
            self.clean()

            print(self.metrics.tail(1))

            if pxp < best_pxp:
                best_pxp = pxp
                best_epoch = epoch
                print("Saving best model...")
                self.save_model()

        self.metrics.to_csv(self.train_config.model_path/"metrics.csv", index=False)

        return {
            "best_perplexity": best_pxp,
            "best_epoch": best_epoch
        }

    @torch.no_grad()
    def generate_caption(self, image, max_tokens=50, temperature=1.0, deterministic=False):

        self.model.eval()

        image = Image.open(image).convert("RGB")
        image = numpy.array(image)
        image = self.gen_tfms(image=image)["image"]
        image = image.unsqueeze(0).to(self.device)
        sequence = torch.ones(1, 1).to(device=self.device).long() * self.tokenizer.bos_token_id

        caption = self.model.generate(
            image,
            sequence,
            max_tokens=max_tokens,
            temperature=temperature,
            deterministic=deterministic
        )
        caption = self.tokenizer.decode(caption.numpy(),skip_special_tokens=True)

        return caption

In [22]:
from types import SimpleNamespace
from pathlib import Path

model_config = SimpleNamespace(
    vocab_size = 50_257,
    embed_dim = 768, # 768
    num_heads = 12,
    seq_len = 1024,
    depth = 12,
    attention_dropout = 0.1,
    residual_dropout = 0.1,
    mlp_ratio = 4,
    mlp_dropout = 0.1,
    emb_dropout = 0.1,
)

train_config = SimpleNamespace(
    epochs = 20,
    freeze_epochs_gpt = 1,
    freeze_epochs_all = 2,
    lr = 1e-4,
    device = 'cuda',
    model_path = Path('/content/drive/MyDrive/models/multi_modal_singleimage'),
    batch_size = 32
)

In [23]:
train_dataset = ImageReportWithKeywordsDataset(train_df, images_folder, tokenizer, train_tfms)
valid_dataset = ImageReportWithKeywordsDataset(valid_df, images_folder, tokenizer, valid_tfms)
test_dataset = ImageReportWithKeywordsDataset(test_df, images_folder, tokenizer, valid_tfms)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=train_config.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True,
    collate_fn=collate_fn
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=train_config.batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True,
    collate_fn=collate_fn
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=train_config.batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=2,
    persistent_workers=True,
    collate_fn=collate_fn
)

In [24]:
trainer = Trainer(model_config, train_config, (train_dataloader, valid_dataloader))
trainer.fit()

total_frozen_params = 210236928
Trainable parameters: 56733696


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/96 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.





  0%|          | 0/12 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



  train_loss train_perplexity  val_loss val_perplexity
0  10.574119      39109.42996  9.868452   19311.424671
Saving best model...
Unfreezing GPT2 entirely...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
1   7.478172      1769.004359  5.680498     293.095414
Saving best model...
total_frozen_params = 0


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
2   4.737868       114.190489  3.245229      25.667595
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
3   2.744952        15.563873  2.229294       9.293305
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
4    2.07475         7.962555  1.881847       6.565623
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
5   1.765016         5.841669  1.665977        5.29084
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
6   1.563225         4.774193  1.553606       4.728489
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
7   1.415324         4.117822  1.465965       4.331723
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
8   1.312149         3.714148  1.403186       4.068139
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


  train_loss train_perplexity  val_loss val_perplexity
9   1.234064         3.435162  1.358592       3.890711
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
10   1.162972         3.199427  1.329612       3.779577
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity val_loss val_perplexity
11   1.102239           3.0109  1.30485       3.687135
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity val_loss val_perplexity
12   1.052183         2.863897  1.30486       3.687171


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
13   1.010568         2.747161  1.279318       3.594188
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity val_loss val_perplexity
14   0.972637          2.64491  1.29166       3.638823


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
15   0.944458          2.57142  1.274871       3.578241
Saving best model...


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
16   0.920059         2.509437  1.278804       3.592342


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
17   0.904023         2.469518  1.279219       3.593833


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity val_loss val_perplexity
18   0.894609         2.446379  1.27848       3.591178


  0%|          | 0/96 [00:00<?, ?it/s]




  0%|          | 0/12 [00:00<?, ?it/s]


   train_loss train_perplexity  val_loss val_perplexity
19   0.892395         2.440968  1.280034       3.596761


{'best_perplexity': np.float64(3.578240912985376), 'best_epoch': 15}

In [26]:
trainer.metrics

Unnamed: 0,train_loss,train_perplexity,val_loss,val_perplexity
0,10.574119,39109.42996,9.868452,19311.424671
1,7.478172,1769.004359,5.680498,293.095414
2,4.737868,114.190489,3.245229,25.667595
3,2.744952,15.563873,2.229294,9.293305
4,2.07475,7.962555,1.881847,6.565623
5,1.765016,5.841669,1.665977,5.29084
6,1.563225,4.774193,1.553606,4.728489
7,1.415324,4.117822,1.465965,4.331723
8,1.312149,3.714148,1.403186,4.068139
9,1.234064,3.435162,1.358592,3.890711


In [27]:
best_model = torch.load("/content/drive/MyDrive/models/multi_modal_singleimage/captioner.pt", weights_only=False)
best_model.to(device)

VisionGPT2Model(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
        (fc2): Linear(in_features=3072, out_

In [28]:
xrv_keywords_extractor = XRayVisionKeywordsExtractor()

Downloading weights...
If this fails you can run `wget https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt -O /root/.torchxrayvision/models_data/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt`
[██████████████████████████████████████████████████]


In [42]:
res = xrv_keywords_extractor.extract([os.path.join(images_folder, x) for x in test_df.loc[100, "filename"]])
# xrv_keywords_extractor.extract(test_df["filename"].head(5))

In [49]:
numpy.unique(numpy.concat(res))

array(['Atelectasis', 'Consolidation', 'Effusion', 'Emphysema',
       'Enlarged Cardiomediastinum', 'Fibrosis', 'Infiltration',
       'Lung Lesion', 'Lung Opacity', 'Mass', 'Nodule',
       'Pleural_Thickening', 'Pneumothorax'], dtype=object)

In [58]:

def generate_caption(model, img_paths, max_tokens=200, temperature=1.0, deterministic=False):
    # model.eval()
    gen_tfms = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2()
    ])

    image = concat_images(img_paths)
    image = numpy.array(image)
    image = gen_tfms(image=image)['image']
    image = image.unsqueeze(0).to(device)  # Move the input image tensor to the same device as the model
    sequence = torch.ones(1, 1).long().to(device) * tokenizer.bos_token_id

    keywords = "keywords: " + ", ".join(numpy.unique(numpy.concat(xrv_keywords_extractor.extract(img_paths))))
    keywords = tokenizer(keywords, truncation=True, padding="max_length", max_length=50, return_tensors="pt")["input_ids"]
    keywords = keywords.to(device)  # Move the keywords tensor to the same device as the model

    caption = model.generate(
        image,
        sequence,
        keywords,
        max_tokens=max_tokens,
        temperature=temperature,
        deterministic=deterministic
    )
    caption = tokenizer.decode(caption.cpu().numpy(), skip_special_tokens=True)  # Move the generated caption back to CPU for decoding

    return caption

img_paths = [os.path.join(images_folder, x) for x in test_df.loc[100, "filename"]]
generated_report = generate_caption(best_model, img_paths)
print(test_df.loc[100, "report"])
print("------------")
print(generated_report)

Clear lungs bilaterally. No pneumothorax or pleural effusion. Normal cardiac contours
------------
 comparison XXXX sternotomy XXXX and bypass graft markers. Severe lumbar degenerative disc disease in the right base. Soft tissue tissues and bony structures are unremarkable.


In [59]:
import torch
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import nltk

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
# Function to calculate BLEU score
def compute_bleu(reference_texts, generated_texts):
    """
    Compute BLEU score between generated texts and references.

    :param reference_texts: List of lists of reference texts (for each generated report)
    :param generated_texts: List of generated reports
    :return: BLEU score
    """
    references = [[ref.split()] for ref in reference_texts]  # List of list of reference tokens
    candidates = [gen.split() for gen in generated_texts]   # List of list of generated tokens
    bleu_score = corpus_bleu(references, candidates)
    return bleu_score

# Function to calculate ROUGE score
def compute_rouge(reference_texts, generated_texts):
    """
    Compute ROUGE score between generated texts and references.

    :param reference_texts: List of reference reports
    :param generated_texts: List of generated reports
    :return: ROUGE score
    """
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    rouge_scores = {"rouge1": [], "rouge2": [], "rougeL": []}

    for reference, generated in zip(reference_texts, generated_texts):
        scores = scorer.score(reference, generated)
        for key in rouge_scores:
            rouge_scores[key].append(scores[key].fmeasure)

    avg_rouge1 = sum(rouge_scores["rouge1"]) / len(rouge_scores["rouge1"])
    avg_rouge2 = sum(rouge_scores["rouge2"]) / len(rouge_scores["rouge2"])
    avg_rougeL = sum(rouge_scores["rougeL"]) / len(rouge_scores["rougeL"])

    return avg_rouge1, avg_rouge2, avg_rougeL

# Evaluation function
def evaluate_model(model, folder_path, eval_set):
    generated_reports = []
    reference_reports = []

    print(f"Starting evaluation for model using {len(eval_set)} items")
    for idx in range(len(eval_set)):
        # Generate report for each image
        generated_report = generate_caption(model, os.path.join(folder_path, eval_set.loc[idx, "filename"]))
        reference_report = eval_set.loc[idx, "report"]

        generated_reports.append(generated_report)
        reference_reports.append(reference_report)

        print(f"Generated report {idx + 1}/{len(eval_set)}    \r", end="")

    print()

    # Compute BLEU
    bleu_score = compute_bleu(reference_reports, generated_reports)
    print(f"BLEU Score: {bleu_score:.4f}")

    # Compute ROUGE
    rouge1, rouge2, rougeL = compute_rouge(reference_reports, generated_reports)
    print(f"ROUGE-1: {rouge1:.4f}, ROUGE-2: {rouge2:.4f}, ROUGE-L: {rougeL:.4f}")

In [None]:
# Evaluate the model
evaluate_model(best_model, images_folder, test_df)

Starting evaluation for model using 747 items
Generated report 747/747    
BLEU Score: 0.0458
ROUGE-1: 0.2948, ROUGE-2: 0.0945, ROUGE-L: 0.2019


In [60]:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from collections import defaultdict

def compute_bleu_scores(reference_texts, generated_texts):
    references = [[ref.split()] for ref in reference_texts]
    candidates = [gen.split() for gen in generated_texts]
    smoothie = SmoothingFunction().method4
    bleu1 = corpus_bleu(references, candidates, weights=(1, 0, 0, 0), smoothing_function=smoothie)
    bleu2 = corpus_bleu(references, candidates, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
    bleu3 = corpus_bleu(references, candidates, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothie)
    bleu4 = corpus_bleu(references, candidates, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)
    return bleu1, bleu2, bleu3, bleu4

def compute_rouge_l(reference_texts, generated_texts):
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    scores = [scorer.score(ref, gen)["rougeL"].fmeasure for ref, gen in zip(reference_texts, generated_texts)]
    return numpy.mean(scores)

def compute_cider(reference_texts, generated_texts):
    # Simple CIDEr implementation for demonstration (not official)
    # For real evaluation, use pycocoevalcap or similar library
    def ngram_counts(text, n):
        tokens = text.split()
        return defaultdict(int, {tuple(tokens[i:i+n]): 1 for i in range(len(tokens)-n+1)})

    cider_scores = []
    for ref, gen in zip(reference_texts, generated_texts):
        score = 0
        for n in range(1, 5):
            ref_ngrams = ngram_counts(ref, n)
            gen_ngrams = ngram_counts(gen, n)
            overlap = sum(min(gen_ngrams[ng], ref_ngrams[ng]) for ng in gen_ngrams)
            total = max(len(gen.split())-n+1, 1)
            score += overlap / total
        cider_scores.append(score / 4)
    return numpy.mean(cider_scores)

# Generate reports for test set
generated_reports = []
reference_reports = []
for idx in range(len(test_df)):
    img_paths = [os.path.join(images_folder, x) for x in test_df.loc[idx, "filename"]]
    generated_report = generate_caption(best_model, img_paths)
    reference_report = test_df.loc[idx, "report"]
    generated_reports.append(generated_report)
    reference_reports.append(reference_report)
    if idx % 50 == 0:
        print(f"Processed {idx+1}/{len(test_df)}")

bleu1, bleu2, bleu3, bleu4 = compute_bleu_scores(reference_reports, generated_reports)
rouge_l = compute_rouge_l(reference_reports, generated_reports)
cider = compute_cider(reference_reports, generated_reports)

print(f"BLEU-1: {bleu1:.4f}")
print(f"BLEU-2: {bleu2:.4f}")
print(f"BLEU-3: {bleu3:.4f}")
print(f"BLEU-4: {bleu4:.4f}")
print(f"ROUGE-L: {rouge_l:.4f}")
print(f"CIDEr: {cider:.4f}")

Processed 1/383
Processed 51/383
Processed 101/383
Processed 151/383
Processed 201/383
Processed 251/383
Processed 301/383
Processed 351/383
BLEU-1: 0.2187
BLEU-2: 0.1111
BLEU-3: 0.0676
BLEU-4: 0.0422
ROUGE-L: 0.2198
CIDEr: 0.0755
