In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("raddar/chest-xrays-indiana-university")

print("Path to dataset files:", path)

  from .autonotebook import tqdm as notebook_tqdm


Path to dataset files: C:\Users\ANNU-10\.cache\kagglehub\datasets\raddar\chest-xrays-indiana-university\versions\2


In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.__version__)

Using device: cuda
2.7.1+cu128


In [3]:
import pandas

dataset_folder = path
images_folder = dataset_folder + "/images/images_normalized"
projections = pandas.read_csv(dataset_folder + "/indiana_projections.csv")
reports = pandas.read_csv(dataset_folder + "/indiana_reports.csv")

combined_dataset = projections.merge(reports, on="uid", how="inner")

def IsNotAvailable(value):
    return value.str.contains("unavailable", case=False, na=False) \
        | value.str.contains("not available", case=False, na=False) \
        | value.str.contains("none", case=False, na=False)

combined_dataset.loc[IsNotAvailable(combined_dataset["comparison"]), "comparison"] = "None"

combined_dataset["indication"] = combined_dataset["indication"].fillna("None")
combined_dataset["findings"] = combined_dataset["findings"].fillna("None")
combined_dataset["impression"] = combined_dataset["impression"].fillna("None")
combined_dataset["comparison"] = combined_dataset["comparison"].fillna("None")
combined_dataset["report"] = (
    "Indication: " + combined_dataset["indication"].astype(str) + "\n"
    + "Findings: " + combined_dataset["findings"].astype(str) + "\n"
    + "Impression: " + combined_dataset["impression"].astype(str) + "\n"
    + "Comparison: " + combined_dataset["comparison"].astype(str)
)

combined_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7466 entries, 0 to 7465
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   uid         7466 non-null   int64 
 1   filename    7466 non-null   object
 2   projection  7466 non-null   object
 3   MeSH        7466 non-null   object
 4   Problems    7466 non-null   object
 5   image       7466 non-null   object
 6   indication  7466 non-null   object
 7   comparison  7466 non-null   object
 8   findings    7466 non-null   object
 9   impression  7466 non-null   object
 10  report      7466 non-null   object
dtypes: int64(1), object(10)
memory usage: 641.7+ KB


In [4]:
combined_dataset.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,Indication: Positive TB test\nFindings: The ca...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,Indication: Positive TB test\nFindings: The ca...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Indication: Preop bariatric surgery.\nFindings...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Indication: Preop bariatric surgery.\nFindings...
4,3,3_IM-1384-1001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p...","Indication: rib pain after a XXXX, XXXX XXXX s..."


In [5]:
for r in combined_dataset["report"].head(5).to_list():
    print(r)
    print("-----")

Indication: Positive TB test
Findings: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
Impression: Normal chest x-XXXX.
Comparison: None
-----
Indication: Positive TB test
Findings: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
Impression: Normal chest x-XXXX.
Comparison: None
-----
Indication: Preop bariatric surgery.
Findings: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
Impression: No acute pulmonary findings.
Comparison: None
-----
Indication: Preop bariatric surgery.
Findings: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XX

In [6]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(combined_dataset, test_size=0.2, random_state=42, shuffle=True)
print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}")

Train shape: (5972, 11), Test shape: (1494, 11)


In [26]:
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2LMHeadModel, GPT2Tokenizer, ViTModel
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

# GPT-2 with Cross-Attention
class GPT2WithCrossAttention(nn.Module):
    def __init__(self, gpt2_model, vit_model, gpt2_hidden_size, vit_hidden_size):
        super(GPT2WithCrossAttention, self).__init__()
        
        # Load pre-trained GPT2 model
        self.gpt2 = gpt2_model
        
        # Load pre-trained Vision Transformer model
        self.vit = vit_model
        
        # Cross-attention layer (between ViT and GPT2)
        self.cross_attention = nn.MultiheadAttention(embed_dim=gpt2_hidden_size, num_heads=8, batch_first=True)
        
        # Linear layer to project ViT features into GPT2 hidden size
        self.image_projection = nn.Linear(vit_hidden_size, gpt2_hidden_size)

    def forward(self, input_ids, attention_mask, images):
        # Get the image embeddings (from ViT)
        image_features = self.vit(images).last_hidden_state  # Shape: (batch_size, num_patches, vit_hidden_size)
        
        # Project image features to match GPT-2 hidden size
        image_features = image_features.mean(dim=1)  # Shape: (batch_size, vit_hidden_size)
        image_features = self.image_projection(image_features)  # Shape: (batch_size, gpt2_hidden_size)

        # Get GPT2 embeddings (language embeddings)
        gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True)
        gpt2_hidden_states = gpt2_outputs.hidden_states[-1]  # Shape: (batch_size, seq_len, gpt2_hidden_size)

        # Cross-attention: GPT-2 attending to ViT embeddings
        # We concatenate the image features to the sequence of tokens in GPT-2
        image_features = image_features.unsqueeze(1).repeat(1, gpt2_hidden_states.size(1), 1)  # Broadcast image features
        
        # Apply multi-head attention (cross-attention) between GPT-2 hidden states and image features
        attn_output, attn_weights = self.cross_attention(query=gpt2_hidden_states, key=image_features, value=image_features)
        
        # Combine the attention output with GPT-2 hidden states
        combined_output = attn_output + gpt2_hidden_states  # Residual connection

        logits = self.gpt2.lm_head(combined_output)  # Project to vocabulary size
        
        return logits

# Define dataset
class XRayReportDataset(Dataset):
    def __init__(self, data, transform=None, tokenizer=None, max_length=512):
        self.data = data
        self.images_folder = images_folder
        self.image_paths = data["filename"].apply(lambda x: os.path.join(images_folder, x)).tolist()
        self.reports = data["report"].tolist()
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        report = self.reports[idx]
        
        # Load image
        image = Image.open(image_path).convert("RGB")
        
        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)
        
        # Tokenize text (report)
        inputs = self.tokenizer(report, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        
        return image, inputs.input_ids.squeeze(), inputs.attention_mask.squeeze()

# Load ViT and GPT-2 models
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token  # Set pad token to eos token for GPT-2

# Set models to eval mode
vit_model.eval()
gpt2_model.eval()

# Prepare dataset and dataloaders
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the custom GPT-2 model with cross-attention
hidden_size = gpt2_model.config.hidden_size
image_feature_size = vit_model.config.hidden_size  # Image features from ViT are the same size as GPT-2 hidden states
gpt2_vit_with_cross_att_model = GPT2WithCrossAttention(gpt2_model, vit_model, hidden_size, image_feature_size).to(device)

print(f"GPT2-VIT-CrossAtten: GPT-2 hidden size: {hidden_size}, ViT hidden size: {image_feature_size}")

GPT2-VIT-CrossAtten: GPT-2 hidden size: 768, ViT hidden size: 768


In [39]:
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer, ViTModel
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Hyperparameters
batch_size = 4
epochs = 5
learning_rate = 5e-5
max_length = 512  # GPT-2 max length for tokenized reports

train_dataset = XRayReportDataset(train_df, transform, gpt2_tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(gpt2_vit_with_cross_att_model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Loss function
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(epochs):
    gpt2_vit_with_cross_att_model.train()
    running_loss = 0.0
    
    for batch_idx, (images, input_ids, attention_mask) in enumerate(tqdm(train_dataloader)):
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        optimizer.zero_grad()

        # Forward pass
        outputs = gpt2_vit_with_cross_att_model(input_ids=input_ids, attention_mask=attention_mask, images=images)
        logits = outputs  # Shape: (batch_size, seq_length, vocab_size)

        # Calculate loss (using the labels as the input_ids)
        labels = input_ids  # Labels are the same as the input_ids
        loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        running_loss += loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
    
    # Optionally, save model after each epoch
    torch.save(gpt2_vit_with_cross_att_model.state_dict(), f"training/gpt2_vit_cross_attention/epoch_{(epoch + 1):02d}.pth")

100%|██████████████████████████████████████████████████████████████████████████████| 1493/1493 [17:23<00:00,  1.43it/s]


Epoch [1/5], Loss: 0.0053


100%|██████████████████████████████████████████████████████████████████████████████| 1493/1493 [17:20<00:00,  1.43it/s]


Epoch [2/5], Loss: 0.0001


100%|██████████████████████████████████████████████████████████████████████████████| 1493/1493 [17:21<00:00,  1.43it/s]


Epoch [3/5], Loss: 0.0001


100%|██████████████████████████████████████████████████████████████████████████████| 1493/1493 [17:20<00:00,  1.43it/s]


Epoch [4/5], Loss: 0.0001


100%|██████████████████████████████████████████████████████████████████████████████| 1493/1493 [17:20<00:00,  1.43it/s]


Epoch [5/5], Loss: 0.0000


In [40]:
import torch
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from rouge_score import rouge_scorer
import nltk

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\ANNU-10\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


BLEU Score: 0.0000
ROUGE-1: 0.0142, ROUGE-2: 0.0000, ROUGE-L: 0.0131


In [49]:
# Function to generate a report from an image
def generate_report_from_image(model, image, tokenizer, device):
    model.eval()
    
    image = image.unsqueeze(0).to(device) # Adding a dummy batch dimmension
    input_ids = tokenizer.encode("Findings:", return_tensors="pt").to(device)
    attention_mask = torch.ones(input_ids.shape, device=device)
    
    # Generate a report
    with torch.no_grad():
        output = model(input_ids=input_ids, attention_mask=attention_mask, images=image)
    
    generated_ids = torch.argmax(output, dim=-1)
    generated_report = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    return generated_report

In [None]:
# Function to calculate BLEU score
def compute_bleu(reference_texts, generated_texts):
    """
    Compute BLEU score between generated texts and references.
    
    :param reference_texts: List of lists of reference texts (for each generated report)
    :param generated_texts: List of generated reports
    :return: BLEU score
    """
    references = [[ref.split()] for ref in reference_texts]  # List of list of reference tokens
    candidates = [gen.split() for gen in generated_texts]   # List of list of generated tokens
    bleu_score = corpus_bleu(references, candidates)
    return bleu_score

# Function to calculate ROUGE score
def compute_rouge(reference_texts, generated_texts):
    """
    Compute ROUGE score between generated texts and references.
    
    :param reference_texts: List of reference reports
    :param generated_texts: List of generated reports
    :return: ROUGE score
    """
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    rouge_scores = {"rouge1": [], "rouge2": [], "rougeL": []}
    
    for reference, generated in zip(reference_texts, generated_texts):
        scores = scorer.score(reference, generated)
        for key in rouge_scores:
            rouge_scores[key].append(scores[key].fmeasure)
    
    avg_rouge1 = sum(rouge_scores["rouge1"]) / len(rouge_scores["rouge1"])
    avg_rouge2 = sum(rouge_scores["rouge2"]) / len(rouge_scores["rouge2"])
    avg_rougeL = sum(rouge_scores["rougeL"]) / len(rouge_scores["rougeL"])
    
    return avg_rouge1, avg_rouge2, avg_rougeL

# Evaluation function
def evaluate_model(model, dataloader, tokenizer, device):
    generated_reports = []
    reference_reports = []
    
    # Collect ground truth reports and generated reports
    for images, input_ids, attention_mask in dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        for i, image in enumerate(images):
            # Generate report for each image
            generated_report = generate_report_from_image(model, image, tokenizer, device)
            reference_report = tokenizer.decode(input_ids[i], skip_special_tokens=True)
            
            generated_reports.append(generated_report)
            reference_reports.append(reference_report)
    
    # Compute BLEU
    bleu_score = compute_bleu(reference_reports, generated_reports)
    print(f"BLEU Score: {bleu_score:.4f}")
    
    # Compute ROUGE
    rouge1, rouge2, rougeL = compute_rouge(reference_reports, generated_reports)
    print(f"ROUGE-1: {rouge1:.4f}, ROUGE-2: {rouge2:.4f}, ROUGE-L: {rougeL:.4f}")

# Define the dataloader for eevaluation
test_dataset = XRayReportDataset(test_df, transform, gpt2_tokenizer, max_length)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

gpt2_vit_with_cross_att_model.to(device)

# Evaluate the model
evaluate_model(gpt2_vit_with_cross_att_model, test_dataloader, gpt2_tokenizer, device)

In [51]:
index = 3
image_path = os.path.join(images_folder, combined_dataset.loc[index, "filename"])
image = transform(Image.open(image_path).convert("RGB"))
print("Actual Report:\n------------")
print(combined_dataset.loc[index, "report"])
print("------------")
print("Predicted Report:\n------------")
generate_report_from_image(gpt2_vit_with_cross_att_model, image, gpt2_tokenizer, device)
print("------------")

Actual Report:
------------
Indication: Preop bariatric surgery.
Findings: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
Impression: No acute pulmonary findings.
Comparison: None
------------
Predicted Report:
------------
------------
