# Import Required Libraries
This section imports all the necessary libraries for the project.

In [None]:
import os
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import (Input, GlobalAveragePooling2D, Dense, Dropout,
                                     GlobalMaxPooling1D, Lambda, concatenate)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import AUC, BinaryAccuracy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import Sequence
from sklearn.metrics import roc_auc_score, accuracy_score, log_loss
from skimage.transform import resize # Using scikit-image for resizing

print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Configuration
Set the paths and hyperparameters for the project.

In [None]:
BASE_DATA_DIR = './MRNet-v1.0/' # Base directory of the extracted MRNet dataset
OUTPUT_DIR = './output_models/' # Where to save trained models and logs

IMG_SIZE = (299, 299) # Input size for Xception
N_CHANNELS = 3 # Xception expects 3 channels
BATCH_SIZE = 8 # Adjust based on GPU memory. Smaller might be needed.
EPOCHS = 50 # Number of training epochs (can be adjusted with EarlyStopping)
LEARNING_RATE = 1e-4
DROPOUT_RATE = 0.5

# Helper Functions
Define utility functions for loading labels, preprocessing slices, and more.

In [None]:
def load_labels(label_dir, task, split):
    """Loads labels for a specific task (acl, meniscus) and split (train, valid)."""
    label_path = os.path.join(label_dir, f"{split}-{task}.csv")
    labels_df = pd.read_csv(label_path, header=None, names=['exam_id', 'label'], index_col='exam_id')
    return labels_df['label'].to_dict()

def preprocess_slice(slice_img, target_size):
    """Preprocesses a single 2D slice."""
    slice_resized = resize(slice_img, target_size, anti_aliasing=True)
    slice_normalized = (slice_resized - np.min(slice_resized)) / (np.max(slice_resized) - np.min(slice_resized) + 1e-6)
    slice_3channel = np.stack([slice_normalized] * 3, axis=-1)
    return slice_3channel.astype(np.float32)

# Keras Sequence for Data Loading
Define a custom Keras Sequence class for loading MRNet data slice by slice.

In [None]:
class MRNetSequence(Sequence):
    def __init__(self, data_dir, plane, labels_acl, labels_meniscus, exam_ids, batch_size, target_size):
        self.data_dir = data_dir
        self.plane = plane
        self.labels_acl = labels_acl
        self.labels_meniscus = labels_meniscus
        self.exam_ids = exam_ids
        self.batch_size = batch_size
        self.target_size = target_size
        self.indices = np.arange(len(self.exam_ids))

    def __len__(self):
        return int(np.ceil(len(self.exam_ids) / self.batch_size))

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size : (index + 1) * self.batch_size]
        batch_exam_ids = [self.exam_ids[i] for i in batch_indices]
        batch_slices, batch_labels_acl, batch_labels_meniscus = [], [], []
        for exam_id in batch_exam_ids:
            exam_path = os.path.join(self.data_dir, self.plane, f"{exam_id}.npy")
            try:
                volume = np.load(exam_path)
            except FileNotFoundError:
                continue
            label_acl = self.labels_acl.get(exam_id, None)
            label_meniscus = self.labels_meniscus.get(exam_id, None)
            if label_acl is None or label_meniscus is None:
                continue
            for i in range(volume.shape[0]):
                slice_img = volume[i]
                processed_slice = preprocess_slice(slice_img, self.target_size)
                batch_slices.append(processed_slice)
                batch_labels_acl.append(label_acl)
                batch_labels_meniscus.append(label_meniscus)
        batch_slices_np = np.array(batch_slices)
        batch_labels_acl_np = np.array(batch_labels_acl, dtype=np.float32)
        batch_labels_meniscus_np = np.array(batch_labels_meniscus, dtype=np.float32)
        return batch_slices_np, {'acl_output': batch_labels_acl_np, 'meniscus_output': batch_labels_meniscus_np}