In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("raddar/chest-xrays-indiana-university")

print("Path to dataset files:", path)

Path to dataset files: C:\Users\ANNU-10\.cache\kagglehub\datasets\raddar\chest-xrays-indiana-university\versions\2


In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(torch.__version__)

Using device: cuda
2.7.1+cu128


In [4]:
import pandas

dataset_folder = path
images_folder = dataset_folder + "/images/images_normalized"
projections = pandas.read_csv(dataset_folder + "/indiana_projections.csv")
reports = pandas.read_csv(dataset_folder + "/indiana_reports.csv")

combined_dataset = projections.merge(reports, on="uid", how="inner")

def IsNotAvailable(value):
    return value.str.contains("unavailable", case=False, na=False) \
        | value.str.contains("not available", case=False, na=False) \
        | value.str.contains("none", case=False, na=False)

combined_dataset.loc[IsNotAvailable(combined_dataset["comparison"]), "comparison"] = "None"

combined_dataset["indication"] = combined_dataset["indication"].fillna("None")
combined_dataset["findings"] = combined_dataset["findings"].fillna("None")
combined_dataset["impression"] = combined_dataset["impression"].fillna("None")
combined_dataset["comparison"] = combined_dataset["comparison"].fillna("None")
combined_dataset["report"] = (
    "Indication: " + combined_dataset["indication"].astype(str) + "\n"
    + "Findings: " + combined_dataset["findings"].astype(str) + "\n"
    + "Impression: " + combined_dataset["impression"].astype(str) + "\n"
    + "Comparison: " + combined_dataset["comparison"].astype(str)
)

combined_dataset.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7466 entries, 0 to 7465
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   uid         7466 non-null   int64 
 1   filename    7466 non-null   object
 2   projection  7466 non-null   object
 3   MeSH        7466 non-null   object
 4   Problems    7466 non-null   object
 5   image       7466 non-null   object
 6   indication  7466 non-null   object
 7   comparison  7466 non-null   object
 8   findings    7466 non-null   object
 9   impression  7466 non-null   object
 10  report      7466 non-null   object
dtypes: int64(1), object(10)
memory usage: 641.7+ KB


In [5]:
combined_dataset.head()

Unnamed: 0,uid,filename,projection,MeSH,Problems,image,indication,comparison,findings,impression,report
0,1,1_IM-0001-4001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,Indication: Positive TB test\nFindings: The ca...
1,1,1_IM-0001-3001.dcm.png,Lateral,normal,normal,Xray Chest PA and Lateral,Positive TB test,,The cardiac silhouette and mediastinum size ar...,Normal chest x-XXXX.,Indication: Positive TB test\nFindings: The ca...
2,2,2_IM-0652-1001.dcm.png,Frontal,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Indication: Preop bariatric surgery.\nFindings...
3,2,2_IM-0652-2001.dcm.png,Lateral,Cardiomegaly/borderline;Pulmonary Artery/enlarged,Cardiomegaly;Pulmonary Artery,"Chest, 2 views, frontal and lateral",Preop bariatric surgery.,,Borderline cardiomegaly. Midline sternotomy XX...,No acute pulmonary findings.,Indication: Preop bariatric surgery.\nFindings...
4,3,3_IM-1384-1001.dcm.png,Frontal,normal,normal,Xray Chest PA and Lateral,"rib pain after a XXXX, XXXX XXXX steps this XX...",,,"No displaced rib fractures, pneumothorax, or p...","Indication: rib pain after a XXXX, XXXX XXXX s..."


In [6]:
for r in combined_dataset["report"].head(5).to_list():
    print(r)
    print("-----")

Indication: Positive TB test
Findings: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
Impression: Normal chest x-XXXX.
Comparison: None
-----
Indication: Positive TB test
Findings: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.
Impression: Normal chest x-XXXX.
Comparison: None
-----
Indication: Preop bariatric surgery.
Findings: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.
Impression: No acute pulmonary findings.
Comparison: None
-----
Indication: Preop bariatric surgery.
Findings: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XX

In [7]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(combined_dataset, test_size=0.2, random_state=42, shuffle=True)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
print(f"Train shape: {train_df.shape}, Test shape: {test_df.shape}")

Train shape: (5972, 11), Test shape: (1494, 11)


In [11]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pandas as pd
from torchvision import transforms

class MIMICDataset(Dataset):
    def __init__(self, dataset, img_dir, tokenizer, transform=None):
        self.data = dataset
        self.img_dir = img_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data.loc[idx, "filename"])
        image = Image.open(img_name).convert('RGB')
        if self.transform:
            image = self.transform(image)
        report = self.data.loc[idx, "report"]
        inputs = self.tokenizer(report, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
        return image, inputs['input_ids'].squeeze(), inputs['attention_mask'].squeeze()

class ResNetGPT2ReportGenerator(nn.Module):
    def __init__(self, gpt2_model_name='gpt2', resnet_model_name='resnet50', device='cuda'):
        super(ResNetGPT2ReportGenerator, self).__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
        resnet = getattr(models, resnet_model_name)(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.resnet_fc = nn.Linear(resnet.fc.in_features, self.gpt2.config.n_embd)
        self.device = device
        self.to(device)

    def forward(self, images, input_ids, attention_mask):
        with torch.no_grad():
            image_features = self.resnet(images).view(images.size(0), -1)
        image_features = self.resnet_fc(image_features)
        outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask, return_dict=True, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
        image_features = image_features.unsqueeze(1).repeat(1, hidden_states.size(1), 1)
        # combined_features = torch.cat((hidden_states, image_features), dim=-1)
        combined_features = hidden_states + image_features
        logits = self.gpt2.lm_head(combined_features)
        return logits

    def generate_report(self, image, max_length=100):
        image = image.unsqueeze(0).to(self.device)
        input_ids = torch.tensor(self.tokenizer.encode("Indication:")).unsqueeze(0).to(self.device)
        attention_mask = torch.ones(input_ids.shape, device=self.device)
        self.eval()
        with torch.no_grad():
            for _ in range(max_length):
                logits = self(image, input_ids, attention_mask)
                next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
                input_ids = torch.cat((input_ids, next_token_id), dim=1)
                attention_mask = torch.cat((attention_mask, torch.ones_like(next_token_id)), dim=1)
                if next_token_id.item() == self.tokenizer.encode("end")[0]:
                    break
        generated_report = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_report


In [12]:
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

def train(model, train_loader, val_loader, epochs=3, lr=5e-5):
    optimizer = AdamW(model.parameters(), lr=lr)
    total_steps = len(train_loader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for images, input_ids, attention_mask in train_loader:
            images = images.to(model.device)
            input_ids = input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)
            optimizer.zero_grad()
            logits = model(images, input_ids, attention_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), input_ids.view(-1))
            loss.backward()
            optimizer.step()
            scheduler.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

        evaluate(model, val_loader)

def evaluate(model, val_loader):
    model.eval()
    total_loss = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, input_ids, attention_mask in val_loader:
            images = images.to(model.device)
            input_ids = input_ids.to(model.device)
            attention_mask = attention_mask.to(model.device)
            logits = model(images, input_ids, attention_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), input_ids.view(-1))
            total_loss += loss.item()
    print(f"Validation Loss: {total_loss / len(val_loader)}")


In [None]:
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
from torchvision import transforms

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MIMICDataset(train_df, img_dir=images_folder, tokenizer=tokenizer, transform=transform)
val_dataset = MIMICDataset(test_df, img_dir=images_folder, tokenizer=tokenizer, transform=transform)

batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = ResNetGPT2ReportGenerator(device=device)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
train(model, train_loader, val_loader, epochs=3, lr=5e-5)


Epoch 1/3, Loss: 6.477251008618623e-05
