In [87]:
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer
import os
from docx import Document
from sklearn.model_selection import StratifiedShuffleSplit
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from transformers import BertTokenizer, BertModel
from PIL import Image
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

# Data Prep

In [27]:
# Load and clean all clinical notes
notes_path = "data/notes"

def load_docx(file_path):
    """Load a .docx file and extract text"""
    doc = Document(file_path)
    text = "\n".join([para.text for para in doc.paragraphs])
    return text

def clean_clinical_text(text):
    """Cleans clinical text for NLP processing."""
    text = "\n".join([line.strip() for line in text.split("\n") if line.strip()])
    text = text.lower()
    text = text.split("\n", 1)[1]
    text = text.replace("\n", " ")
    return text

def load_clinical_notes(file_path):
    notes = {}
    for root, dirs, files in os.walk(file_path):
        for file in files:
            filename, file_extension = os.path.splitext(file)
            if file.endswith(".docx"):
                note = load_docx(os.path.join(root, file))
                note = clean_clinical_text(note)
                notes[filename] = note
    return notes

d = load_clinical_notes(notes_path)


In [31]:
image_folder = 'data/images/all'
notes_folder = 'data/notes'
image_list = os.listdir(image_folder)

def extract_image_id(image_path):
    # image_path = 'P100_L_CM_MLO.jpg'
    # Extract the image ID from the filename

    # Assuming the image ID is the part before the first underscore
    image_id = image_path.split('_')[0]
    return image_id

image_ids = [extract_image_id(image_path) for image_path in image_list]

# create list of paths for all 503 images and their corresponding notes
image_paths = [os.path.join(image_folder, image_path) for image_path in image_list]
notes = [d[image_id] if image_id in d else None for image_id in image_ids]

In [78]:
# Now we have a list of image paths and their corresponding notes
# Time to extract the corresponding labels from the CSV file

# Length of image_paths, notes and labels = 503
# Normal = 0 
# Benign = 1
# Malignant = 2

# Load the CSV file
csv_path = 'data/clinical_data.csv'
df = pd.read_csv(csv_path)
labels = []
encode = {'Benign': 1, 'Malignant': 2, 'Normal': 0}

for id in image_list:
    c = id.split('.')[0]
    
    # extract row where column 'Image_name' = c and find pathology column
    row = df[df['Image_name'] == c]
    if not row.empty:
        pathology = row['Pathology Classification/ Follow up'].values[0]
        encode_pathology = encode.get(pathology, None)
        if encode_pathology is not None:
            labels.append(encode_pathology)
        else:
            # If the pathology is not in the encode dictionary, append None
            print(f"Pathology '{pathology}' not found in encoding dictionary for image {c}.")
    else:
        print(c)
        print(f"No matching row found in CSV for image {c}")

In [47]:
# Check if the lengths of image_paths, notes, and labels match
print(f"Length of image_paths: {len(image_paths)}")
print(f"Length of notes: {len(notes)}")
print(f"Length of labels: {len(labels)}")

print("Normal: ", labels.count(0))
print("Benign: ", labels.count(1))
print("Malignant: ", labels.count(2))

Length of image_paths: 503
Length of notes: 503
Length of labels: 503
Normal:  206
Benign:  128
Malignant:  168


# Dataset

In [50]:
# ============================
#    Custom Dataset Class
# ============================
class CancerDataset(Dataset):
    def __init__(self, image_paths, notes, labels, tokenizer, max_len, transform=None):
        self.image_paths = image_paths
        self.notes = notes
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load Image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image) # apply resize / normalization / convert to tensor 
        
        # Tokenize Clinical Notes
        text = str(self.notes[idx])
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors='pt')
        input_ids = encoding['input_ids'].squeeze(0)  # (max_len,)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Label
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        return image, input_ids, attention_mask, label

# Feature Extraction

In [82]:
# ============================
#    Feature Extractors
# ============================
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  # Remove last FC layer
        self.bert = BertModel.from_pretrained('bert-base-uncased')
    
    def forward(self, image, input_ids, attention_mask):
        # Image Features
        img_feat = self.cnn(image).flatten(1)  # (batch, 512)
        
        # Text Features
        text_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_feat = text_outputs.last_hidden_state[:, 0, :]  # CLS token (batch, 768)
        
        return img_feat, text_feat

# Fusion

In [91]:
import torch
import torch.nn as nn

class CAM(nn.Module):
    def __init__(self):
        super(CAM, self).__init__()
        self.encoder1 = nn.Linear(512, 128)  # Image
        self.encoder2 = nn.Linear(768, 128)  # Text
        self.affine_a = nn.Linear(128, 128, bias=False)
        self.affine_v = nn.Linear(128, 128, bias=False)
        self.W_ca = nn.Linear(128, 32, bias=False)  # Cross attention weight for image
        self.W_cv = nn.Linear(128, 32, bias=False)  # Cross attention weight for text
        self.W_ha = nn.Linear(32, 128, bias=False)  # Attention map weight for image
        self.W_hv = nn.Linear(32, 128, bias=False)  # Attention map weight for text
        self.relu = nn.ReLU()

        # Fully connected layers for final prediction
        self.regressor = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 3),  # 3 classes: Normal, Benign, Malignant
        )
    
    def forward(self, img_feat, text_feat):
        # Image and text feature encoders
        img_feat = self.encoder1(img_feat)  # X_v (image)
        text_feat = self.encoder2(text_feat)  # X_a (text)

        # Attention calculation for both modalities
        att_img = self.affine_a(img_feat)
        att_text = self.affine_v(text_feat)

        # Joint attention for both modalities (image and text)
        H_a = self.relu(self.W_ca(att_img))
        H_v = self.relu(self.W_cv(att_text))

        # Attending to features (modulation based on attention maps)
        img_out = self.W_ha(H_a) + img_feat
        text_out = self.W_hv(H_v) + text_feat

        # Fusing the attended features (concatenation)
        fused_feat = torch.cat((img_out, text_out), dim=1)

        # Final regression to get prediction
        output = self.regressor(fused_feat)
        return output


# Training set up

In [84]:
# we already have the image paths, notes and labels
image_paths = [i.replace('\\', '/') for i in image_paths]

print("Length of image_paths: ", len(image_paths))
print("Length of notes: ", len(notes))
print("Length of labels: ", len(labels))

# check for None values in labels
labels.count(None)  # should be 0

Length of image_paths:  503
Length of notes:  503
Length of labels:  503


0

In [85]:
# Hyperparameters
BATCH_SIZE = 16
MAX_LEN = 128
EPOCHS = 10
LEARNING_RATE = 1e-4

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Split the image_paths,notes and labels
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, val_index in sss.split(image_paths, labels):
    train_image_paths, val_image_paths = [image_paths[i] for i in train_index], [image_paths[i] for i in val_index]
    train_notes, val_notes = [notes[i] for i in train_index], [notes[i] for i in val_index]
    train_labels, val_labels = [labels[i] for i in train_index], [labels[i] for i in val_index]


# Create the training and validation datasets
train_dataset = CancerDataset(train_image_paths, train_notes, train_labels, tokenizer, MAX_LEN, transform)
val_dataset = CancerDataset(val_image_paths, val_notes, val_labels, tokenizer, MAX_LEN, transform)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Train and Test

In [92]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
extractor = FeatureExtractor().to(device)
cam_model = CAM().to(device)

# Use CrossEntropyLoss for multi-class classification
criterion = nn.CrossEntropyLoss()

# Adam optimizer
optimizer = optim.Adam(cam_model.parameters(), lr=LEARNING_RATE)
# ============================


#    Training Loop
# ============================

best_val_acc = 0.0  # Track the best validation accuracy
best_val_auc = 0.0  # Track the best validation AUC score

for epoch in range(EPOCHS):
    cam_model.train()
    extractor.train()
    total_loss = 0
    correct_preds = 0
    total_preds = 0
    all_preds = []
    all_labels = []
    
    for images, input_ids, attention_mask, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Feature extraction
        img_feat, text_feat = extractor(images, input_ids, attention_mask)
        
        # CAM model forward pass
        outputs = cam_model(img_feat, text_feat)  # Multi-class outputs (logits)
        
        # Check if the output is a 2D tensor with shape [batch_size, num_classes]
        assert outputs.shape[1] == 3, f"Expected 3 classes in output, but got {outputs.shape[1]}"
        # Loss calculation (CrossEntropyLoss expects raw logits)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate accuracy
        _, predicted_labels = torch.max(outputs, 1)  # Get the class with the highest probability
        correct_preds += (predicted_labels == labels).sum().item()
        total_preds += labels.size(0)

        # For AUC-ROC calculation, we need probabilities
        probs = torch.nn.functional.softmax(outputs, dim=1)
        all_preds.append(probs.cpu().detach().numpy())
        all_labels.append(labels.cpu().detach().numpy())

    train_acc = correct_preds / total_preds
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Compute AUC-ROC score
    train_auc = roc_auc_score(all_labels, all_preds, multi_class='ovr', average='macro')
    
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}, Training Accuracy: {train_acc:.4f}, Training AUC-ROC: {train_auc:.4f}")
    
    # ============================
    # Validation Loop
    # ============================
    cam_model.eval()  # Switch to evaluation mode
    val_correct_preds = 0
    val_total_preds = 0
    val_all_preds = []
    val_all_labels = []
    with torch.no_grad():
        for images, input_ids, attention_mask, labels in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{EPOCHS}"):
            images, input_ids, attention_mask, labels = images.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)
            
            # Feature extraction
            img_feat, text_feat = extractor(images, input_ids, attention_mask)
            
            # CAM model forward pass
            outputs = cam_model(img_feat, text_feat)  # Multi-class outputs (logits)
            
            # Calculate accuracy
            _, predicted_labels = torch.max(outputs, 1)  # Get the class with the highest probability
            val_correct_preds += (predicted_labels == labels).sum().item()
            val_total_preds += labels.size(0)

            # For AUC-ROC calculation, we need probabilities
            probs = torch.nn.functional.softmax(outputs, dim=1)
            val_all_preds.append(probs.cpu().detach().numpy())
            val_all_labels.append(labels.cpu().detach().numpy())
    
    val_acc = val_correct_preds / val_total_preds
    val_all_preds = np.concatenate(val_all_preds, axis=0)
    val_all_labels = np.concatenate(val_all_labels, axis=0)

    # Compute AUC-ROC score for validation
    val_auc = roc_auc_score(val_all_labels, val_all_preds, multi_class='ovr', average='macro')
    
    print(f"Epoch {epoch+1}/{EPOCHS}, Validation Accuracy: {val_acc:.4f}, Validation AUC-ROC: {val_auc:.4f}")
    
    # Save the best model based on validation accuracy or AUC
    if val_acc > best_val_acc or val_auc > best_val_auc:
        best_val_acc = val_acc
        best_val_auc = val_auc
        torch.save(cam_model.state_dict(), "best_cancer_cam_model.pth")
        print(f"Saved best model with Validation Accuracy: {val_acc:.4f} and Validation AUC-ROC: {val_auc:.4f}")


Epoch 1/10: 100%|██████████| 26/26 [04:19<00:00,  9.99s/it]


Epoch 1/10, Loss: 1.0841, Training Accuracy: 0.4005, Training AUC-ROC: 0.5297


Validation Epoch 1/10: 100%|██████████| 7/7 [00:26<00:00,  3.72s/it]


Epoch 1/10, Validation Accuracy: 0.4158, Validation AUC-ROC: 0.6326
Saved best model with Validation Accuracy: 0.4158 and Validation AUC-ROC: 0.6326


Epoch 2/10: 100%|██████████| 26/26 [04:24<00:00, 10.18s/it]


Epoch 2/10, Loss: 1.0526, Training Accuracy: 0.4627, Training AUC-ROC: 0.6528


Validation Epoch 2/10: 100%|██████████| 7/7 [00:28<00:00,  4.08s/it]


Epoch 2/10, Validation Accuracy: 0.4158, Validation AUC-ROC: 0.6767
Saved best model with Validation Accuracy: 0.4158 and Validation AUC-ROC: 0.6767


Epoch 3/10: 100%|██████████| 26/26 [04:23<00:00, 10.15s/it]


Epoch 3/10, Loss: 1.0452, Training Accuracy: 0.4975, Training AUC-ROC: 0.6729


Validation Epoch 3/10: 100%|██████████| 7/7 [00:25<00:00,  3.65s/it]


Epoch 3/10, Validation Accuracy: 0.4653, Validation AUC-ROC: 0.6885
Saved best model with Validation Accuracy: 0.4653 and Validation AUC-ROC: 0.6885


Epoch 4/10: 100%|██████████| 26/26 [04:23<00:00, 10.13s/it]


Epoch 4/10, Loss: 1.0077, Training Accuracy: 0.5448, Training AUC-ROC: 0.7287


Validation Epoch 4/10: 100%|██████████| 7/7 [00:25<00:00,  3.64s/it]


Epoch 4/10, Validation Accuracy: 0.4752, Validation AUC-ROC: 0.7060
Saved best model with Validation Accuracy: 0.4752 and Validation AUC-ROC: 0.7060


Epoch 5/10: 100%|██████████| 26/26 [04:31<00:00, 10.46s/it]


Epoch 5/10, Loss: 0.9962, Training Accuracy: 0.5423, Training AUC-ROC: 0.7494


Validation Epoch 5/10: 100%|██████████| 7/7 [00:27<00:00,  3.92s/it]


Epoch 5/10, Validation Accuracy: 0.5545, Validation AUC-ROC: 0.6819
Saved best model with Validation Accuracy: 0.5545 and Validation AUC-ROC: 0.6819


Epoch 6/10: 100%|██████████| 26/26 [04:37<00:00, 10.66s/it]


Epoch 6/10, Loss: 0.9708, Training Accuracy: 0.5647, Training AUC-ROC: 0.7351


Validation Epoch 6/10: 100%|██████████| 7/7 [00:25<00:00,  3.65s/it]


Epoch 6/10, Validation Accuracy: 0.5149, Validation AUC-ROC: 0.7037
Saved best model with Validation Accuracy: 0.5149 and Validation AUC-ROC: 0.7037


Epoch 7/10: 100%|██████████| 26/26 [04:29<00:00, 10.36s/it]


Epoch 7/10, Loss: 0.9356, Training Accuracy: 0.5771, Training AUC-ROC: 0.7787


Validation Epoch 7/10: 100%|██████████| 7/7 [00:24<00:00,  3.49s/it]


Epoch 7/10, Validation Accuracy: 0.5545, Validation AUC-ROC: 0.6942
Saved best model with Validation Accuracy: 0.5545 and Validation AUC-ROC: 0.6942


Epoch 8/10: 100%|██████████| 26/26 [04:14<00:00,  9.77s/it]


Epoch 8/10, Loss: 0.9203, Training Accuracy: 0.5821, Training AUC-ROC: 0.7475


Validation Epoch 8/10: 100%|██████████| 7/7 [00:25<00:00,  3.67s/it]


Epoch 8/10, Validation Accuracy: 0.5743, Validation AUC-ROC: 0.7265
Saved best model with Validation Accuracy: 0.5743 and Validation AUC-ROC: 0.7265


Epoch 9/10: 100%|██████████| 26/26 [04:09<00:00,  9.59s/it]


Epoch 9/10, Loss: 0.9042, Training Accuracy: 0.5746, Training AUC-ROC: 0.7802


Validation Epoch 9/10: 100%|██████████| 7/7 [00:25<00:00,  3.65s/it]


Epoch 9/10, Validation Accuracy: 0.5743, Validation AUC-ROC: 0.7106


Epoch 10/10: 100%|██████████| 26/26 [04:21<00:00, 10.07s/it]


Epoch 10/10, Loss: 0.8847, Training Accuracy: 0.5821, Training AUC-ROC: 0.7801


Validation Epoch 10/10: 100%|██████████| 7/7 [00:25<00:00,  3.67s/it]

Epoch 10/10, Validation Accuracy: 0.5644, Validation AUC-ROC: 0.7259



