# Import Packages

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)  
warnings.filterwarnings("ignore", category=FutureWarning)
import seaborn as sns
import re
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
import glob
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pydicom
from scipy.ndimage import zoom
from scipy.stats import chi2_contingency
from scipy import stats
from scipy.stats import ttest_ind
from fastai.vision.all import *
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import Image
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import torch.optim as optim
import torch.cuda.amp as amp
from torchmetrics.classification import BinaryAccuracy
from sklearn.metrics import roc_auc_score, roc_curve, auc
from sklearn.utils import resample
import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display, clear_output
import nibabel as nib
from skimage.transform import resize
from torchsummary import summary
from torchvision import models
from torchvision.models.resnet import ResNet, BasicBlock
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.utils import concordance_index
from lifelines.statistics import multivariate_logrank_test, logrank_test
from lifelines.plotting import add_at_risk_counts
import torchio as tio
import random

# Define Pre-Processing Functions

Dataset Class

In [3]:
class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

In [4]:
class MRI(Dataset):
    def __init__(self, rd, dfs, sf, test=0.2):
        self.mode = 'train'

        # Labels
        rd_label = np.ones(rd.shape[0], dtype=np.float32)
        dfs_label = np.zeros(dfs.shape[0], dtype=np.float32)

        # Concatenate images and labels
        self.labels = np.concatenate((rd_label, dfs_label))

        # Combine image arrays but don't load them into memory until needed
        self.rd = rd
        self.dfs = dfs
        self.image_indices = np.concatenate((np.arange(rd.shape[0]), np.arange(dfs.shape[0])))

        # Split the data into train and validation sets (indices only)
        self.train_val_split(sf, test)

    def train_val_split(self, sf, test):
        # Split only indices to avoid duplicating data in memory
        self.train_idx, self.val_idx = train_test_split(np.arange(len(self.labels)), test_size=test, random_state=sf, shuffle=True)

    def __len__(self):
        if self.mode == 'train':
            return len(self.train_idx)
        else:
            return len(self.val_idx)

    def __getitem__(self, idx):
        # Select the correct index based on the mode
        if self.mode == 'train':
            idx = self.train_idx[idx]
        else:
            idx = self.val_idx[idx]

        # Determine whether the image comes from rd or dfs
        if idx < len(self.rd):
            image = self.rd[idx]
        else:
            image = self.dfs[idx - len(self.rd)]

        label = self.labels[idx]

        return {'image': image, 'label': label}

    def set_mode(self, mode):
        if mode in ['train', 'val']:
            self.mode = mode
        else:
            raise ValueError("Mode should be 'train' or 'val'")

    def get_train_rd_count(self):
        """Returns the number of rd samples in the training set."""
        rd_indices = [idx for idx in self.train_idx if idx < len(self.rd)]
        return len(rd_indices)

    def get_train_dfs_count(self):
        """Returns the number of dfs samples in the training set."""
        dfs_indices = [idx for idx in self.train_idx if idx >= len(self.rd)]
        return len(dfs_indices)

Generating Table 1

In [5]:
def perform_chi_square(observed, expected, total_observed, total_expected):
    if total_observed == 0 or total_expected == 0:
        return ""
    
    observed_frequencies = [observed, total_observed - observed]
    expected_frequencies = [expected, total_expected - expected]
    
    # Check for zeros in the expected frequencies
    if 0 in expected_frequencies:
        return ""
    
    contingency_table = [observed_frequencies, expected_frequencies]
    chi2, p_value, _, _ = chi2_contingency(contingency_table)
    
    if p_value < 0.001:
        return "***"
    elif p_value < 0.01:
        return "**"
    elif p_value < 0.05:
        return "*"
    else:
        return ""
        
def format_output(count, total_in_interval, baseline_count, total_baseline):
    percentage = 100 * count / total_in_interval if total_in_interval != 0 else 0
    star = perform_chi_square(count, baseline_count, total_in_interval, total_baseline)
    return f"{int(count)} ({percentage:.2f}%) {star}"

def summarize_data(group, baseline_group, total_baseline, cat_columns_dict, cont_columns_dict):
    total_in_interval = len(group)  # Total number of cases in this interval

    summary = {
        'Number of Patients, n': total_in_interval,
    }

    # Summarize continuous variables with t-test
    for display_name, (col_name, stat_method) in cont_columns_dict.items():
        # Perform t-test between the group and baseline group
        t_stat, p_value = ttest_ind(group[col_name], baseline_group[col_name], nan_policy='omit', equal_var=False)

        # Determine significance level
        if p_value < 0.001:
            star = "***"
        elif p_value < 0.01:
            star = "**"
        elif p_value < 0.05:
            star = "*"
        else:
            star = ""

        # Calculate and format statistics based on preferred method
        if stat_method == 'mean_std':
            mean = group[col_name].mean()
            std = group[col_name].std()
            summary[display_name] = f"{mean:.2f} ± {std:.2f} {star}"
        elif stat_method == 'median_iqr':
            median = group[col_name].median()
            q1 = group[col_name].quantile(0.25)
            q3 = group[col_name].quantile(0.75)
            summary[display_name] = f"{median:.2f} [{q1:.2f}, {q3:.2f}] {star}"
        else:
            # Default to mean (SD) if stat_method is not recognized
            mean = group[col_name].mean()
            std = group[col_name].std()
            summary[display_name] = f"{mean:.2f} ({std:.2f}) {star}"

    # Summarize categorical variables with chi-square test
    for display_name, col_name in cat_columns_dict.items():
        if col_name == "":
            summary[display_name] = ""
        else:
            count = group[col_name].sum()
            baseline_count = baseline_group[col_name].sum()
            summary[display_name] = format_output(count, total_in_interval, baseline_count, total_baseline)

    return pd.Series(summary)

def generate_summary_table(df, cat_columns_dict, cont_columns_dict, baseline_interval=1, variable='COVID'):
    df[variable] = pd.Categorical(df[variable])

    baseline_group = df[df[variable] == baseline_interval]
    total_baseline = len(baseline_group)

    table = df.groupby(variable).apply(
        lambda group: summarize_data(group, baseline_group, total_baseline, cat_columns_dict, cont_columns_dict)
    )

    table = table.transpose()

    # Optional: Reorder columns to have the baseline group first
    if baseline_interval in table.columns:
        cols = [baseline_interval] + [col for col in table.columns if col != baseline_interval]
        table = table[cols]

    return table

Duke

In [6]:
def count_dcm_files(folder_path):
    # Use glob to find all .dcm files in the folder
    dcm_files = glob.glob(os.path.join(folder_path, '*.dcm'))

    # Count the number of files found
    num_dcm_files = len(dcm_files)

    return num_dcm_files

In [7]:
def find_folders_with_dcm_files(root_folder):
    folders_with_dcm = []

    for folder_path, _, _ in os.walk(root_folder):
        # Use glob to check for .dcm files in the current folder
        dcm_files = glob.glob(os.path.join(folder_path, '*.dcm'))

        if dcm_files:
            # Add the full path of the folder to the list
            folders_with_dcm.append(folder_path)

    return folders_with_dcm

In [8]:
def sort_by_number_after_last_slash(lst):
    # Sort the list by extracting the number after the last '/'
    sorted_list = sorted(lst, key=lambda x: float(x.split('/')[-1].split('.')[0]))  # Extract number before first period
    return sorted_list

In [9]:
def display_first_file_name(folder_path):
    try:
        # List all files in the folder
        files = os.listdir(folder_path)
        # Filter out directories, only keep files
        files = [f for f in files if os.path.isfile(os.path.join(folder_path, f))]

        if files:
            # Get the first file name
            first_file = files[0]
            return first_file
        else:
            return "The folder is empty or contains no files."

    except FileNotFoundError:
        return "The specified folder does not exist."

In [10]:
def count_digits_after_dash(filename):
    # Use regex to find the pattern after the dash
    match = re.search(r'-(\d+)', filename)

    if match:
        # Return the length of the matched digits
        return len(match.group(1))
    else:
        # Return 0 if no match is found
        return 0

In [11]:
def resample_slices_gpu(image_np, target_slices):
    # Convert numpy array to torch tensor and move to GPU
    image_tensor = torch.tensor(image_np).cuda()
    original_slices = image_tensor.shape[1]  # Number of slices in the original image

    if original_slices != target_slices:
        zoom_factor = target_slices / original_slices
        # Adjust the tensor shape from (N, S, H, W) to (N, C, S, H, W) by adding a channel dimension
        image_tensor = image_tensor.unsqueeze(1)  # (N, S, H, W) to (N, 1, S, H, W)
        # Use interpolate to resample the slices
        image_tensor = F.interpolate(image_tensor, size=(target_slices, image_tensor.shape[3], image_tensor.shape[4]), mode='trilinear', align_corners=False)
        # Remove the added channel dimension
        image_tensor = image_tensor.squeeze(1)  # (N, 1, S, H, W) to (N, S, H, W)

    return image_tensor.cpu().numpy()  # Move back to CPU and convert to numpy array

In [12]:
def pad_to_shape(image_np, target_shape, device='cpu'):
    if len(image_np.shape) != 4 or len(target_shape) != 4:
        raise ValueError("Both input tensor and target shape must be 4D.")
    
    image_tensor = torch.tensor(image_np.copy(), device=device)
    
    # Calculate padding only for necessary dimensions
    padding = []
    for dim in range(4):
        current_size = image_tensor.shape[dim]
        target_size = target_shape[dim]
        if current_size < target_size:
            total_padding = target_size - current_size
            pad_before = total_padding // 2
            pad_after = total_padding - pad_before
            padding.append((pad_before, pad_after))
        else:
            padding.append((0, 0))
    
    padding = [p for pair in reversed(padding) for p in pair]
    padded_tensor = torch.nn.functional.pad(image_tensor, padding, mode='constant', value=0)
    
    return padded_tensor.cpu().numpy()

In [13]:
def resize_image(image_np, target_shape):

    # Unpacking target dimensions
    target_channels, target_slices, target_height, target_width = target_shape
    
    # Creating resized image
    resized_image = np.array([
        [resize(slice, (target_height, target_width), anti_aliasing=True) 
         for slice in channel] for channel in image_np
    ])
    
    return resized_image

In [14]:
def resize_array_scipy(image_np, target_size):
    """
    Resize a 5D NumPy array from (N, C, D, H, W) to (N, C, D, target_H, target_W).

    Parameters:
        image_np (numpy.ndarray): Input array with shape (N, C, D, H, W).
        target_size (tuple): Tuple (target_H, target_W).

    Returns:
        numpy.ndarray: Resized array with shape (N, C, D, target_H, target_W).
    """
    N, C, D, H, W = image_np.shape
    target_H, target_W = target_size

    # Calculatea zoom factors
    zoom_factors = [1, 1, 1, target_H / H, target_W / W]

    # Perform the zoom
    resized_array = zoom(image_np, zoom_factors, order=1)  # order=1 for bilinear interpolation

    return resized_array

In [15]:
def find_midpoint(coord1, coord2):

    if len(coord1) != 3 or len(coord2) != 3:
        raise ValueError("Both coordinates must be lists or tuples of length 3.")
    
    midpoint = [(c1 + c2) / 2 for c1, c2 in zip(coord1, coord2)]
    return midpoint

In [16]:
def sort_by_number_after_slash(lst):
    # Sort the list using a custom key that extracts the number after the last '/'
    sorted_list = sorted(lst, key=lambda x: int(x.split('-')[-1]))
    return sorted_list

In [17]:
def dcm_to_np_duke(patients, target_shape, device, segmentations):
    patients_list = []
    list_of_people_scans = []

    # Iterate over each patient
    for patient_path in patients:

        # Get the DCE MRI loaded in as a np array
        phase_paths = find_folders_with_dcm_files(patient_path)
        sorted_phase_paths = sort_by_number_after_last_slash(phase_paths)
        for offset in range(len(sorted_phase_paths)):
            if 'dyn' in sorted_phase_paths[offset].split('/')[-1].lower() or 'vibrant' in sorted_phase_paths[offset].split('/')[-1].lower():
                break
        selected_phase_paths = sorted_phase_paths[offset:target_shape[0]+offset]
        dcms = get_dicom_file_paths(selected_phase_paths)
        print(int(patient_path[-3:]))
        entire = np.stack([pydicom.dcmread(dcm).pixel_array for dcm in dcms], axis=0)     
        entire = np.reshape(entire, (target_shape[0], entire.shape[0] // target_shape[0], entire.shape[1], entire.shape[2]))     

        # Normalize
        entire = (entire - np.min(entire)) / (np.max(entire) - np.min(entire))

        # Grab sagittal tumor slices
        start = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['Start Sagittal'].values[0] -1
        end = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['End Sagittal'].values[0] - 1
        middle = math.floor((start + end) // 2)
        slices = target_shape[1]
        start_at = max(0, int(middle - (slices // 2)))
        end_at = min(entire.shape[3], int(start_at + slices))    
        entire = entire[:, :, :, start_at:end_at]

        # Grab coronal tumor slices
        start = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['Start Coronal'].values[0] -1
        end = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['End Coronal'].values[0] - 1
        middle = math.floor((start + end) // 2)
        slices = target_shape[3]
        start_at = max(0, int(middle - (slices // 2)))
        end_at = min(entire.shape[2], int(start_at + slices))    
        entire = entire[:, :, start_at:end_at, :]

        # Grab axial tumor slices
        start = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['Start Axial'].values[0] -1
        end = segmentations[segmentations['Patient ID'] == patient_path[-14:]]['End Axial'].values[0] - 1
        middle = math.floor((start + end) // 2)
        slices = target_shape[2]
        start_at = max(0, int(middle - (slices // 2)))
        end_at = min(entire.shape[1], int(start_at + slices))    
        entire = entire[:, start_at:end_at, :, :]

        # Pad to shape if needed and add to large list
        list_of_people_scans.append(pad_to_shape(entire, target_shape, device))

        # Output Patinet ID
        patients_list.append(int(patient_path[-3:]))

    # Create output np array
    output_np = np.stack(list_of_people_scans, axis=0).astype(np.float32)

    return output_np, patients_list

In [18]:
def create_dummies(df, columns, prefix=True):
    for column in columns:
        # Replace null values with a placeholder
        df[column] = df[column].fillna('Missing')
        
        # Create dummy variables
        dummies = pd.get_dummies(df[column], prefix=column if prefix else None)
        
        # Convert dummies to integers
        dummies = dummies.astype(int)
        
        # Concatenate the dummies to the original DataFrame
        df = pd.concat([df, dummies], axis=1)
        
    return df

Plotting

In [27]:
def plot_slice(image_np, slice_index, size=(24,9), normalize_entire=True, cmap='gray', title="MRI Slice"):
    if len(image_np.shape) == 3:
        image_np = np.expand_dims(image_np, axis=0)
    clear_output(wait=True)  # Clear previous plots
    plt.figure(figsize=size)
    num_slices = image_np.shape[0]


    for i in range(num_slices):
        plt.subplot(1, num_slices, i + 1)
        plt.imshow(image_np[i, slice_index, :, :], cmap=cmap, vmin=np.min(image_np if normalize_entire else image_np[i]), vmax=np.max(image_np if normalize_entire else image_np[i]))
        plt.title(f'Phase {i+1}')
        plt.axis('off')


    plt.tight_layout()
    plt.show()
    display(slice_slider)  # Redisplay the slider
    
    def scroll(image_np, normalize_entire=True):
        if len(image_np.shape) == 3:
            image_np = np.expand_dims(image_np, axis=0)
        max_slice_index = image_np.shape[1] - 1
        global slice_slider
        slice_slider = widgets.IntSlider(min=0, max=max_slice_index, step=1, value=0)

        def on_value_change(change):
            plot_slice(image_np, change['new'], normalize_entire=normalize_entire)

        slice_slider.observe(on_value_change, names='value')
        display(slice_slider)  # Initial display of the slider
        plot_slice(image_np, slice_slider.value, normalize_entire=normalize_entire)  # Initial plot



In [28]:
def scroll(image_np, normalize_entire=True):
    if len(image_np.shape) == 3:
        image_np = np.expand_dims(image_np, axis=0)
    max_slice_index = image_np.shape[1] - 1
    global slice_slider
    slice_slider = widgets.IntSlider(min=0, max=max_slice_index, step=1, value=0)
   
    def on_value_change(change):
        plot_slice(image_np, change['new'], normalize_entire=normalize_entire)
   
    slice_slider.observe(on_value_change, names='value')
    display(slice_slider)  # Initial display of the slider
    plot_slice(image_np, slice_slider.value, normalize_entire=normalize_entire)  # Initial plot


# Define Model Functions

In [29]:
def threshold(scores,threshold=0.50, minimum=0, maximum = 1.0):
    x = np.array(list(scores))
    x[x >= threshold] = maximum
    x[x < threshold] = minimum
    return x

In [30]:
def plot_confusion_matrix(y_true, y_pred):
  plt.figure(figsize=(16,9))
  cm = confusion_matrix(y_true, y_pred)
  ax= plt.subplot()
  sns.heatmap(cm, annot=True, fmt='g', ax=ax, annot_kws={"size": 20})

  # labels, title and ticks
  ax.set_xlabel('Predicted labels', fontsize=20)
  ax.set_ylabel('True labels', fontsize=20)
  ax.set_title('Confusion Matrix', fontsize=20)
  ax.xaxis.set_ticklabels(['pCR','Non pCR'], fontsize=20)
  ax.yaxis.set_ticklabels(['pCR','Non pCR'], fontsize=20)

In [31]:
def plot_loss(epoch_train_loss, epoch_val_loss, lr, bs):
  plt.figure(figsize=(16,9))
  plt.plot(epoch_train_loss, c='b', label='Training loss')
  plt.plot(epoch_val_loss, c='r', label = 'Testing loss')
  plt.legend()
  plt.grid()
  plt.xlabel('Epochs', fontsize=20)
  plt.ylabel('Loss', fontsize=20)
  plt.title(f'Learning Rate: {lr} Batch Size: {bs}')
  plt.show() 

In [32]:
def plot_accuracy(epoch_train_accuracy, epoch_val_accuracy, lr, bs):
  plt.figure(figsize=(16,9))
  plt.plot(epoch_train_accuracy, c='b', label='Training accuracy')
  plt.plot(epoch_val_accuracy, c='r', label = 'Testing accuracy')
  plt.legend()
  plt.grid()
  plt.xlabel('Epochs', fontsize=20)
  plt.ylabel('Accuracy', fontsize=20)
  plt.title(f'Learning Rate: {lr} Batch Size: {bs}')
  plt.show()

In [33]:
def plot_auc(epoch_train_auc, epoch_val_auc, lr, bs):
  plt.figure(figsize=(16,9))
  plt.plot(epoch_train_auc, c='b', label='Training AUC')
  plt.plot(epoch_val_auc, c='r', label = 'Testing AUC')
  plt.legend()
  plt.grid()
  plt.xlabel('Epochs', fontsize=20)
  plt.ylabel('AUC', fontsize=20)
  plt.title(f'Learning Rate: {lr} Batch Size: {bs}')
  plt.show()

In [34]:
def bootstrap_auc_ci(y_true, y_score, n_bootstraps=1000, alpha=0.95, random_state=None):
    """
    Compute AUC and its (alpha*100)% confidence interval via bootstrapping.
    """

    # Convert to numpy arrays in case they aren't
    y_true = np.asarray(y_true)
    y_score = np.asarray(y_score)

    # Initialize random state
    rng = np.random.default_rng(random_state)
    
    # Calculate the true AUC on the full dataset
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_original = auc(fpr, tpr)

    # Perform bootstrapping
    bootstrapped_scores = []
    for _ in range(n_bootstraps):
        # Sample with replacement
        indices = rng.integers(0, len(y_score), len(y_score))
        # Ensure we have at least one positive and one negative
        if len(np.unique(y_true[indices])) < 2:
            continue
        fpr_bs, tpr_bs, _ = roc_curve(y_true[indices], y_score[indices])
        score = auc(fpr_bs, tpr_bs)
        bootstrapped_scores.append(score)
    
    # Sort the scores
    sorted_scores = np.array(bootstrapped_scores)
    sorted_scores.sort()

    # Confidence interval boundaries
    lower_idx = int((1.0 - alpha) / 2 * len(sorted_scores))
    upper_idx = int((1.0 + alpha) / 2 * len(sorted_scores))
    lower_bound = sorted_scores[lower_idx]
    upper_bound = sorted_scores[upper_idx]

    return auc_original, lower_bound, upper_bound

def plot_roc_curve(train_targets, train_probs, val_targets, val_probs, n_bootstraps=1000, alpha=0.95):
    # Convert inputs to NumPy arrays if needed
    train_targets = np.asarray(train_targets)
    train_probs   = np.asarray(train_probs)
    val_targets   = np.asarray(val_targets)
    val_probs     = np.asarray(val_probs)
    
    # Compute AUCs and 95% CI
    roc_auc_train, lower_train, upper_train = bootstrap_auc_ci(
        train_targets, train_probs, n_bootstraps=n_bootstraps, alpha=alpha, random_state=42
    )
    roc_auc_val, lower_val, upper_val = bootstrap_auc_ci(
        val_targets, val_probs, n_bootstraps=n_bootstraps, alpha=alpha, random_state=42
    )

    print(f"Training AUC = {roc_auc_train:.3f} (95% CI: {lower_train:.3f} - {upper_train:.3f})")
    print(f"Validation AUC = {roc_auc_val:.3f} (95% CI: {lower_val:.3f} - {upper_val:.3f})")

    # Compute ROC curves for plotting
    fpr_train, tpr_train, _ = roc_curve(train_targets, train_probs)
    fpr_val, tpr_val, thresholds_val = roc_curve(val_targets, val_probs)

    # Calculate Youden's J for the validation set
    youden_j = tpr_val - fpr_val
    optimal_idx = np.argmax(youden_j)
    optimal_threshold = thresholds_val[optimal_idx]
    optimal_point = (fpr_val[optimal_idx], tpr_val[optimal_idx])
    print(f"Optimal Threshold (Youden's J) = {optimal_threshold:.4f}")

    # Plotting
    plt.figure(figsize=(8, 6))
    plt.plot(
        fpr_train, tpr_train, color='blue', lw=2,
        label=f'Training ROC (AUC = {roc_auc_train:.3f}, 95% CI: {lower_train:.3f}-{upper_train:.3f})'
    )
    plt.plot(
        fpr_val, tpr_val, color='red', lw=2,
        label=f'Testing ROC (AUC = {roc_auc_val:.3f}, 95% CI: {lower_val:.3f}-{upper_val:.3f})'
    )
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.scatter(optimal_point[0], optimal_point[1], marker='o', color='green', label='Optimal Threshold')

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16)
    plt.legend(loc="lower right", fontsize=12)
    plt.grid(True)
    plt.show()

In [35]:
def last_fm_neurons(model, data, device):
    model_children = list(model.children())  # Sequential children of the model

    layers = []

    # Collect all layers from the cnn_model (1st Sequential)
    for child in model_children:
        if isinstance(child, nn.Sequential):
            for layer in child.children():
                if not isinstance(layer, nn.Dropout):
                    layers.append(layer)
            break

    # Add the conv2d_layer (2nd Sequential)
    for child in model_children:
        if isinstance(child, nn.Sequential):
            for layer in child.children():
                if not isinstance(layer, nn.Dropout):
                    layers.append(layer)

    img = data[0]['image']  # Grab a random image
    img = torch.from_numpy(img).to(device)  # Convert image to GPU tensor and move it to the specified device

    # Ensure the image is in float32 format
    img = img.float()

    img = img.unsqueeze(0)  # Change shape of image to be (1, C, D, H, W); 1 means we have 1 image

    # Run the image through the model layers sequentially
    results = [layers[0](img)]

    for i in range(1, len(layers)):
        results.append(layers[i](results[-1]))

    feature_maps = results

    # Return the shape of the final feature map from the Conv2d layer and its flattened shape
    last_fm_shape = feature_maps[-1].shape
    last_fm_flattened_shape = feature_maps[-1].view(-1).shape

    return last_fm_shape, last_fm_flattened_shape


In [36]:
def calculate_in_features(model, input_shape, device):
    # Create a dummy input tensor with the given shape
    dummy_input = torch.zeros(1, *input_shape).to(device)
    
    # Move model to the specified device
    model.to(device)
    
    # Forward pass the dummy input through the model up to the last conv2d layer
    with torch.no_grad():
        x = model.cnn_model(dummy_input)

    return x.shape

In [37]:
def calculate_class_weights(data):
    labels = [d['label'].item() for d in data]  # Extract labels from the dataset
    labels = np.array(labels).astype(int)  # Ensure labels are integers
    
    # Count the frequency of each class
    class_counts = np.bincount(labels)
    
    # Calculate class weights: total_samples / (num_classes * class_counts)
    total_samples = len(labels)
    class_weights = total_samples / (len(class_counts) * class_counts)
    
    return torch.tensor(class_weights, dtype=torch.float32)

In [38]:
def fit_km(df, outcomes_label, predictor, group_name_mapping, color_palette):
    for OUTCOME, (label, flip) in outcomes_label.items():
        # Select relevant columns based on the outcome label (PRE exclusion)
        if label:
            df_filtered = df[df[f'PRE_{OUTCOME}'] == 0]
            df_selected = df_filtered[[OUTCOME, f'{OUTCOME}_TIME', predictor]]
        else:
            df_selected = df[[OUTCOME, f'{OUTCOME}_TIME', predictor]]

        plt.figure(figsize=(10, 9))

        # Dictionary to store Kaplan-Meier fitters for each group
        kmf_group_dict = {}

        # Prepare lists to store survival times and events for the log-rank test
        times = []
        events = []
        groups = []

        # Iterate over the groups in the order of group_name_mapping
        for group in group_name_mapping.keys():
            if group in df_selected[predictor].unique():
                group_data = df_selected[df_selected[predictor] == group]
                kmf_group = KaplanMeierFitter()
                kmf_group.fit(
                    group_data[f'{OUTCOME}_TIME'],
                    event_observed=group_data[OUTCOME],
                    label=group_name_mapping[group]
                )

                # Retrieve color and line style for the current group
                color, linestyle = color_palette.get(group, ('blue', 'solid'))  # Default values if group not in color_palette

                if flip:
                    # For cumulative incidence (1 - survival probability)
                    survival_prob = kmf_group.survival_function_
                    ci_upper = kmf_group.confidence_interval_.iloc[:, 1]
                    ci_lower = kmf_group.confidence_interval_.iloc[:, 0]

                    plt.plot(
                        survival_prob.index,
                        1 - survival_prob.values.flatten(),
                        label=group_name_mapping[group],
                        color=color,
                        linewidth=2,
                        linestyle=linestyle
                    )

                    # Add confidence intervals for cumulative incidence
                    plt.fill_between(
                        survival_prob.index,
                        1 - ci_upper,
                        1 - ci_lower,
                        color=color,
                        alpha=0.3
                    )
                else:
                    # Plot survival curve with specified color and line style
                    kmf_group.plot_survival_function(
                        ci_show=False,
                        color=color,
                        linewidth=2,
                        linestyle=linestyle
                    )

                kmf_group_dict[group] = kmf_group

                # Collect data for log-rank test
                times.extend(group_data[f'{OUTCOME}_TIME'])
                events.extend(group_data[OUTCOME])
                groups.extend([group] * len(group_data))

        # Perform log-rank test
        if len(set(groups)) > 2:
            results = multivariate_logrank_test(times, groups, events)
        else:
            unique_groups = list(set(groups))
            group_1 = df_selected[df_selected[predictor] == unique_groups[0]]
            group_2 = df_selected[df_selected[predictor] == unique_groups[1]]
            results = logrank_test(
                group_1[f'{OUTCOME}_TIME'],
                group_2[f'{OUTCOME}_TIME'],
                event_observed_A=group_1[OUTCOME],
                event_observed_B=group_2[OUTCOME]
            )

        # Output p-value
        print(f"Log-rank test p-value for {OUTCOME}: {results.p_value}")

        # Set the x-axis to intervals of 1
        plt.xticks(ticks=range(0, int(df_selected[f'{OUTCOME}_TIME'].max()) + 1, 1))

        # Set y-axis limits and enable grid
        plt.ylim(0.5, 1)
        plt.grid(True)

        # Add legend in the top right
        plt.legend(loc="upper right")

        # Add titles and labels
        plt.xlabel('Time Since Breast Cancer Diagnosis (Years)')
        plt.ylabel('Recurrence-Free Survival Probability' if not flip else 'Cumulative Incidence')

        # Adding number at risk and censored information at the bottom with group names
        at_risk_labels = [group_name_mapping[g] for g in group_name_mapping.keys() if g in kmf_group_dict]
        add_at_risk_counts(*[kmf_group_dict[g] for g in group_name_mapping.keys() if g in kmf_group_dict], labels=at_risk_labels)

        # Adjust layout
        plt.tight_layout()
        plt.show()

In [39]:
def bootstrap_c_index(model, df, n_bootstrap=1000, random_seed=42):
    """
    Compute the bootstrap 95% CI for the C-index.
    """
    np.random.seed(random_seed)
    
    durations = df[model.duration_col].values
    events    = df[model.event_col].values
    
    # Predicted partial hazard
    preds = model.predict_partial_hazard(df).values
    
    # Negate so that larger is interpreted as "later event" by concordance_index
    preds = -preds

    c_index_list = []
    n = len(df)
    
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        c_index_list.append(concordance_index(durations[idx],
                                              preds[idx],
                                              events[idx]))
    mean_c = np.mean(c_index_list)
    lower  = np.percentile(c_index_list, 2.5)
    upper  = np.percentile(c_index_list, 97.5)
    
    return mean_c, lower, upper

def fit_cox_model(df, outcomes_label, predictors, 
                  bootstrap=True, n_bootstrap=1000):
    """
    Fit Cox models for each outcome, then print the summary (HR, 95% CI, p),
    plus the C-index for each model. Optionally, bootstrap to get 95% CI
    for the C-index.

    Returns
    -------
    models_dict : dict
        Dictionary of {OUTCOME: fitted CoxPHFitter objects}.
        You can use these objects later for c-index comparison.
    """
    
    models_dict = {}  # will hold the fitted models
    
    for OUTCOME, label in outcomes_label.items():
        
        # Select relevant columns
        if label:
            df_filtered = df[df[f'PRE_{OUTCOME}'] == 0]
            df_selected = df_filtered[[OUTCOME, f'{OUTCOME}_TIME'] + predictors]
        else:
            df_selected = df[[OUTCOME, f'{OUTCOME}_TIME'] + predictors]
        
        # Fit Cox proportional hazards model
        cph = CoxPHFitter()
        cph.fit(df_selected, duration_col=f'{OUTCOME}_TIME', event_col=OUTCOME)

        # Print the summary results
        summary = cph.summary[['exp(coef)', 'exp(coef) lower 95%', 
                               'exp(coef) upper 95%', 'p']]
        summary.columns = ['Hazard Ratio', 'CI Lower 95%', 'CI Upper 95%', 'p-value']

        print(f"Results for {OUTCOME}:")
        for covariate, row in summary.iterrows():
            hazard_ratio = f"{row['Hazard Ratio']:.2f}"
            lower_ci     = f"{row['CI Lower 95%']:.2f}"
            upper_ci     = f"{row['CI Upper 95%']:.2f}"
            p_value      = "<0.005" if row['p-value'] < 0.005 else f"{row['p-value']:.2g}"
            print(f"{covariate}\t{hazard_ratio} [{lower_ci}, {upper_ci}]\t{p_value}")
        
        # Single-run (in-sample) C-index
        print("\nSingle-run C-index:", f"{cph.concordance_index_:.3f}")
        
        # Optionally do bootstrap
        if bootstrap:
            mean_c, lower, upper = bootstrap_c_index(
                model=cph, 
                df=df_selected, 
                n_bootstrap=n_bootstrap
            )
            print(f"Bootstrap C-index: {mean_c:.3f} (95% CI: [{lower:.3f}, {upper:.3f}])")
        
        # Store the fitted model in the dictionary
        models_dict[OUTCOME] = cph
        
        print("="*60 + "\n")

    # Return the dictionary of models
    return models_dict

In [40]:
def bootstrap_compare_c_indices(modelA, modelB, df, 
                                duration_col='RD', event_col='RD_TIME',
                                n_bootstrap=1000, random_seed=42,
                                flip_sign=True):
    """
    Compare the C-index of two fitted Cox models on the same dataset
    using bootstrap resampling (paired). Returns a dict with the single-run
    C-index for each model, bootstrap mean C-index, difference, 95% CI, 
    and an approximate two-sided p-value for that difference.
    """
    import numpy as np
    from lifelines.utils import concordance_index

    np.random.seed(random_seed)
    
    durations = df[duration_col].values
    events    = df[event_col].values

    # predictions
    phA = modelA.predict_partial_hazard(df).values
    phB = modelB.predict_partial_hazard(df).values
    
    # Possibly flip sign if needed
    predsA = -phA if flip_sign else phA
    predsB = -phB if flip_sign else phB

    # Single-run c-indices
    cA = concordance_index(durations, predsA, events)
    cB = concordance_index(durations, predsB, events)

    cA_values = []
    cB_values = []
    diffs     = []
    
    n = len(df)

    for _ in range(n_bootstrap):
        idx = np.random.choice(n, size=n, replace=True)
        cA_boot = concordance_index(durations[idx], predsA[idx], events[idx])
        cB_boot = concordance_index(durations[idx], predsB[idx], events[idx])
        cA_values.append(cA_boot)
        cB_values.append(cB_boot)
        diffs.append(cA_boot - cB_boot)

    cA_values = np.array(cA_values)
    cB_values = np.array(cB_values)
    diffs     = np.array(diffs)

    mean_diff = np.mean(diffs)
    ci_lower  = np.percentile(diffs, 2.5)
    ci_upper  = np.percentile(diffs, 97.5)
    
    # Approximate two-sided p-value
    frac_lt0 = np.mean(diffs < 0)
    frac_gt0 = np.mean(diffs > 0)
    p_value  = 2 * min(frac_lt0, frac_gt0)
    p_value  = min(p_value, 1.0)

    # Round to 3 decimals
    return {
        'CindexA': round(cA, 3),
        'CindexB': round(cB, 3),
        'mean_diff': round(mean_diff, 3),
        'ci_lower': round(ci_lower, 3),
        'ci_upper': round(ci_upper, 3),
        'p_value': round(p_value, 3),
        'mean_cA': round(cA_values.mean(), 3),
        'mean_cB': round(cB_values.mean(), 3)
    }

In [41]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

In [42]:
def train_model(model, data, device, 
                lr=1e-4, epochs=100, batch_size=32, weight_decay=1e-3, 
                early_stopping_patience=10, lr_scheduler=True, 
                use_amp=True, num_workers=4):

    # Compute class weights
    train_labels = data.train_labels
    classes, class_counts = np.unique(train_labels, return_counts=True)
    class_weights = torch.tensor(len(train_labels) / (len(classes) * class_counts), dtype=torch.float32).to(device)
    
    # Use BCEWithLogitsLoss instead of BCELoss (for numerical stability with AMP)
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # Learning Rate Scheduler (Reduce LR when validation loss stagnates)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5) if lr_scheduler else None
    
    # Early Stopping Setup
    best_val_loss = float('inf')
    patience_counter = 0
    
    # AMP Scaler (Updated syntax)
    scaler = amp.GradScaler() if use_amp else None
    
    # Tracking metrics
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    train_aucs, val_aucs = [], []

    for epoch in range(1, epochs + 1):
        # === Training Phase ===
        model.train()
        data.set_mode('train')
        train_loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        
        train_correct, train_total, train_probs, train_targets = 0, 0, [], []
        epoch_train_losses = []

        for D in train_loader:
            optimizer.zero_grad()
            image, label = D['image'].to(device).float(), D['label'].to(device).float().view(-1, 1)
            
            with amp.autocast(enabled=use_amp):  # Updated syntax for AMP
                y_hat = model(image)  # No sigmoid in forward pass!
                sample_weights = class_weights[label.long()].view(-1, 1)
                loss = (criterion(y_hat, label) * sample_weights).mean()

            if use_amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            
            epoch_train_losses.append(loss.item())
            train_correct += (y_hat.sigmoid().round() == label).sum().item()  # Apply sigmoid only for prediction
            train_total += label.size(0)
            train_probs.extend(y_hat.sigmoid().detach().cpu().numpy().flatten())  # Convert logits to probabilities
            train_targets.extend(label.detach().cpu().numpy().flatten())

        train_auc = roc_auc_score(train_targets, train_probs)
        train_losses.append(np.mean(epoch_train_losses))
        train_accs.append(train_correct / train_total)
        train_aucs.append(train_auc)

        # === Validation Phase ===
        model.eval()
        data.set_mode('val')
        val_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        
        val_correct, val_total, val_probs, val_targets = 0, 0, [], []
        epoch_val_losses = []

        with torch.no_grad():
            for D in val_loader:
                image, label = D['image'].to(device).float(), D['label'].to(device).float().view(-1, 1)
                
                with amp.autocast(enabled=use_amp):
                    y_hat = model(image)  # No sigmoid in forward pass!
                    sample_weights = class_weights[label.long()].view(-1, 1)
                    loss = (criterion(y_hat, label) * sample_weights).mean()

                epoch_val_losses.append(loss.item())
                val_correct += (y_hat.sigmoid().round() == label).sum().item()
                val_total += label.size(0)
                val_probs.extend(y_hat.sigmoid().cpu().numpy().flatten())
                val_targets.extend(label.cpu().numpy().flatten())

        val_auc = roc_auc_score(val_targets, val_probs)
        val_losses.append(np.mean(epoch_val_losses))
        val_accs.append(val_correct / val_total)
        val_aucs.append(val_auc)

        # Learning Rate Scheduling
        if lr_scheduler:
            scheduler.step(val_losses[-1])

        # Early Stopping
        if val_losses[-1] < best_val_loss:
            best_val_loss = val_losses[-1]
            patience_counter = 0  # Reset patience counter
            torch.save(model.state_dict(), "best_model.pth")  # Save best model
        else:
            patience_counter += 1

        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch}")
            break

        # Print Progress
        if epoch % 10 == 0 or epoch == epochs:
            print(f"Epoch {epoch} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f} | "
                  f"Train Acc: {100 * train_accs[-1]:.2f}% | Val Acc: {100 * val_accs[-1]:.2f}% | "
                  f"Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f}")

    # Plot Metrics
    plot_loss(train_losses, val_losses, lr, batch_size)
    plot_accuracy(train_accs, val_accs, lr, batch_size)
    plot_auc(train_aucs, val_aucs, lr, batch_size)
    plot_roc_curve(train_targets, train_probs, val_targets, val_probs)

In [43]:
def evaluate_model(model, data, device):
    data.mode = 'val'
    val_dataloader = DataLoader(data, shuffle=False, batch_size=2, num_workers=0)
    model.eval()

    y_true = []
    y_pred = []

    with torch.no_grad():
        for D in val_dataloader:
            images = D['image'].to(device)
            labels = D['label'].to(device).float().view(-1, 1)

            outputs = model(images)
            predictions = (outputs >= 0.5).float()

            y_true.append(labels.cpu().numpy())
            y_pred.append(predictions.cpu().numpy())

    y_true = np.concatenate(y_true, axis=0).squeeze()
    y_pred = np.concatenate(y_pred, axis=0).squeeze()

    accuracy = accuracy_score(y_true, y_pred)
    return y_true, y_pred, accuracy

In [44]:
def train_model(model, data, device, lrng_rt=0.0001, EPOCH=200, batch_size=4, weight_decay=1e-4):
    
    # Compute class weights based on the training data
    train_labels = data.train_labels  # Assuming this is a NumPy array of labels in the training set
    classes = np.unique(train_labels)
    class_counts = np.array([(train_labels == c).sum() for c in classes])
    total_samples = len(train_labels)
    num_classes = len(classes)

    # Compute weights inversely proportional to class frequencies
    class_weights = total_samples / (num_classes * class_counts)
    class_weights = class_weights / class_weights.sum()  # Normalize weights

    # Convert class weights to a PyTorch tensor
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

    # Initialize the loss function with no reduction
    error = nn.BCELoss(reduction='none')

    epoch_train_loss = []
    epoch_val_loss = []
    epoch_train_accuracy = []
    epoch_val_accuracy = []
    epoch_train_auc = []
    epoch_val_auc = []

    optimizer = torch.optim.Adam(model.parameters(), lr=lrng_rt, weight_decay=weight_decay)

    for epoch in range(1, EPOCH + 1):
        train_losses = []
        train_correct = 0
        train_total = 0
        train_probs = []
        train_targets = []

        data.set_mode('train')  # Ensure the dataset is in training mode
        train_dataloader = DataLoader(data, shuffle=True, batch_size=batch_size, num_workers=0)
        model.train()

        for D in train_dataloader:
            optimizer.zero_grad()
            image = D['image'].to(device).float()
            label = D['label'].to(device).float()

            y_hat = model(image)
            label = label.view(-1, 1)

            # Compute the per-sample loss
            loss = error(y_hat, label)

            # Get the weights for each sample in the batch
            label_flat = label.view(-1)
            sample_weights = class_weights[label_flat.long()].view(-1, 1)

            # Apply the weights to the loss
            weighted_loss = loss * sample_weights

            # Compute the mean loss over the batch
            loss = weighted_loss.mean()

            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            predicted = (y_hat > 0.5).float()
            train_correct += (predicted == label).sum().item()
            train_total += label.size(0)

            # Store probabilities and targets for AUC calculation
            train_probs.extend(y_hat.detach().cpu().numpy().flatten())
            train_targets.extend(label.detach().cpu().numpy().flatten())

        # Compute AUC for the training set
        train_auc = roc_auc_score(train_targets, train_probs)
        epoch_train_auc.append(train_auc)

        epoch_train_loss.append(np.mean(train_losses))
        epoch_train_accuracy.append(train_correct / train_total)

        # Validate the model
        val_losses = []
        val_correct = 0
        val_total = 0
        val_probs = []
        val_targets = []

        data.set_mode('val')  # Ensure the dataset is in validation mode
        val_dataloader = DataLoader(data, shuffle=False, batch_size=batch_size, num_workers=0)
        model.eval()

        with torch.no_grad():
            for D in val_dataloader:
                image = D['image'].to(device).float()
                label = D['label'].to(device).float()

                y_hat = model(image)
                label = label.view(-1, 1)

                # Compute the per-sample loss
                loss = error(y_hat, label)

                # Get the weights for each sample in the batch
                label_flat = label.view(-1)
                sample_weights = class_weights[label_flat.long()].view(-1, 1)

                # Apply the weights to the loss
                weighted_loss = loss * sample_weights

                # Compute the mean loss over the batch
                loss = weighted_loss.mean()

                val_losses.append(loss.item())

                predicted = (y_hat > 0.5).float()
                val_correct += (predicted == label).sum().item()
                val_total += label.size(0)

                # Store probabilities and targets for AUC calculation
                val_probs.extend(y_hat.detach().cpu().numpy().flatten())
                val_targets.extend(label.detach().cpu().numpy().flatten())

        # Compute AUC for the validation set
        val_auc = roc_auc_score(val_targets, val_probs)
        epoch_val_auc.append(val_auc)

        epoch_val_loss.append(np.mean(val_losses))
        epoch_val_accuracy.append(val_correct / val_total)

        if epoch % 10 == 0 or epoch == EPOCH:
            print('Epoch: {}\tTrain Loss: {:.6f}\tVal Loss: {:.6f}\tTrain Acc: {:.2f}%\tVal Acc: {:.2f}%\tTrain AUC: {:.4f}\tVal AUC: {:.4f}'.format(
                epoch,
                np.mean(train_losses),
                np.mean(val_losses),
                100.0 * train_correct / train_total,
                100.0 * val_correct / val_total,
                train_auc,
                val_auc
            ))

    # Plot loss, accuracy, and AUC over epochs
    plot_loss(epoch_train_loss, epoch_val_loss, lrng_rt, batch_size)
    plot_accuracy(epoch_train_accuracy, epoch_val_accuracy, lrng_rt, batch_size)
    plot_auc(epoch_train_auc, epoch_val_auc, lrng_rt, batch_size)

    # After training, plot ROC curves and find optimal threshold
    plot_roc_curve(train_targets, train_probs, val_targets, val_probs)

# Train and Test

In [45]:
class MRI(Dataset):
    def __init__(self, rd, dfs, sf, test=0.2):
        self.mode = 'train'

        # Labels
        rd_label = np.ones(rd.shape[0], dtype=np.float32)
        dfs_label = np.zeros(dfs.shape[0], dtype=np.float32)

        # Combine labels
        self.labels = np.concatenate((rd_label, dfs_label))

        # Store original data
        self.rd_original = rd  # Shape: (patients, phases, slices, H, W)
        self.dfs_original = dfs  # Shape: (patients, phases, slices, H, W)

        # Store patient IDs in the order of labels
        self.combined_patient_ids = HRp + HRn  # Assuming HRp and HRn are available here

        # Split indices for train and validation
        self.train_val_split(sf, test)

        # Preprocess data
        self.preprocess_data()

    def train_val_split(self, sf, test):
        # Split indices into train and validation sets with stratification
        self.train_idx, self.val_idx = train_test_split(
            np.arange(len(self.labels)),
            test_size=test,
            random_state=sf,
            shuffle=True,
            stratify=self.labels
        )
        
        # Store patient indices for train and validation sets
        self.train_patient_indices = self.train_idx.copy()
        self.val_patient_indices = self.val_idx.copy()

    def preprocess_data(self):
        # Preprocess training data by expanding slices
        self.train_images = []
        self.train_labels = []

        for idx in self.train_idx:
            label = self.labels[idx]
            if idx < len(self.rd_original):
                # rd data
                image = self.rd_original[idx]  # Shape: (phases, slices, H, W)
            else:
                # dfs data
                idx_adjusted = idx - len(self.rd_original)
                image = self.dfs_original[idx_adjusted]  # Shape: (phases, slices, H, W)

            # Transpose to rearrange axes: (phases, slices, H, W) -> (slices, phases, H, W)
            image = np.transpose(image, (1, 0, 2, 3))  # Shape: (slices, phases, H, W)

            # Define transformations
            transformations = [
                lambda x: np.flip(x, axis=1),  # Vertical flip
                lambda x: x,  
                lambda x: np.flip(x, axis=2),  # Horizontal flip
                lambda x: np.flip(x, axis=1),
                lambda x: x,
                lambda x: np.flip(x, axis=2),  
                lambda x: np.flip(x, axis=1),
                lambda x: x,
                lambda x: np.flip(x, axis=2),
                lambda x: x,  # No flip
            ]

            # Randomly assign transformations to slices
            #np.random.shuffle(transformations)

            # Apply transformations to slices
            for slice_idx in range(image.shape[0]):  # Number of slices per patient
                slice_image = image[slice_idx]  # Shape: (phases, H, W)

                # Apply the corresponding transformation
                transform = transformations[slice_idx]
                transformed_slice = transform(slice_image)

                self.train_images.append(transformed_slice)
                self.train_labels.append(label)

        # Convert to NumPy arrays for efficient indexing
        self.train_images = np.array(self.train_images)
        self.train_labels = np.array(self.train_labels, dtype=np.float32)

        # Preprocess validation data
        self.val_images = []
        self.val_labels = []

        for idx in self.val_idx:
            label = self.labels[idx]
            if idx < len(self.rd_original):
                # rd data
                image = self.rd_original[idx, :, (self.rd_original.shape[2]-1), :, :]  # Central slice
            else:
                # dfs data
                idx_adjusted = idx - len(self.rd_original)
                image = self.dfs_original[idx_adjusted, :, (self.dfs_original.shape[2]-1), :, :]  # Central slice

            self.val_images.append(image)
            self.val_labels.append(label)

        # Convert to NumPy arrays
        self.val_images = np.array(self.val_images)
        self.val_labels = np.array(self.val_labels, dtype=np.float32)

    def __len__(self):
        if self.mode == 'train':
            return len(self.train_images)
        else:
            return len(self.val_images)

    def __getitem__(self, idx):
        if self.mode == 'train':
            image = self.train_images[idx]  # Shape: (phases, H, W)
            label = self.train_labels[idx]
        else:
            image = self.val_images[idx]  # Shape: (phases, H, W)
            label = self.val_labels[idx]

        # Optionally, apply any additional transformations here (e.g., normalization)

        return {'image': image, 'label': label}

    def set_mode(self, mode):
        if mode in ['train', 'val']:
            self.mode = mode
        else:
            raise ValueError("Mode should be 'train' or 'val'")

    def get_train_rd_count(self):
        """Returns the number of rd samples in the training set."""
        return np.sum(self.train_labels == 1)

    def get_train_dfs_count(self):
        """Returns the number of dfs samples in the training set."""
        return np.sum(self.train_labels == 0)

In [None]:
# Note: Replace folder paths with where you have the images, demographics file, and annotation boxes file stored
# You may need to make sure that the demographics csv has the first row as header/column name

In [None]:
duke_list = pd.read_csv('/gs/gsfs0/users/rhadidchi/Duke pts.csv')
duke_list['Patient ID'] = duke_list['Patient ID'].str[-3:]
ids = list(duke_list['Patient ID'])
pts = [f'/gs/gsfs0/users/rhadidchi/Duke/manifest-1725137251434/Duke-Breast-Cancer-MRI/Breast_MRI_{id}' for id in ids]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
duke_segmentations = pd.read_excel('/gs/gsfs0/users/rhadidchi/Annotation_Boxes.xlsx')
duke, duke_pts = dcm_to_np_duke(pts, (4,64,64,64), device, duke_segmentations)  # (acquisitions,axial,coronal,sagittal)
np.save('/gs/gsfs0/users/rhadidchi/duke4phases.npy', duke)
duke.shape

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
duke = np.load('/gs/gsfs0/users/rhadidchi/duke4phases.npy') # '/kaggle/input/duke224/duke224.npy'
duke.shape

(922, 4, 64, 64, 64)

In [None]:
pts = pd.read_csv('/gs/gsfs0/users/rhadidchi/Duke pts.csv')
pts['Patient ID'] = list(map(int, pts['Patient ID'].str[-3:]))
pts['HR'] = pts['ER'] + pts['PR']
pts = create_dummies(pts, ['Race and Ethnicity'])
pts['Other Race'] = (pts['Race and Ethnicity_0'] | pts['Race and Ethnicity_3'] | pts['Race and Ethnicity_4'] | pts['Race and Ethnicity_5'] | pts['Race and Ethnicity_6'] | pts['Race and Ethnicity_7'] | pts['Race and Ethnicity_8']).astype(int)
pts['T stage'] = pts['Staging(Tumor Size)# [T]']
pts['N stage'] = pts['Staging(Nodes)#(Nx replaced by -1)[N]']
pts['M stage'] = pts['Staging(Metastasis)#(Mx -replaced by -1)[M]']
pts['Nodal Involvement'] = (pts['N stage'] > 0).astype(int)
pts['High T Stage'] = (pts['T stage']>1).astype(int)
pts['Nottingham N Score'] = pts['N stage'].apply(lambda x: 1 if x == 0 else 2 if x == 1 else 3)
pts['Age'] = pts['Date of Birth (Days)']/(-365.25)
pts = pts.dropna(subset=['Tumor Grade(T)','Tumor Grade(N)','Tumor Grade(M)'])
pts['T Grade'] = pts['Tumor Grade(T)']
pts['N Grade'] = pts['Tumor Grade(N)']
pts['M Grade'] = pts['Tumor Grade(M)']
pts = create_dummies(pts, ['T stage', 'N stage', 'M stage', 'T Grade', 'N Grade', 'M Grade'])
pts['Grade'] = pts['T Grade'] + pts['N Grade'] + pts['M Grade']
pts['Grade 7'] = (pts['Grade'] == 7).astype(int)
pts['Nottingham Grade'] = pts['Grade'].apply(lambda x: 1 if x in [3,4,5] else 2 if x in [6,7] else 3)
pts['Nottingham Grade 1'] = (pts['Nottingham Grade'] == 1).astype(int)
pts['Nottingham Grade 2'] = (pts['Nottingham Grade'] == 2).astype(int)
pts['Nottingham Grade 3'] = (pts['Nottingham Grade'] == 3).astype(int)
pts['Nottingham Grade 2 or 3'] = (pts['Nottingham Grade'] > 1).astype(int)
pts['Oncotype'] = pts['Oncotype score'].apply(lambda x: 0 if x < 18 else 1 if x >= 18 and x <= 30 else 2 if x > 30 else None)
pts = create_dummies(pts, ['Oncotype'])
pts['luminal A'] = ((pts['HR']>0)&(pts['HER2']==0)).astype(int) 
pts['luminal B'] = ((pts['HR']>0)&(pts['HER2']==1)).astype(int) 
pts['luminal'] = ((pts['luminal A'] == 1) | (pts['luminal B'] == 1)).astype(int)
pts['Mol Subtype'] = pts.apply(lambda x: 5 if x['luminal B'] == 1 else x['Mol Subtype'], axis=1)
pts['TN'] = (pts['Mol Subtype'] == 3).astype(int)
pts['HR−HER2+'] = (pts['Mol Subtype'] == 2).astype(int)
pts['HR+HER2+'] = (pts['Mol Subtype'] == 1).astype(int)
pts['TN&HER2+'] = (pts['Mol Subtype'] > 1).astype(int)
pts['Simplified Mol Subtype'] = pts['Mol Subtype'].apply(lambda x: 1 if x==0 else 2 if x in [1,2,5] else 3)
pts['pCR'] = pts['Overall Near-complete Response:  Looser Definition'].apply(lambda x: 1 if x == 1 else 0)
pts['HR+HER2−'] = ((pts['HR']>0) & (pts['HER2']==0)).astype(int)
pts['HR+HER2+'] = ((pts['HR']>0) & (pts['HER2']==1)).astype(int)
pts['Molecular Subtype'] = pts.apply(lambda x:
    'ER+/HER2+(±PR)' if x['ER'] == 1 and x['HER2'] == 1 else
    'ER−/HER2+(±PR)' if x['ER'] == 0 and x['HER2'] == 1 else
    'ER+/HER2−(±PR)' if x['ER'] == 1 and x['HER2'] == 0 else
    'ER−/HER2−(PR−)' if x['ER'] == 0 and x['HER2'] == 0 else
    'Unknown', axis=1)
pts = create_dummies(pts, ['Molecular Subtype'], prefix=False)
pts['Everyone'] = 'Everyone'
pts['Age≥50'] = (pts['Age']>=50).astype(int)

In [48]:
pts['RD'] = ((pts['Days to death (from the date of diagnosis) ']!='NP')|(pts['Recurrence event(s)']==1)|(pts['Days to local recurrence (from the date of diagnosis) ']!='NP')|(pts['Days to distant recurrence(from the date of diagnosis) ']!='NP')).astype(int)

pts['Recurrence_TIME'] = np.where(
    pts['Days to local recurrence (from the date of diagnosis) '] != 'NP',
    pts['Days to local recurrence (from the date of diagnosis) '],
    np.where(
        pts['Days to distant recurrence(from the date of diagnosis) '] != 'NP',
        pts['Days to distant recurrence(from the date of diagnosis) '],
        'NP'
    )
)

pts['Outcome_TIME'] = np.where(
    (pts['Recurrence_TIME'] != 'NP') & (pts['Days to death (from the date of diagnosis) '] != 'NP'),
    np.minimum(pts['Recurrence_TIME'], pts['Days to death (from the date of diagnosis) ']),
    np.where(
        pts['Recurrence_TIME'] != 'NP',
        pts['Recurrence_TIME'],
        np.where(
            pts['Days to death (from the date of diagnosis) '] != 'NP',
            pts['Days to death (from the date of diagnosis) '],
            'NP'
        )
    )
)

pts['Days to last local recurrence free assessment (from the date of diagnosis) '] = pd.to_numeric(pts['Days to last local recurrence free assessment (from the date of diagnosis) '], errors='coerce')
pts['Days to last distant recurrence free assemssment(from the date of diagnosis) '] = pd.to_numeric(pts['Days to last distant recurrence free assemssment(from the date of diagnosis) '], errors='coerce')

pts['Non_Outcome_TIME'] = np.where(
    pts['Days to last local recurrence free assessment (from the date of diagnosis) '].notna() & pts['Days to last distant recurrence free assemssment(from the date of diagnosis) '].notna() ,
    np.maximum(pts['Days to last local recurrence free assessment (from the date of diagnosis) '], pts['Days to last distant recurrence free assemssment(from the date of diagnosis) ']),
    np.where(
        pts['Days to last local recurrence free assessment (from the date of diagnosis) '].notna(),
        pts['Days to last local recurrence free assessment (from the date of diagnosis) '],
        np.where(
            pts['Days to last distant recurrence free assemssment(from the date of diagnosis) '].notna(),
            pts['Days to last distant recurrence free assemssment(from the date of diagnosis) '],
            'NA'
        )
    )
)

pts = pts[pts['Non_Outcome_TIME']!='NA']

pts['Non_Outcome_TIME'] = pd.to_numeric(pts['Non_Outcome_TIME'], errors='coerce')
pts['Outcome_TIME'] = pd.to_numeric(pts['Outcome_TIME'], errors='coerce')

pts['RD_TIME'] = np.where(
    pts['RD'] == 1,
    pts['Outcome_TIME'],  # Take Outcome_TIME if RD == 1
    pts['Non_Outcome_TIME']  # Take Non_Outcome_TIME if RD == 0
)

pts['RD_TIME'] = pts['RD_TIME'] / 365.25

In [None]:
variable = 'Nottingham Grade'
HRp = [id -1 for id in list(pts[pts[variable] == 1]['Patient ID'])]
HRn = [id -1 for id in list(pts[pts[variable] == 3]['Patient ID'])]
HRps = duke[HRp,:3,28:36,:,:]
HRng = duke[HRn,:3,28:36,:,:]
HRps.shape, HRng.shape

In [None]:
MRIs = MRI(HRps, HRng, 42)
MRIs.get_train_rd_count(), MRIs.get_train_dfs_count()

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.cnn_model = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),  
       
        )
       
        self.fc_model = nn.Sequential(
        nn.Linear(in_features=512, out_features=256),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=1)
        )
       
    def forward(self, x):
        x = self.cnn_model(x)
        x = x.view(x.size(0), -1)
        x = self.fc_model(x)
        x = F.sigmoid(x)
       
        return x

set_seed(45)

# Move model to device
model = CNN().to(device)

#model = CNN().to(device)
train_model(model, MRIs, device, lrng_rt=5e-6, EPOCH=240, batch_size=32, weight_decay=1e-4)

In [None]:
pts['Probability of Grade 3'] = model(torch.tensor(duke[(pts['Patient ID']-1),:3,31], dtype=torch.float32).to(device)).squeeze().tolist()
pts['Grade 2 High'] = (pts['Probability of Grade 3'] > 0.50).astype(int)
pts[pts['Nottingham Grade']==2]['Grade 2 High'].value_counts()

In [56]:
pts.to_csv('Duke RFS Model Results.csv')

In [None]:
population = pts[(pts['Nottingham Grade']==2)].reset_index(drop=True)
len(population)

In [None]:
outcomes = {
    'RD': False
}

predictors = ['Grade 2 High', 'Age', 'Nodal Involvement', 'High T Stage', 'ER+/HER2+(±PR)', 'ER−/HER2+(±PR)', 'ER−/HER2−(PR−)']

full = fit_cox_model(population, outcomes, predictors)

In [None]:
outcomes = {
    'RD': False
}

predictors = ['Age', 'Nodal Involvement', 'High T Stage', 'ER+/HER2+(±PR)', 'ER−/HER2+(±PR)', 'ER−/HER2−(PR−)']

reduced = fit_cox_model(population, outcomes, predictors)

In [None]:
comparison_result = bootstrap_compare_c_indices(
    full['RD'],
    reduced['RD'],
    df=population, flip_sign=False
)

print(comparison_result)