In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, EncoderDecoderCache
from tqdm import tqdm  # Import progress bar

In [2]:
import nltk
nltk.download('punkt_tab')


[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\bagga\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
IMG_SIZE = (224, 224)
MODEL_NAME = "t5-small"  # Change to a medical-specific model if available
MODEL_PATH = "xray_report_model.pth"
data_path = "processed_xray_data.csv"
image_folder = "images/images_normalized"

# Load data
df = pd.read_csv(data_path)

In [4]:
# Image transformations
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [5]:
# Custom dataset
class XRayDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None):
        self.dataframe = dataframe
        self.img_dir = img_dir
        self.transform = transform
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        img_path = os.path.join(self.img_dir, row['filename'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        report_text = row['findings'] + " " + row['impression']
        tokenized_text = self.tokenizer(report_text, padding='max_length', truncation=True, max_length=256, return_tensors="pt")
        
        return image, tokenized_text.input_ids.squeeze(), tokenized_text.attention_mask.squeeze()


In [6]:
# Model
class XRayReportGenerator(nn.Module):
    def __init__(self):
        super(XRayReportGenerator, self).__init__()
        self.cnn = models.resnet50(pretrained=True)
        self.cnn.fc = nn.Linear(2048, 512)  # Extract features
        self.transformer = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    
    def forward(self, images, input_ids=None, attention_mask=None):
        img_features = self.cnn(images)
        encoder_outputs = self.transformer.get_encoder()(inputs_embeds=img_features.unsqueeze(1))
        
        if input_ids is not None:
            decoder_input_ids = input_ids[:, :-1]  # Remove last token for teacher forcing
            labels = input_ids[:, 1:].contiguous()  # Shifted target for loss computation
            outputs = self.transformer(
                input_ids=decoder_input_ids,
                attention_mask=attention_mask[:, :-1],
                encoder_outputs=encoder_outputs,
                labels=labels
            )
            return outputs.loss, outputs.logits
        
        # Inference Mode
        generated_ids = self.transformer.generate(encoder_outputs=encoder_outputs, max_length=256)
        return generated_ids


In [7]:
# DataLoader
dataset = XRayDataset(df, image_folder, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
# Initialize model
model = XRayReportGenerator().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)



In [9]:
if not os.path.exists(MODEL_PATH):
    print("Training model...")
    EPOCHS = 5
    for epoch in range(EPOCHS):
        model.train()
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")  # Add progress bar
        for images, input_ids, attention_mask in progress_bar:
            images, input_ids, attention_mask = images.to(DEVICE), input_ids.to(DEVICE), attention_mask.to(DEVICE)
            optimizer.zero_grad()
            loss, logits = model(images, input_ids, attention_mask)  # Extract logits correctly
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())  # Live loss update
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{EPOCHS}, Avg Loss: {avg_loss:.4f}", flush=True)  # Ensure immediate print
    
    torch.save(model.state_dict(), MODEL_PATH)
    print("Model saved successfully!")
else:
    print("Loading trained model...")
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()
    print("Model loaded successfully!")


Loading trained model...
Model loaded successfully!


In [10]:
# Function for Inference
def generate_report(image_path):
    if not os.path.exists(image_path):
        print("Error: Image file not found!", flush=True)
        return
    
    model.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        generated_ids = model.forward(image)
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    report = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print("Generated Report:", report, flush=True)
    return report


In [11]:
# Example Usage
sample_image_path = "images/images_normalized/1000_IM-0003-2001.dcm.png" # Change to actual image path
if os.path.exists(sample_image_path):
    generate_report(sample_image_path)
else:
    print("Error: Sample image not found. Please provide a valid image path.")

Generated Report: heart size and mediastinal contours are within normal limits. The lungs are clear. No acute disease.


In [12]:
import os
print(os.listdir("images/images_normalized")[:10])  # Show first 10 images


['1000_IM-0003-1001.dcm.png', '1000_IM-0003-2001.dcm.png', '1000_IM-0003-3001.dcm.png', '1001_IM-0004-1001.dcm.png', '1001_IM-0004-1002.dcm.png', '1002_IM-0004-1001.dcm.png', '1002_IM-0004-2001.dcm.png', '1003_IM-0005-2002.dcm.png', '1004_IM-0005-1001.dcm.png', '1004_IM-0005-2001.dcm.png']


In [14]:
from torch.utils.data import Subset
import nltk
from nltk.translate.bleu_score import corpus_bleu
from transformers import AutoTokenizer
from tqdm import tqdm

# Ensure nltk tokenizer is available.
nltk.download('punkt')

def tokenize_text(text):
    """Tokenize text using NLTK's word_tokenize (convert to lowercase for consistency)."""
    return nltk.word_tokenize(text.lower())

# Load your test CSV file (update the path if necessary).
test_csv = "processed_xray_data.csv"  # Update if needed.
test_df = pd.read_csv(test_csv)

# Create the full test dataset using your previously defined XRayDataset.
test_dataset = XRayDataset(test_df, image_folder, transform=transform)

# Use a subset of the test data for faster evaluation (e.g., first 100 samples).
subset_size = 100
subset_indices = list(range(subset_size))
test_dataset_subset = Subset(test_dataset, subset_indices)
test_loader_subset = DataLoader(test_dataset_subset, batch_size=1, shuffle=False)

# Load your trained model.
model = XRayReportGenerator()
model.load_state_dict(torch.load("xray_report_model.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()

# Load the tokenizer separately.
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Lists to store tokenized references and hypotheses.
all_references = []
all_hypotheses = []

# Evaluate the model on the test subset with progress tracking.
with torch.no_grad():
    for idx, (image, input_ids, attention_mask) in tqdm(enumerate(test_loader_subset),
                                                          total=len(test_loader_subset),
                                                          desc="Evaluating"):
        image = image.to(DEVICE)

        # Generate predicted report tokens.
        generated_ids = model(image)
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

        # Extract the reference text for the current sample.
        # Since we're using a subset, test_df's row order corresponds to the subset indices.
        reference_text = test_df.iloc[idx]['findings'] + " " + test_df.iloc[idx]['impression']

        # Tokenize both generated and reference texts.
        hyp_tokens = tokenize_text(generated_text)
        ref_tokens = tokenize_text(reference_text)

        all_hypotheses.append(hyp_tokens)
        all_references.append([ref_tokens])  # Wrap reference tokens in a list to support multiple references.

        # Optional: print the first three sample outputs for verification.
        if idx < 3:
            print(f"\nSample {idx+1}:")
            print("Generated:", generated_text)
            print("Reference:", reference_text)
            print("-" * 50)

# Compute the BLEU score using NLTK's corpus_bleu.
bleu_score = corpus_bleu(all_references, all_hypotheses)
print(f"\nBLEU score for the subset: {bleu_score:.4f}")


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\bagga\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Evaluating:   1%|▋                                                                     | 1/100 [00:00<00:44,  2.23it/s]


Sample 1:
Generated: heart size and mediastinal contours are within normal limits. The lungs are clear. No acute disease.
Reference: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax. Normal chest x-XXXX.
--------------------------------------------------


Evaluating:   2%|█▍                                                                    | 2/100 [00:00<00:41,  2.34it/s]


Sample 2:
Generated: heart size and mediastinal contours are within normal limits. The lungs are clear. No acute disease.
Reference: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax. Normal chest x-XXXX.
--------------------------------------------------


Evaluating:   3%|██                                                                    | 3/100 [00:01<00:41,  2.31it/s]


Sample 3:
Generated: heart is normal in size. The mediastinum is unremarkable. The lungs are clear. No acute disease.
Reference: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX. No acute pulmonary findings.
--------------------------------------------------


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 100/100 [00:40<00:00,  2.46it/s]


BLEU score for the subset: 0.0607





In [15]:
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

smoothie = SmoothingFunction().method4  # Helps with short sentences or no matches

# BLEU-1: weights = (1.0, 0, 0, 0)
bleu1 = corpus_bleu(all_references, all_hypotheses, weights=(1.0, 0, 0, 0), smoothing_function=smoothie)

# BLEU-2: weights = (0.5, 0.5, 0, 0)
bleu2 = corpus_bleu(all_references, all_hypotheses, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)

# BLEU-3: weights = (0.33, 0.33, 0.33, 0)
bleu3 = corpus_bleu(all_references, all_hypotheses, weights=(1/3, 1/3, 1/3, 0), smoothing_function=smoothie)

# BLEU-4: weights = (0.25, 0.25, 0.25, 0.25)
bleu4 = corpus_bleu(all_references, all_hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)

# Print results
print(f"\nBLEU-1: {bleu1:.4f}")
print(f"BLEU-2: {bleu2:.4f}")
print(f"BLEU-3: {bleu3:.4f}")
print(f"BLEU-4: {bleu4:.4f}")



BLEU-1: 0.1460
BLEU-2: 0.1064
BLEU-3: 0.0797
BLEU-4: 0.0607
