# Data Split

Splitting out test data makes sense to give us an unbiased model performance to expect. 

Given the size of the data, just 4k examples, we will do well to use cross validation to get a stable estimate of evaluation metric.

Since the data has significant missing values in the label, we will use stratified sampling to ensure that the distribution of the label is maintained in the train and validation sets.

In [12]:
filepath = './../../DataSets/ResistanceCiprofloxacinStrict.tsv.gz'

In [13]:
import pandas as pd

In [14]:
df = pd.read_csv(filepath, sep='\t', compression='gzip')

  df = pd.read_csv(filepath, sep='\t', compression='gzip')


In [15]:
df.head()

Unnamed: 0,accession,genus,species,phenotype,mic,3005053,3000830,3003838,3000508,3003890,...,3007751-D87Y,3003926-D87Y,3003709-G46S,3004851-A39T,3004832-A501P,3003381-R20H,3003926-S83I,3003381-G121D,3004832-T483S,3004832-A311V
0,SRR3138666,Campylobacter,jejuni,Susceptible,0.12,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,SRR3138667,Campylobacter,jejuni,Susceptible,0.06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,SRR3138668,Campylobacter,jejuni,Susceptible,0.06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,SRR3138669,Campylobacter,jejuni,Susceptible,0.06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,SRR3138670,Campylobacter,jejuni,Susceptible,0.06,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [16]:
df.shape

(3881, 880)

In [18]:
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit


In [19]:
def stratified_split_with_missing(df, target_column, test_size=0.2, random_state=42):
    """
    Splits the DataFrame into training and test sets, ensuring that all rows with missing values
    in the target column are included in the training set, and the remaining data is split 80-20
    based on the target column.

    Parameters:
    df (pd.DataFrame): The input DataFrame.
    target_column (str): The name of the target column to stratify on.
    test_size (float): The proportion of the dataset to include in the test split.
    random_state (int): Random seed for reproducibility.

    Returns:
    pd.DataFrame: Training set.
    pd.DataFrame: Test set.
    """
    # Separate rows with missing values in the target column
    missing_target = df[df[target_column].isna()]
    non_missing_target = df[df[target_column].notna()]

    # Create the stratified shuffle split object
    split = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)

    # Perform the split on the non-missing target data
    for train_index, test_index in split.split(non_missing_target, non_missing_target[target_column]):
        strat_train_set = non_missing_target.iloc[train_index]
        strat_test_set = non_missing_target.iloc[test_index]

    # Add the rows with missing target to the training set
    strat_train_set = pd.concat([strat_train_set, missing_target])

    return strat_train_set, strat_test_set

In [20]:
# Perform the stratified split
strat_train_set, strat_test_set = stratified_split_with_missing(df, 'phenotype')

# Display the shapes of the resulting splits
print("Training set shape:", strat_train_set.shape)
print("Test set shape:", strat_test_set.shape)

# Display the distribution of 'phenotype' in the splits
print("\nTraining set 'phenotype' distribution:")
print(strat_train_set['phenotype'].value_counts(dropna=False))

print("\nTest set 'phenotype' distribution:")
print(strat_test_set['phenotype'].value_counts(dropna=False))

Training set shape: (3317, 880)
Test set shape: (564, 880)

Training set 'phenotype' distribution:
phenotype
Susceptible    2083
NaN            1063
Resistant       171
Name: count, dtype: int64

Test set 'phenotype' distribution:
phenotype
Susceptible    521
Resistant       43
Name: count, dtype: int64


In [23]:
# Save the stratified train and test sets to tab-separated gzip compressed CSV files
strat_train_set.to_csv('./data/stratified_train_set.tsv.gz', sep='\t', compression='gzip', index=False)
strat_test_set.to_csv('./data/stratified_test_set.tsv.gz', sep='\t', compression='gzip', index=False)

## Sanity Check

In [24]:
# Load the stratified train and test sets
strat_train_set = pd.read_csv('./data/stratified_train_set.tsv.gz', sep='\t', compression='gzip')
strat_test_set = pd.read_csv('./data/stratified_test_set.tsv.gz', sep='\t', compression='gzip')

# Check that the sum of rows in stratified train and test sets equals the number of rows in the original dataframe
total_rows = strat_train_set.shape[0] + strat_test_set.shape[0]
original_rows = df.shape[0]

print(f"Total rows in stratified train and test sets: {total_rows}")
print(f"Total rows in original dataframe: {original_rows}")
print(f"Do they match? {'Yes' if total_rows == original_rows else 'No'}")

Total rows in stratified train and test sets: 3881
Total rows in original dataframe: 3881
Do they match? Yes


  strat_train_set = pd.read_csv('./data/stratified_train_set.tsv.gz', sep='\t', compression='gzip')
