# Import Library

In [None]:
import os
import pandas as pd
import numpy as np
import pydicom
import psutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import models
from skimage.transform import resize
from sklearn.model_selection import train_test_split
import gc
import warnings
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler


# Config

In [None]:
# Constants
BATCH_SIZE = 8
CHUNK_SIZE = 8
START_IDX = 0
MAX_IMAGES_PER_SERIES = 128
CHECKPOINT_PATH = '/kaggle/working/model_checkpoint.pth'
IMG_SIZE = (224, 224)
NUM_EPOCHS = 5
LEARNING_RATE = 5e-7
LOSS_FUNCTION = 'binary_crossentropy'
METRICS = ['accuracy']

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def read_clinical_data(file_path):
    '''
    Reads clinical data from an Excel file.

    This function attempts to find a sheet containing clinical data in the given Excel file.
    It searches for sheets with names containing 'clinical', 'клинические', or 'сlinical' 
    (case-insensitive).

    Parameters:
    file_path (str): The path to the Excel file.

    Returns:
    pandas.DataFrame: A DataFrame containing the clinical data from the identified sheet.

    Raises:
    ValueError: If no sheet with clinical data is found in the Excel file.

    Note:
    The function uses pandas to read the Excel file and assumes that the clinical data
    is contained in a single sheet.
    '''
    xls = pd.ExcelFile(file_path)
    sheet_names_lower = [sheet.lower() for sheet in xls.sheet_names]
    clinical_sheet = [sheet for sheet in sheet_names_lower if 'clinical' in sheet 
                      or 'клинические' in sheet 
                      or 'сlinical' in sheet]
    if not clinical_sheet:
        raise ValueError('No clinical sheet found in the XLSX file.')
    
    original_sheet_name = xls.sheet_names[sheet_names_lower.index(clinical_sheet[0])]
    df_clinical = pd.read_excel(xls, original_sheet_name)
    return df_clinical

def read_dicom_files(study_path):
    '''
    Reads DICOM files from a study directory.

    This function scans the given study directory for DICOM files (.dcm) organized in series 
    subdirectories. It reads each DICOM file and groups them by series.

    Parameters:
    study_path (str): The path to the study directory containing series subdirectories with DICOM files.

    Returns:
    dict: A dictionary where keys are series names and values are lists of pydicom.dataset.FileDataset 
          objects representing the DICOM images in each series.

    Note:
    - The function assumes that DICOM files are organized in series subdirectories within the study directory.
    - It sets default values for 'SamplesPerPixel' and 'PhotometricInterpretation' if these tags are missing.
    - Empty series (directories without valid DICOM files) are not included in the returned dictionary.
    '''
    dicom_series = {}
    for series in os.listdir(study_path):
        series_path = os.path.join(study_path, series)
        if os.path.isdir(series_path):
            series_images = []
            for file in os.listdir(series_path):
                if file.endswith('.dcm'):
                    dicom_path = os.path.join(series_path, file)
                    dicom_image = pydicom.dcmread(dicom_path, force=True)
                    if 'SamplesPerPixel' not in dicom_image:
                        dicom_image.SamplesPerPixel = 1
                    if 'PhotometricInterpretation' not in dicom_image:
                        dicom_image.PhotometricInterpretation = 'MONOCHROME2'
                    series_images.append(dicom_image)
            if series_images:
                dicom_series[series] = series_images
    return dicom_series

### Note

```python
max_value = np.max(image_array)
image_array = image_array / max_value if max_value > 0 else np.zeros_like(image_array)
```

- Need to reconsider the division by max_value for normalization:
  - If max_value differs between images, two points with equal values in different images might not be equal after normalization. Does this affect training results?
  - If it doesn't affect results now, could it impact when applying Hounsfield Units (HU)? When values are filtered to highlight specific parts, does normalization negate this effect?
  - Since HU can have negative values [-1000; 2000], what if the max in an image is 0? Would all pixels become 0 after division?

- Anti-aliasing blurs edges. Is there a way to control the degree of blurring? Can we print some sample images to see the effect?
  - If we're not cropping, we might not need anti-aliasing.

In [None]:
def preprocess_images(dicom_images):
    '''
    Preprocesses a list of DICOM images for use in a machine learning model.

    This function performs the following preprocessing steps on each DICOM image:
    1. Extracts the pixel array from the DICOM image.
    2. Normalizes the pixel values to the range [0, 1].
    3. Resizes the image to a predefined size (IMG_SIZE).
    4. Converts grayscale images to RGB by replicating the single channel.

    Parameters:
    dicom_images (list): A list of pydicom.dataset.FileDataset objects representing DICOM images.

    Returns:
    numpy.ndarray: A 4D numpy array of preprocessed images with shape (n_images, height, width, 3),
                   where each image is normalized, resized, and in RGB format.

    Note:
    - The function assumes that IMG_SIZE is a predefined constant representing the target image size.
    - Images are resized using anti-aliasing for better quality.
    - Grayscale images are converted to RGB by replicating the single channel three times.
    - If an image's maximum pixel value is 0, it will be converted to an all-zero array.
    '''
    preprocessed_images = []
    for dicom_image in dicom_images:
        image_array = dicom_image.pixel_array
        max_value = np.max(image_array)
        image_array = image_array / max_value if max_value > 0 else np.zeros_like(image_array)
        resized_image = resize(image_array, IMG_SIZE, anti_aliasing=True)
        rgb_image = np.stack((resized_image,) * 3, axis=-1) if len(resized_image.shape) == 2 else resized_image
        preprocessed_images.append(rgb_image)
    return np.array(preprocessed_images)

In [None]:
def process_batch(studies_folder, dataset_path, start_idx=START_IDX, batch_size=BATCH_SIZE, max_images_per_series=MAX_IMAGES_PER_SERIES):
    '''
    Process a batch of medical imaging studies, extracting DICOM images and clinical data.

    This function reads DICOM files from a specified folder, preprocesses the images,
    and extracts corresponding clinical data from an Excel file. It processes a subset
    of studies based on the start index and batch size.

    Parameters:
    studies_folder (str): Name of the folder containing the studies.
    dataset_path (str): Path to the dataset root directory.
    start_idx (int): Starting index for processing studies in the folder.
    batch_size (int): Number of studies to process in this batch.
    max_images_per_series (int): Maximum number of images to process per series.

    Returns:
    tuple: A tuple containing three elements:
        - patient_data (dict): A dictionary with study UIDs as keys, containing
          preprocessed images and clinical labels for each study.
        - all_clinical_data (pd.DataFrame): A DataFrame containing clinical data
          for all processed studies.
        - next_start_idx (int): The next starting index for subsequent batches.

    Notes:
    - The function expects a specific folder structure and file naming convention.
    - It processes DICOM images and extracts clinical data from an Excel file.
    - Target columns for clinical data are predefined and must be binary (0 or 1).
    - Warnings are logged for missing clinical data or unprocessed images.
    '''
    all_processed_series = {}
    all_clinical_data = pd.DataFrame()
    
    target_columns = ['epidural hemorrhage', 'subarachnoid hemorrhage', 'subdural hemorrhage', 
                      'intracerebral hemorrhage', 'multiple hemorrhages', 'skull fracture']
    
    if studies_folder != '200_300_studies':
        studies_folder_path = os.path.join(dataset_path, studies_folder, studies_folder)
        print(f"Processing folder {studies_folder}")
        xlsx_files = [f for f in os.listdir(studies_folder_path) if f.endswith('.xlsx')]
        
        if len(xlsx_files) == 1:
            clinical_data_path = os.path.join(studies_folder_path, xlsx_files[0])
            clinical_data = read_clinical_data(clinical_data_path)    
            clinical_data.columns = clinical_data.columns.str.lower()
            clinical_data = clinical_data[clinical_data['comment'] != 'Study without report']

            # Check if target columns are binary and drop non-binary rows
            for column in target_columns:
                if column in clinical_data.columns:
                    clinical_data = clinical_data[clinical_data[column].isin([0, 1])]
                else:
                    print(f"Warning: Column '{column}' not found in clinical data.")

            study_uids_in_folder = [uid for uid in os.listdir(studies_folder_path) if os.path.isdir(os.path.join(studies_folder_path, uid))]
            
            for study_uid in study_uids_in_folder[start_idx:start_idx+batch_size]:
                study_path = os.path.join(studies_folder_path, study_uid)
                
                dicom_series = read_dicom_files(study_path)
                
                for series_uid, series_images in dicom_series.items():
                    processed_images = preprocess_images(series_images)
                    
                    if len(processed_images) > 0:
                        if study_uid not in all_processed_series:
                            all_processed_series[study_uid] = {}
                        all_processed_series[study_uid][series_uid] = processed_images
                        
                        study_clinical_data = clinical_data[clinical_data['study_uid'] == study_uid]
                        if not study_clinical_data.empty:
                            all_clinical_data = pd.concat([all_clinical_data, study_clinical_data], ignore_index=True)
                            print(f"Processed {len(series_images)} DICOM images for study UID: {study_uid}, series UID: {series_uid}")
                        else:
                            print(f"Warning: No clinical data found for study UID: {study_uid}")
                    else:
                        print(f"No images processed for study UID: {study_uid}, series UID: {series_uid}")

        else:
            print(f"Error: Expected exactly one XLSX file in {studies_folder_path}")

    print("Number of processed studies:", len(all_processed_series))
    print("Number of rows in clinical data:", len(all_clinical_data))

    patient_data = {}

    for study_uid, series_dict in all_processed_series.items():
        study_clinical_data = all_clinical_data[all_clinical_data['study_uid'] == study_uid]
        if not study_clinical_data.empty:
            patient_data[study_uid] = {
                'images': series_dict,
                'labels': study_clinical_data.iloc[0][target_columns].to_dict()
            }
        else:
            print(f"Warning: No clinical data for study UID: {study_uid}")

    return patient_data, all_clinical_data, start_idx + len(all_processed_series)

In [None]:
from torchvision.models.resnet import ResNet50_Weights

def create_or_load_model(checkpoint_path):
    '''
    Creates a new ResNet50 model or loads a previously saved model from a checkpoint.

    This function initializes a ResNet50 model pre-trained on ImageNet and modifies
    its fully connected layer for a 6-class classification task. If a checkpoint file
    exists at the specified path, it attempts to load the model weights from that checkpoint.

    Parameters:
    checkpoint_path (str): The file path where the model checkpoint is expected to be found.

    Returns:
    torch.nn.Module: A ResNet50 model, either newly initialized or loaded from a checkpoint,
                     with the final layer modified for 6-class output.

    Notes:
    - The function uses the DEFAULT weights for ResNet50 from torchvision.
    - The model's fully connected layer is modified to output 6 classes.
    - If a checkpoint file exists:
        - It attempts to load the model state from the checkpoint.
        - If loading fails, it falls back to creating a new model.
    - If no checkpoint file exists, a new model is created.
    - Any errors during checkpoint loading are caught and reported.

    Raises:
    - May raise exceptions related to file I/O or torch.load operations.
    '''
    model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 6)  # Change to 6 outputs for 6 labels
    
    if os.path.exists(checkpoint_path):
        print(f"Attempting to load checkpoint from {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Checkpoint loaded successfully")
        except (KeyError, RuntimeError) as e:
            print(f"Error loading checkpoint: {e}")
            print("Creating new model instead")
    else:
        print("No checkpoint found. Creating new model")
    
    return model

In [None]:
def custom_collate(batch):
    '''
    Custom collate function for creating batches with variable-sized image sequences.

    This function is designed to be used with PyTorch's DataLoader to handle batches
    where each item may contain a different number of images. It pads the image sequences
    to ensure all items in the batch have the same number of images.

    Parameters:
    batch (list): A list of tuples, where each tuple contains:
                  (images, labels, patient_id, series_uid)
                  - images: A tensor of shape (num_images, channels, height, width)
                  - labels: A tensor of labels
                  - patient_id: A string or identifier for the patient
                  - series_uid: A string or identifier for the image series

    Returns:
    tuple: A tuple containing:
           - images (torch.Tensor): Padded and stacked images of shape 
             (batch_size, max_images, channels, height, width)
           - labels (torch.Tensor): Stacked labels
           - patient_ids (tuple): Original patient IDs
           - series_uids (tuple): Original series UIDs

    Notes:
    - The function pads shorter image sequences with zeros to match the length
      of the longest sequence in the batch.
    - This approach allows for efficient processing of variable-length sequences
      in a single batch.
    - The original patient IDs and series UIDs are returned as tuples to maintain
      the correspondence with the padded images and labels.
    '''
    images, labels, patient_ids, series_uids = zip(*batch)
    
    # Pad images to the same size
    max_images = max(img.size(0) for img in images)
    padded_images = []
    for img in images:
        if img.size(0) < max_images:
            padding = torch.zeros(max_images - img.size(0), *img.size()[1:])
            img = torch.cat([img, padding], dim=0)
        padded_images.append(img)
    
    images = torch.stack(padded_images)
    labels = torch.stack(labels)
    
    return images, labels, patient_ids, series_uids

In [None]:
class PatientDataset(Dataset):
    '''
    A custom PyTorch Dataset for handling patient medical imaging data and associated labels.

    This dataset is designed to work with a dictionary of patient data, where each patient
    may have multiple image series. It prepares the data for use in a PyTorch DataLoader,
    organizing it into a format suitable for training or evaluation of medical imaging models.

    Attributes:
    patient_data (dict): A dictionary containing patient data, where each key is a patient ID
                         and each value is a dictionary with 'images' and 'labels' keys.
    patient_ids (list): A list of patient IDs to include in the dataset.
    series_data (list): A list of tuples, each containing (image_tensor, label_tensor, 
                        patient_id, series_uid) for each series of each patient.

    Methods:
    __init__(self, patient_data, patient_ids): Initializes the dataset.
    __len__(self): Returns the total number of image series in the dataset.
    __getitem__(self, idx): Retrieves a single item (image series, label, patient ID, series UID)
                            from the dataset.

    The __init__ method:
    - Processes the patient_data dictionary.
    - Converts image data to PyTorch tensors and applies necessary transformations.
    - Creates label tensors for each patient.
    - Organizes all series data into a list of tuples for easy access.

    The __getitem__ method:
    - Returns a tuple containing:
      1. Image tensor of shape (num_images, channels, height, width)
      2. Label tensor
      3. Patient ID
      4. Series UID

    Note:
    - This class assumes that the 'labels' in patient_data are consistent across all patients
      and are in a format that can be directly converted to a tensor.
    - Image data is expected to be in a format that can be converted to a PyTorch tensor
      and have dimensions (num_images, height, width, channels).
    '''
    def __init__(self, patient_data, patient_ids):
        self.patient_data = patient_data
        self.patient_ids = patient_ids
        self.series_data = []
        
        for patient_id in self.patient_ids:
            patient = self.patient_data[patient_id]
            labels = patient['labels']
            label_tensor = torch.tensor([labels[key] for key in sorted(labels.keys())]).float()
            
            for series_uid, series_images in patient['images'].items():
                series_tensor = torch.tensor(series_images).float().permute(0, 3, 1, 2)
                self.series_data.append((series_tensor, label_tensor, patient_id, series_uid))
    
    def __len__(self):
        return len(self.series_data)
    
    def __getitem__(self, idx):
        return self.series_data[idx]

In [None]:
def train_model_in_batches(model, studies_folders, dataset_path, batch_size=BATCH_SIZE, num_epochs=NUM_EPOCHS, accumulation_steps=4):
    '''
    Trains a PyTorch model on medical imaging data in batches.

    This function implements a training loop that processes multiple studies folders,
    handles large datasets by processing them in batches, and uses mixed precision
    training with gradient accumulation.

    Parameters:
    model (torch.nn.Module): The PyTorch model to be trained.
    studies_folders (list): List of folder names containing the study data.
    dataset_path (str): Path to the root directory of the dataset.
    batch_size (int): Number of images to process in each sub-batch.
    num_epochs (int): Number of training epochs.
    accumulation_steps (int): Number of steps to accumulate gradients before updating weights.

    Returns:
    torch.nn.Module: The trained model.

    Notes:
    - Uses BCEWithLogitsLoss as the loss function and AdamW as the optimizer.
    - Implements mixed precision training using torch.cuda.amp.
    - Processes data in batches to handle large datasets.
    - Saves a checkpoint after processing each batch of studies.
    - Assumes the existence of global constants:
      LEARNING_RATE, IMG_SIZE, CHUNK_SIZE, MAX_IMAGES_PER_SERIES, CHECKPOINT_PATH

    The training process:
    1. Iterates over epochs and study folders.
    2. Processes batches of patient data using the `process_batch` function.
    3. Creates a DataLoader for each batch of patient data.
    4. For each batch:
       - Moves data to the appropriate device (CPU/GPU).
       - Processes sub-batches of images to handle memory constraints.
       - Computes loss, scales gradients, and performs backpropagation.
       - Updates model weights after accumulating gradients.
    5. Saves a checkpoint after processing each batch of studies.

    Requirements:
    - Requires the PatientDataset class and custom_collate function to be defined.
    - Assumes the availability of CUDA for GPU acceleration.
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scaler = GradScaler()

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        for studies_folder in studies_folders:
            start_idx = 0
            while True:
                patient_data, clinical_data, next_start_idx = process_batch(studies_folder, dataset_path, start_idx, BATCH_SIZE, MAX_IMAGES_PER_SERIES)
                
                if not patient_data:
                    break  # No more data to process in this folder
                
                dataset = PatientDataset(patient_data, list(patient_data.keys()))
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=CHUNK_SIZE, shuffle=True, collate_fn=custom_collate)
                
                for i, (images, labels, patient_ids, series_uids) in enumerate(dataloader):
                    print(f"Batch {i+1} - Images shape: {images.shape}, Labels shape: {labels.shape}")
                    images = images.to(device)
                    labels = labels.to(device)
                    
                    optimizer.zero_grad()
                    
                    for j in range(0, images.size(1), batch_size):
                        batch = images[:, j:j+batch_size].contiguous().view(-1, 3, IMG_SIZE[0], IMG_SIZE[1])
                        batch_labels = labels.repeat_interleave(batch_size, dim=0)
                        
                        print(f"  Sub-batch {j//batch_size + 1} - Batch shape: {batch.shape}, Batch labels shape: {batch_labels.shape}")

                        with autocast():
                            outputs = model(batch)
                            print(f"  Outputs shape: {outputs.shape}")
                            loss = criterion(outputs, batch_labels[:outputs.size(0)])
                        
                        scaler.scale(loss).backward()
                        
                        if (j + batch_size) % (batch_size * accumulation_steps) == 0:
                            scaler.step(optimizer)
                            scaler.update()
                            optimizer.zero_grad()

                    print(f"Batch {i+1} processed")
                
                start_idx = next_start_idx
                
                # Save checkpoint after each batch
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, CHECKPOINT_PATH)
                print(f"Checkpoint saved to {CHECKPOINT_PATH}")
                
    return model

In [None]:
def save_final_model(model, path=CHECKPOINT_PATH):
    torch.save(model.state_dict(), path)
    print(f"Final model saved to {path}")

dataset_path = '/kaggle/input/mosmeddata-ct-hemorrhage-type-viii/MosMedData-CT-HEMORRHAGE-type VIII/'

# Get a list of all studies folders
# studies_folders = [f for f in os.listdir(dataset_path) if f.endswith('_studies')]
studies_folders = ['400_500_studies']
print(studies_folders)

#### Main execution
model = create_or_load_model(CHECKPOINT_PATH)

# Train the model in batches
model = train_model_in_batches(model, studies_folders, dataset_path)

# Save the final model
save_final_model(model)