# Split into train, val and test with no patient id overlap

# plus ensure that the same pct of 1s & 0s in each disease column

In [9]:
import pandas as pd
from sklearn.model_selection import train_test_split

# Load the datasets
df1 = pd.read_csv("data/nih/train-small.csv")  
df2 = pd.read_csv("data/nih/valid-small.csv")

# Concatenate the two DataFrames along rows
df = (pd.concat([df1, df2], axis=0, ignore_index=True)
      .sample(frac=1, random_state=42)
      .reset_index(drop=True))

# Ensure 'PatientId' is correctly specified
patient_df = df.groupby('PatientId').first().reset_index()  # Group by PatientId to get unique rows for each patient

# Specify the columns that you want to use for stratification (columns with 1s and 0s)
stratification_columns = list(df.columns)
stratification_columns.remove('Image')
stratification_columns.remove('PatientId')

# Create a stratification key (sum of the selected columns or any other strategy)
patient_df['stratify_key'] = patient_df[stratification_columns].sum(axis=1)

# Group by stratification key and get counts
stratify_counts = patient_df['stratify_key'].value_counts()

# Combine rare classes to ensure a minimum of min_class_size samples per class
min_class_size = 8
patient_df['stratify_key'] = patient_df['stratify_key'].apply(
    lambda x: x if stratify_counts[x] >= min_class_size else -1  # Use -1 or any other placeholder value
)

# Check the distribution of stratify_key after combining rare classes
print("Distribution of stratification key after re-aggregation:")
print("A stratify_key of 0 means this number of people had none of the 14 diseases")
print(patient_df['stratify_key'].value_counts())

# Verify that no class has fewer than min_class_size samples
assert all(patient_df['stratify_key'].value_counts() >= 2), "Some classes still have fewer than 2 samples!"

# Now perform the train-test split
train_ids, temp_ids = train_test_split(patient_df['PatientId'], 
                                       test_size=0.2, 
                                       stratify=patient_df['stratify_key'], 
                                       random_state=42)

# Create a stratification key for the remaining set
temp_df = patient_df[patient_df['PatientId'].isin(temp_ids)]

# Split temp_ids into validation and test sets
val_ids, test_ids = train_test_split(temp_df['PatientId'], 
                                     test_size=0.5, 
                                     stratify=temp_df['stratify_key'], 
                                     random_state=42)

# Create train, validation, and test sets from the original DataFrame based on patient IDs
train_df = df[df['PatientId'].isin(train_ids)]
val_df   = df[df['PatientId'].isin(val_ids)]
test_df  = df[df['PatientId'].isin(test_ids)]

Distribution of stratification key after re-aggregation:
A stratify_key of 0 means this number of people had none of the 14 diseases
stratify_key
 0    582
 1    262
 2    112
 3     50
-1      8
Name: count, dtype: int64


In [3]:
# Display the sizes of each set
print(f"Train set size: {train_df.shape[0]}")
print(f"Validation set size: {val_df.shape[0]}")
print(f"Test set size: {test_df.shape[0]}")

Train set size: 895
Validation set size: 106
Test set size: 108
