Splitting the dataset into training, validation, and test sets
- Match clinical data with images by ID1 value and extracted ID from image names
- Training + Validation and Test Split: The data is split into two parts: a large part for training and validation (80% of the data) and a smaller part for testing (20% of the data).
- Training and Validation Split: The training + validation set is further split into training (90%) and validation (10%). This split allows the model to be trained on one set and validated on a different set before evaluating performance on the test set.

In [3]:

import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import os

# Load clinical data from Excel
clinical_data_path = '/Users/ellaquan/Project/preprocessed_clinical_data.xlsx'
clinical_df = pd.read_excel(clinical_data_path)

# Ensure ID matching by trimming whitespaces or correcting any discrepancies in the format
clinical_df['ID1'] = clinical_df['ID1'].str.strip()

# Check the clinical data for correct loading
print("Clinical Data Head:")
print(clinical_df.head())

# Map clinical data by 'ID1'
clinical_map = clinical_df.set_index('ID1').to_dict('index')

# Get the list of image file paths (Assuming .png images after preprocessing)
image_dir = '/Users/ellaquan/Project/Preprocessed_Images'
image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith('.png')]

# Extract ID1 from image file names (e.g., 'D2-0002' from 'D2-0002_1-3.png')
def extract_id_from_image(file_path):
    return file_path.split('/')[-1].split('_')[0]

# Debugging: Print all ID1 values from clinical data for comparison
print("ID1 values from clinical data:")
print(clinical_df['ID1'].tolist())

# Replicate clinical data for each image, only keeping matched clinical data
matched_clinical_data = []
matched_image_paths = []
for img_path in image_paths:
    patient_id = extract_id_from_image(img_path)  # Extract 'D2-0002' from file path
    
    # Debugging: Print the extracted ID1
    print(f"Extracted ID1 from image: {patient_id}")
    
    if patient_id in clinical_map:
        matched_clinical_data.append(clinical_map[patient_id])
        matched_image_paths.append(img_path)  # Keep the image path only if there's a match
    else:
        print(f"Warning: No clinical data found for {patient_id}")

# Convert matched clinical data to a DataFrame
if matched_clinical_data:
    matched_clinical_df = pd.DataFrame(matched_clinical_data)

    # Convert clinical data to a numpy array
    preprocessed_clinical_data = matched_clinical_df[['Age']].values

    # Convert to PyTorch tensor for consistency
    preprocessed_clinical_data_tensor = torch.tensor(preprocessed_clinical_data, dtype=torch.float32)

else:
    raise ValueError("No valid clinical data available after matching images.")

# Function to split data into training, validation, and test sets
def split_data(image_paths, clinical_data, test_size=0.2, val_size=0.1):
    # First split into training+validation and test sets
    train_image_paths, test_image_paths, train_clinical_data, test_clinical_data = train_test_split(
        image_paths, clinical_data, test_size=test_size, random_state=42
    )
    
    # Then split the training set into training and validation sets
    train_image_paths, val_image_paths, train_clinical_data, val_clinical_data = train_test_split(
        train_image_paths, train_clinical_data, test_size=val_size, random_state=42
    )
    
    return (train_image_paths, val_image_paths, test_image_paths,
            train_clinical_data, val_clinical_data, test_clinical_data)

# Convert clinical tensor to numpy for splitting
preprocessed_clinical_data = preprocessed_clinical_data_tensor.numpy()

# Perform the data split
train_img, val_img, test_img, train_clinical, val_clinical, test_clinical = split_data(
    matched_image_paths, preprocessed_clinical_data
)

# Output the results
print(f"Training Set: {len(train_img)} images, {len(train_clinical)} clinical records")
print(f"Validation Set: {len(val_img)} images, {len(val_clinical)} clinical records")
print(f"Test Set: {len(test_img)} images, {len(test_clinical)} clinical records")


Clinical Data Head:
       ID1 LeftRight       Age  number    abnormality classification  \
0  D2-0001         L  0.651515       2  calcification      Malignant   
1  D2-0002         R  0.727273       2  calcification      Malignant   
2  D2-0003         L  0.348485       2  calcification      Malignant   
3  D2-0004         L  0.257576       2  calcification      Malignant   
4  D2-0005         R  0.303030       2  calcification      Malignant   

         subtype  target  
0      Luminal B       1  
1      Luminal B       1  
2      Luminal B       1  
3      Luminal B       1  
4  HER2-enriched       0  
ID1 values from clinical data:
['D2-0001', 'D2-0002', 'D2-0003', 'D2-0004', 'D2-0005', 'D2-0006', 'D2-0007', 'D2-0008', 'D2-0009', 'D2-0010', 'D2-0011', 'D2-0012', 'D2-0013', 'D2-0014', 'D2-0015', 'D2-0016', 'D2-0017', 'D2-0018', 'D2-0019', 'D2-0020', 'D2-0021', 'D2-0022', 'D2-0023', 'D2-0024', 'D2-0025', 'D2-0026', 'D2-0027', 'D2-0028', 'D2-0029', 'D2-0030', 'D2-0031', 'D2-0032', '