In [None]:
# CNN-RNN + Beam Search Model for Chest X-Ray Report Generation
# Based on the Boag et al. paper "Baselines for Chest X-Ray Report Generation"

import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import matplotlib.pyplot as plt
from PIL import Image
import pickle
import time
import tqdm
import pydicom
import string
import certifi
import ssl

# Fix SSL certificate issues in macOS
import os
os.environ['SSL_CERT_FILE'] = certifi.where()
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()

# Simple tokenizer to avoid NLTK dependency
def simple_tokenize(text):
    """Simple tokenizer that splits text on whitespace and punctuation."""
    if not isinstance(text, str):
        return []
    # Remove punctuation and replace with space
    for char in string.punctuation:
        text = text.replace(char, ' ' + char + ' ')

    # Split on whitespace and filter empty tokens
    tokens = [token for token in text.lower().split() if token.strip()]
    return tokens

# Define paths
base_path = '/Users/simeon/Documents/DLH/content/mimic-cxr-project'
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'files')
output_dir = os.path.join(base_path, 'output')
reports_dir = os.path.join(base_path, 'reports')
models_dir = os.path.join(base_path, 'models')

# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)

# Import the report parser module
import sys
sys.path.append(f"{base_path}/modules")
from report_parser import parse_report, MIMIC_RE
print("Successfully imported report parser module")

# Load train and test data
train_df = pd.read_csv(os.path.join(data_dir, 'train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Vocabulary class to handle tokenization
class Vocabulary:
    def __init__(self, freq_threshold=2):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            # Use simple tokenizer
            for word in simple_tokenize(sentence):
                if word not in frequencies:
                    frequencies[word] = 0
                frequencies[word] += 1

                if word not in self.stoi and frequencies[word] >= self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        # Use simple tokenizer
        tokenized_text = simple_tokenize(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

# Class to load and preprocess the data
class ChestXRayReportDataset:
    def __init__(self, df, is_train=True, transform=None, max_seq_length=100):
        self.df = df
        self.is_train = is_train
        self.transform = transform
        self.max_seq_length = max_seq_length
        self.reports = []

        # Extract reports if training
        if self.is_train:
            print("Extracting report texts for training data...")
            for _, row in tqdm.tqdm(self.df.iterrows(), total=len(self.df)):
                subject_id = row['subject_id']
                study_id = row['study_id']

                # Construct path to report
                subject_prefix = f"p{str(subject_id)[:2]}"
                subject_dir = f"p{subject_id}"
                study_dir = f"s{study_id}"
                report_path = os.path.join(reports_dir, 'files', subject_prefix, subject_dir, f"{study_dir}.txt")

                try:
                    if os.path.exists(report_path):
                        report = parse_report(report_path)
                        if 'findings' in report and report['findings']:
                            self.reports.append((row['dicom_id'], report['findings']))
                except Exception as e:
                    pass  # Skip reports that can't be parsed

            print(f"Extracted {len(self.reports)} reports from training data")

        self.dicom_paths = []
        for _, row in self.df.iterrows():
            subject_id = row['subject_id']
            study_id = row['study_id']
            dicom_id = row['dicom_id']

            # Construct path to DICOM file
            subject_prefix = f"p{str(subject_id)[:2]}"
            subject_dir = f"p{subject_id}"
            study_dir = f"s{study_id}"
            dicom_file = f"{dicom_id}.dcm"
            dicom_path = os.path.join(files_path, subject_prefix, subject_dir, study_dir, dicom_file)

            if os.path.exists(dicom_path):
                self.dicom_paths.append((dicom_id, dicom_path))

    def __len__(self):
        return len(self.dicom_paths)

    def __getitem__(self, idx):
        dicom_id, dicom_path = self.dicom_paths[idx]

        # Load and transform image
        try:
            ds = pydicom.dcmread(dicom_path)
            pixel_array = ds.pixel_array

            # Normalize and convert to RGB
            pixel_array = pixel_array / np.max(pixel_array)
            img = np.uint8(pixel_array * 255)

            # Convert to RGB
            if len(img.shape) == 2:
                img_rgb = np.stack([img, img, img], axis=2)
            elif img.shape[2] == 1:
                img_rgb = np.concatenate([img, img, img], axis=2)
            else:
                img_rgb = img

            pil_img = Image.fromarray(img_rgb)
            if self.transform:
                image = self.transform(pil_img)
        except Exception as e:
            # Create blank image if loading fails
            image = torch.zeros(3, 224, 224)

        # Return image and report for training, only image for testing
        if self.is_train:
            for report_id, report_text in self.reports:
                if report_id == dicom_id:
                    return image, report_text, dicom_id
            # Return empty report if none found
            return image, "", dicom_id
        else:
            return image, dicom_id

# Create a simplified CNN-RNN model that follows the paper
class CNNRNNModel(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, cnn_feature_size):
        super(CNNRNNModel, self).__init__()

        # CNN encoder (DenseNet121)
        print("Loading DenseNet121 model without pretrained weights...")
        self.densenet = models.densenet121(pretrained=False)
        self.densenet.classifier = nn.Linear(cnn_feature_size, embed_size)

        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)

        # LSTM decoder
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)

        # Output layer
        self.fc = nn.Linear(hidden_size, vocab_size)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)

    def forward(self, images, captions=None, teacher_forcing_ratio=0.5):
        # Extract features from images using CNN encoder
        # [batch_size, 3, 224, 224] -> [batch_size, embed_size]
        features = self.densenet(images)

        # Training mode with captions
        if captions is not None:
            batch_size = features.size(0)
            caption_length = captions.size(1)

            # Initialize tensor for storing outputs
            outputs = torch.zeros(batch_size, caption_length, self.fc.out_features).to(device)

            # Initialize hidden and cell state
            h = torch.zeros(1, batch_size, self.lstm.hidden_size).to(device)
            c = torch.zeros(1, batch_size, self.lstm.hidden_size).to(device)

            # First input to LSTM is the image features
            # [batch_size, embed_size] -> [batch_size, 1, embed_size]
            x = features.unsqueeze(1)

            # Generate caption one word at a time
            for t in range(caption_length):
                # Forward through LSTM
                lstm_out, (h, c) = self.lstm(x, (h, c))

                # Forward through fully connected layer
                out = self.fc(lstm_out)

                # Store output
                outputs[:, t] = out.squeeze(1)

                # Teacher forcing
                use_teacher_force = torch.rand(1).item() < teacher_forcing_ratio

                if use_teacher_force and t < caption_length - 1:
                    # Use ground truth as next input
                    # [batch_size, 1] -> [batch_size, 1, embed_size]
                    x = self.embedding(captions[:, t+1].unsqueeze(1))
                else:
                    # Use predicted word as next input
                    # [batch_size, 1, vocab_size] -> [batch_size, 1, 1]
                    pred_token = out.argmax(2)
                    # [batch_size, 1, 1] -> [batch_size, 1, embed_size]
                    x = self.embedding(pred_token)

            return outputs

        # Inference mode
        else:
            batch_size = features.size(0)

            # Start with <SOS> token (index 1)
            current_token = torch.ones(batch_size, 1, dtype=torch.long).to(device)

            # Initialize hidden and cell state
            h = torch.zeros(1, batch_size, self.lstm.hidden_size).to(device)
            c = torch.zeros(1, batch_size, self.lstm.hidden_size).to(device)

            # First input to LSTM is the image features
            x = features.unsqueeze(1)

            # Store generated captions
            outputs = []

            # Generate first token using image features
            lstm_out, (h, c) = self.lstm(x, (h, c))
            out = self.fc(lstm_out)
            pred_token = out.argmax(2)
            outputs.append(pred_token.squeeze(1))

            # Use predicted token as input for next step
            x = self.embedding(pred_token)

            # Maximum length is 100 tokens
            for _ in range(99):  # Already generated first token
                # Forward through LSTM
                lstm_out, (h, c) = self.lstm(x, (h, c))

                # Forward through fully connected layer
                out = self.fc(lstm_out)

                # Get predicted token
                pred_token = out.argmax(2)

                # Add to outputs
                outputs.append(pred_token.squeeze(1))

                # Break if <EOS> token (index 2) is predicted
                if (pred_token == 2).all():
                    break

                # Prepare next input
                x = self.embedding(pred_token)

            # Convert list of tensors to tensor
            return torch.stack(outputs, dim=1)

# Beam Search implementation
class BeamSearch:
    def __init__(self, model, beam_size=4, max_length=100):
        self.model = model
        self.beam_size = beam_size
        self.max_length = max_length

    # Modify the beam search to include a repetition penalty
def search(self, image, vocab):
    """
    Perform beam search to generate caption with repetition penalty

    Args:
        image: Input image tensor [1, 3, 224, 224]
        vocab: Vocabulary object for mapping indices to words

    Returns:
        string: Generated caption
    """
    self.model.eval()
    with torch.no_grad():
        # Extract features from image
        features = self.model.densenet(image)

        # Initialize hidden state
        batch_size = 1
        h = torch.zeros(1, batch_size, self.model.lstm.hidden_size).to(device)
        c = torch.zeros(1, batch_size, self.model.lstm.hidden_size).to(device)

        # First input to LSTM is the image features
        x = features.unsqueeze(1)

        # Forward pass to get first predictions
        lstm_out, (h, c) = self.model.lstm(x, (h, c))
        outputs = self.model.fc(lstm_out)

        # Get top beam_size predictions
        log_probs, indexes = torch.topk(F.log_softmax(outputs.squeeze(1), dim=1), self.beam_size)

        # Initialize beams
        beams = [(log_probs[0][i].item(), [indexes[0][i].item()], h, c) for i in range(self.beam_size)]

        # Generate remaining words
        for _ in range(self.max_length - 1):
            # Collect new candidates
            candidates = []

            # Expand each beam
            for log_prob, seq, hidden_h, hidden_c in beams:
                # If sequence ends with <EOS>, keep it as is
                if seq[-1] == 2:  # <EOS> token
                    candidates.append((log_prob, seq, hidden_h, hidden_c))
                    continue

                # Prepare input
                word = torch.tensor([[seq[-1]]], dtype=torch.long).to(device)
                x = self.model.embedding(word)

                # Forward pass
                lstm_out, (h_new, c_new) = self.model.lstm(x, (hidden_h, hidden_c))
                outputs = self.model.fc(lstm_out)

                # Get log probabilities with repetition penalty
                logits = outputs.squeeze(1)

                # Apply repetition penalty - reduce probability of tokens that have already appeared
                if len(seq) > 1:  # Only apply after first token
                    for prev_token in set(seq):
                        logits[0, prev_token] /= 1.5  # Penalty factor

                log_probs = F.log_softmax(logits, dim=1)

                # Get top k predictions
                word_log_probs, word_indices = torch.topk(log_probs, self.beam_size)

                # Add new candidates
                for i in range(self.beam_size):
                    # Check for repetitions of 3 or more consecutive tokens
                    candidate_token = word_indices[0][i].item()
                    if len(seq) >= 2 and seq[-1] == seq[-2] == candidate_token:
                        continue  # Skip this candidate to avoid repetition

                    candidate_log_prob = log_prob + word_log_probs[0][i].item()
                    candidate_seq = seq + [candidate_token]
                    candidates.append((candidate_log_prob, candidate_seq, h_new, c_new))

            # If no candidates (all beams hit repetition limits), break
            if not candidates:
                break

            # Keep top beam_size candidates
            candidates.sort(key=lambda x: x[0], reverse=True)
            beams = candidates[:self.beam_size]

            # Stop if all beams end with <EOS> or max length reached
            if all(beam[1][-1] == 2 for beam in beams):
                break

        # Get best beam, normalize by length for longer sequences
        best_beam = max(beams, key=lambda x: x[0] / len(x[1]) if len(x[1]) > 10 else x[0])
        best_seq = best_beam[1]

        # Convert to words
        caption = []
        for token_id in best_seq:
            if token_id == 2:  # <EOS>
                break
            if token_id > 3:  # Skip <PAD>, < SOS >, <EOS>, <UNK>
                caption.append(vocab.itos[token_id])

        # If caption is too short, use another beam
        if len(caption) < 5 and len(beams) > 1:
            beams.sort(key=lambda x: x[0] / len(x[1]) if len(x[1]) > 10 else x[0], reverse=True)
            for beam in beams[1:]:  # Try other beams
                alt_seq = beam[1]
                alt_caption = []
                for token_id in alt_seq:
                    if token_id == 2:  # <EOS>
                        break
                    if token_id > 3:  # Skip <PAD>, < SOS >, <EOS>, <UNK>
                        alt_caption.append(vocab.itos[token_id])
                if len(alt_caption) >= 5:
                    return " ".join(alt_caption)

        return " ".join(caption)

# Function to generate dummy data for testing the model
def create_dummy_batch(batch_size=2, seq_len=10, vocab_size=100):
    # Create dummy images
    images = torch.randn(batch_size, 3, 224, 224)

    # Create dummy captions
    captions = torch.randint(0, vocab_size, (batch_size, seq_len))

    return images, captions

# Main execution function
def run():
    # Create dataset for training
    print("Creating training dataset...")
    train_dataset = ChestXRayReportDataset(train_df, is_train=True, transform=transform)

    # Build vocabulary
    print("Building vocabulary...")
    vocab = Vocabulary()
    all_reports = [report for _, report, _ in train_dataset]
    vocab.build_vocabulary(all_reports)
    print(f"Built vocabulary with {len(vocab)} tokens")

    # Save vocabulary
    vocab_path = os.path.join(models_dir, 'vocab.pkl')
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocab, f)

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0  # Set to 0 to avoid multiprocessing issues on Mac
    )

    # Initialize model
    model = CNNRNNModel(
        vocab_size=len(vocab),
        embed_size=256,
        hidden_size=512,
        cnn_feature_size=1024
    )

    # Move model to device
    model = model.to(device)

    # Test model with dummy data to verify dimensions
    print("Testing model with dummy data...")
    dummy_images, dummy_captions = create_dummy_batch(batch_size=2, seq_len=10, vocab_size=len(vocab))
    dummy_images = dummy_images.to(device)
    dummy_captions = dummy_captions.to(device)

    # Forward pass
    try:
        with torch.no_grad():
            outputs = model(dummy_images, dummy_captions)
            print(f"Dummy forward pass successful! Output shape: {outputs.shape}")
    except Exception as e:
        print(f"Error in dummy forward pass: {e}")
        return

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=16, gamma=0.5)  # LR decay every 16 epochs

    # Load model if available, train it for just 1 epoch to see if it works
    model_path = os.path.join(models_dir, 'cnn_rnn_beam.pth')
    # Change this section in the run() function
    if os.path.exists(model_path):
        print(f"Loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print("Training model for 64 epochs...")
        # Training loop for full 64 epochs
        num_epochs = 64  # As specified in the paper
        best_loss = float('inf')

        for epoch in range(num_epochs):
            model.train()
            train_loss = 0

            # Decrease teacher forcing ratio as training progresses
            # "We increase the probability of feeding a sample of the inferred probability
            # to itself by 0.05 per 16 epochs"
            teacher_forcing_ratio = max(0.5 - (epoch // 16) * 0.05, 0.0)

            for i, (images, captions, _) in enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
                images = images.to(device)

                # Tokenize and pad captions
                tokenized_captions = []
                for caption in captions:
                    if caption:
                        tokens = [1]  # < SOS >
                        tokens.extend(vocab.numericalize(caption))
                        tokens.append(2)  # <EOS>
                    else:
                        tokens = [1, 2]  # < SOS >, <EOS>
                    tokenized_captions.append(tokens)

                # Pad sequences
                padded_captions = []
                for tokens in tokenized_captions:
                    if len(tokens) > 100:
                        padded_captions.append(tokens[:100])
                    else:
                        padded_captions.append(tokens + [0] * (100 - len(tokens)))

                captions_tensor = torch.tensor(padded_captions).to(device)

                # Zero the gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = model(images, captions_tensor, teacher_forcing_ratio)

                # Reshape for loss calculation
                outputs = outputs.reshape(-1, outputs.shape[2])
                targets = captions_tensor.reshape(-1)

                # Calculate loss
                loss = criterion(outputs, targets)

                # Backward pass
                loss.backward()

                # Clip gradients
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

                # Update weights
                optimizer.step()

                # Update training loss
                train_loss += loss.item()

                # Print progress
                if (i + 1) % 50 == 0:
                    print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")

            # Update learning rate
            scheduler.step()

            # Calculate average loss for the epoch
            avg_loss = train_loss / len(train_loader)
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Teacher forcing: {teacher_forcing_ratio:.2f}")

            # Save model if it's the best so far
            if avg_loss < best_loss:
                best_loss = avg_loss
                torch.save(model.state_dict(), model_path)
                print(f"Model saved with improved loss: {best_loss:.4f}")

    # Create test dataset
    print("Creating test dataset...")
    test_dataset = ChestXRayReportDataset(test_df, is_train=False, transform=transform)

    test_loader = DataLoader(
        test_dataset,
        batch_size=1,  # Process one at a time for beam search
        shuffle=False,
        num_workers=0  # Set to 0 to avoid multiprocessing issues on Mac
    )

    # Generate reports for a small sample
    print("Generating reports with beam search...")
    model.eval()
    beam_search = BeamSearch(model, beam_size=4)  # Beam size 4 as in the paper

    generated_reports = {}

    with torch.no_grad():
        # Just process a few samples for demonstration
        max_samples = 3
        sample_count = 0

        for images, dicom_ids in tqdm.tqdm(test_loader):
            if sample_count >= max_samples:
                break

            images = images.to(device)
            dicom_id = dicom_ids[0]  # Batch size is 1

            try:
                caption = beam_search.search(images, vocab)
                generated_reports[dicom_id] = caption
                sample_count += 1
                print(f"Generated report {sample_count}/{max_samples}")
            except Exception as e:
                print(f"Error generating report for {dicom_id}: {e}")

    # Show sample reports
    print("\nSample generated reports:")
    for dicom_id, report in generated_reports.items():
        print(f"\nDICOM ID: {dicom_id}")
        print(f"Report: {report}")

    # Save generated reports
    if generated_reports:
        report_df = pd.DataFrame({
            'dicom_id': list(generated_reports.keys()),
            'generated': list(generated_reports.values())
        })

        output_file = os.path.join(output_dir, 'cnn_rnn_beam_sample.tsv')
        report_df.to_csv(output_file, sep='\t', index=False)

        print(f"Generated {len(generated_reports)} reports and saved to {output_file}")
    else:
        print("No reports were generated")

# Run the main function
if __name__ == "__main__":
    run()