# Stratify Data into Training and Testing Sets

**Gregory Way, 2019**

Split the input data into training and testing sets balanced by guide infection.

In [1]:
import os
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
np.random.seed(123)

## Load X Matrices

In [3]:
profile_dir = os.path.join("..", "0.generate-profiles", "data", "profiles")

x_df = pd.concat([pd.read_csv(os.path.join(profile_dir, x)) for x in os.listdir(profile_dir)])

print(x_df.shape)
x_df.head(2)

(3456, 262)


Unnamed: 0,Metadata_Plate,Metadata_Well,Metadata_Assay_Plate_Barcode,Metadata_Plate_Map_Name,Metadata_well_position,Metadata_WellRow,Metadata_WellCol,Metadata_gene_name,Metadata_pert_name,Metadata_broad_sample,...,Nuclei_Texture_InverseDifferenceMoment_AGP_10_0,Nuclei_Texture_InverseDifferenceMoment_AGP_5_0,Nuclei_Texture_InverseDifferenceMoment_DNA_5_0,Nuclei_Texture_InverseDifferenceMoment_Mito_5_0,Nuclei_Texture_SumAverage_AGP_5_0,Nuclei_Texture_SumAverage_DNA_20_0,Nuclei_Texture_SumAverage_ER_5_0,Nuclei_Texture_SumEntropy_RNA_20_0,Nuclei_Texture_Variance_ER_20_0,Nuclei_Texture_Variance_Mito_20_0
0,SQ00014617,A01,SQ00014617,DEPENDENCIES1_HCC44,A01,A,1,EMPTY,EMPTY,,...,-1.9629,-1.383255,0.015549,-0.393101,2.006827,0.216338,-0.011588,-0.974067,-0.675762,-0.670202
1,SQ00014617,A02,SQ00014617,DEPENDENCIES1_HCC44,A02,A,2,MCL1,MCL1-5,,...,-0.881302,-0.974382,0.554052,-1.274361,1.606921,0.407935,-0.087639,-0.39126,-0.518197,-1.080534


## Load Y Matrix

In [4]:
file = os.path.join("data", "cell_health_labels.tsv")
y_df = pd.read_csv(file, sep='\t').drop(["plate_name", "well_col", "well_row"], axis="columns")

print(y_df.shape)
y_df.head(2)

(2303, 72)


Unnamed: 0,cell_id,guide,cc_all_high_n_spots_h2ax_mean,cc_all_large_notround_polynuclear_mean,cc_all_large_round_polyploid_mean,cc_all_n_objects,cc_all_n_spots_mean,cc_all_n_spots_per_nucleus_area_mean,cc_all_nucleus_area_mean,cc_all_nucleus_roundness_mean,...,vb_num_live_cells,vb_percent_all_apoptosis,vb_percent_all_early_apoptosis,vb_percent_all_late_apoptosis,vb_percent_caspase_dead_only,vb_percent_dead,vb_percent_dead_only,vb_percent_live,vb_ros_back_mean,vb_ros_mean
0,ES2,AKT1-1,0.04287,0.007976,0.003988,1003,1.777,67.61,166.2,0.899,...,1465.0,0.0271,0.0119,0.0152,1.64,0.03173,0.01652,0.9683,,
1,ES2,AKT1-1,0.02635,0.005988,0.005988,835,1.582,58.99,156.2,0.9011,...,1575.0,0.03169,0.01463,0.01706,1.405,0.03961,0.02255,0.9598,279.6,2083.0


## Determine how many profiles have status labels

In [5]:
x_groupby_cols = ["Metadata_gene_name", "Metadata_pert_name", "Metadata_cell_line"]

In [6]:
x_meta_df = (
    x_df
    .loc[:, x_groupby_cols]
    .assign(n_measurements=1)
    .groupby(x_groupby_cols)
    .count()
    .reset_index()
    .assign(data_type="cell_painting")
)

print(x_meta_df.shape)
x_meta_df.head(3)

(357, 5)


Unnamed: 0,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,n_measurements,data_type
0,AKT1,AKT1-1,A549,6,cell_painting
1,AKT1,AKT1-1,ES2,6,cell_painting
2,AKT1,AKT1-1,HCC44,6,cell_painting


In [7]:
y_groupby_cols = ["guide", "cell_id"]

In [8]:
y_meta_df = (
    y_df
    .loc[:, y_groupby_cols]
    .assign(n_measurements=1)
    .groupby(y_groupby_cols)
    .count()
    .reset_index()
    .assign(data_type="cell_health")
)

print(y_meta_df.shape)
y_meta_df.head(2)

(364, 4)


Unnamed: 0,guide,cell_id,n_measurements,data_type
0,AKT1-1,A549,4,cell_health
1,AKT1-1,ES2,4,cell_health


In [9]:
all_measurements_df = (
    x_meta_df
    .merge(
        y_meta_df,
        left_on=["Metadata_pert_name", "Metadata_cell_line"],
        right_on=["guide", "cell_id"],
        suffixes=["_paint", "_health"],
        how="inner")
    .sort_values(by=["Metadata_cell_line", "Metadata_pert_name"])
    .reset_index(drop=True)
    .reset_index()
    .rename({"index": "Metadata_profile_id"}, axis='columns')
)

all_measurements_df.Metadata_profile_id = ["profile_{}".format(x) for x in all_measurements_df.Metadata_profile_id]

file = os.path.join("results", "all_profile_metadata.tsv")
all_measurements_df.to_csv(file, sep='\t', index=False)

print(all_measurements_df.shape)
all_measurements_df.head()

(357, 10)


Unnamed: 0,Metadata_profile_id,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,n_measurements_paint,data_type_paint,guide,cell_id,n_measurements_health,data_type_health
0,profile_0,AKT1,AKT1-1,A549,6,cell_painting,AKT1-1,A549,4,cell_health
1,profile_1,AKT1,AKT1-2,A549,6,cell_painting,AKT1-2,A549,4,cell_health
2,profile_2,ARID1B,ARID1B-1,A549,6,cell_painting,ARID1B-1,A549,4,cell_health
3,profile_3,ARID1B,ARID1B-2,A549,6,cell_painting,ARID1B-2,A549,4,cell_health
4,profile_4,ATF4,ATF4-1,A549,6,cell_painting,ATF4-1,A549,4,cell_health


## Aggregate Profiles and Outcomes Further

Because the plates do not match (no way to map wells across experiments), we must aggregate the ~6 cell painting replicates per guide and ~4 cell health replicates per guide together to form a single profile and single outcome.

In [10]:
x_columns = x_groupby_cols + x_df.loc[:, ~x_df.columns.str.startswith("Metadata_")].columns.tolist()

In [11]:
x_agg_df = (
    x_df
    .loc[:, x_columns]
    .groupby(x_groupby_cols)
    .median()
    .reset_index()
    .query("Metadata_gene_name in @all_measurements_df.Metadata_gene_name.unique()")
    .query("Metadata_pert_name in @all_measurements_df.Metadata_pert_name.unique()")
    .query("Metadata_cell_line in @all_measurements_df.Metadata_cell_line.unique()")
    .sort_values(by=["Metadata_cell_line", "Metadata_pert_name"])
    .reset_index(drop=True)
    .merge(all_measurements_df.loc[:, ["Metadata_profile_id"] + x_groupby_cols],
           left_on=x_groupby_cols,
           right_on=x_groupby_cols)
    .loc[:, ["Metadata_profile_id"] + x_columns]
)

print(x_agg_df.shape)
x_agg_df.head(5)

(357, 248)


Unnamed: 0,Metadata_profile_id,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,Cells_AreaShape_Solidity,Cells_AreaShape_Zernike_2_0,Cells_AreaShape_Zernike_2_2,Cells_AreaShape_Zernike_4_2,Cells_AreaShape_Zernike_4_4,Cells_Correlation_Correlation_DNA_ER,...,Nuclei_Texture_InverseDifferenceMoment_AGP_10_0,Nuclei_Texture_InverseDifferenceMoment_AGP_5_0,Nuclei_Texture_InverseDifferenceMoment_DNA_5_0,Nuclei_Texture_InverseDifferenceMoment_Mito_5_0,Nuclei_Texture_SumAverage_AGP_5_0,Nuclei_Texture_SumAverage_DNA_20_0,Nuclei_Texture_SumAverage_ER_5_0,Nuclei_Texture_SumEntropy_RNA_20_0,Nuclei_Texture_Variance_ER_20_0,Nuclei_Texture_Variance_Mito_20_0
0,profile_0,AKT1,AKT1-1,A549,0.291131,0.160739,0.429845,0.322007,0.376163,0.783602,...,-0.896489,-0.645535,-0.054866,-0.178363,0.598618,0.500371,0.791675,0.660538,1.019464,0.299209
1,profile_1,AKT1,AKT1-2,A549,0.546982,-0.074894,0.779504,0.177575,0.769397,0.089346,...,-0.463206,-0.317185,-0.403614,-0.299086,0.3459,0.165794,0.181927,0.007565,0.029475,-0.132508
2,profile_2,ARID1B,ARID1B-1,A549,-0.570857,-0.155407,0.15693,1.042704,0.966247,0.293263,...,0.718614,0.36819,0.266718,0.027826,-0.430721,-0.131549,0.266398,0.429656,0.285818,-0.004287
3,profile_3,ARID1B,ARID1B-2,A549,-0.24051,0.083044,-0.046454,0.652886,0.209439,0.958265,...,0.606592,0.614678,0.498258,0.127643,-0.444072,0.759505,0.245041,1.136552,0.064947,-0.1784
4,profile_4,ATF4,ATF4-1,A549,1.125073,0.883058,-3.061044,-0.529122,-1.522177,3.547337,...,0.522397,0.649992,2.718593,0.477541,-0.120796,1.145846,2.829938,1.730277,0.35289,-0.486197


In [12]:
y_agg_df = (
    y_df
    .groupby(y_groupby_cols)
    .median()
    .reset_index()
    .query("guide in @all_measurements_df.Metadata_pert_name.unique()")
    .query("cell_id in @all_measurements_df.Metadata_cell_line.unique()")
    .sort_values(by=["cell_id", "guide"])
    .reset_index(drop=True)
    .merge(all_measurements_df.loc[:, ["Metadata_profile_id"] + y_groupby_cols],
           left_on=y_groupby_cols,
           right_on=y_groupby_cols)
    .loc[:, ["Metadata_profile_id"] + y_df.columns.tolist()]
)

print(y_agg_df.shape)
y_agg_df.head(2)

(357, 73)


Unnamed: 0,Metadata_profile_id,cell_id,guide,cc_all_high_n_spots_h2ax_mean,cc_all_large_notround_polynuclear_mean,cc_all_large_round_polyploid_mean,cc_all_n_objects,cc_all_n_spots_mean,cc_all_n_spots_per_nucleus_area_mean,cc_all_nucleus_area_mean,...,vb_num_live_cells,vb_percent_all_apoptosis,vb_percent_all_early_apoptosis,vb_percent_all_late_apoptosis,vb_percent_caspase_dead_only,vb_percent_dead,vb_percent_dead_only,vb_percent_live,vb_ros_back_mean,vb_ros_mean
0,profile_0,A549,AKT1-1,0.005171,0.010071,0.003162,4396.5,0.27135,7.489,170.95,...,2077.5,0.0,0.0,0.0,0.0,0.000974,0.000974,0.999027,1018.29,6336.67
1,profile_1,A549,AKT1-2,0.005909,0.014611,0.005027,4618.5,0.281,8.391,169.75,...,1987.0,0.00073,0.000231,0.000467,0.196429,0.00426,0.003757,0.99574,993.3095,6223.115


In [13]:
# Confirm that matrices are aligned
pd.testing.assert_series_equal(x_agg_df.Metadata_profile_id, y_agg_df.Metadata_profile_id, check_names=False)

# Are the guides aligned?
pd.testing.assert_series_equal(x_agg_df.Metadata_pert_name, y_agg_df.guide, check_names=False)

# Are the cells aligned?
pd.testing.assert_series_equal(x_agg_df.Metadata_cell_line, y_agg_df.cell_id, check_names=False)

## Split into Training and Testing

In [14]:
test_proportion = 0.15

In [15]:
x_train_df, x_test_df, y_train_df, y_test_df = train_test_split(
    x_agg_df,
    y_agg_df,
    test_size=test_proportion,
    random_state=42)

In [16]:
print(x_train_df.shape)
print(x_test_df.shape)

(303, 248)
(54, 248)


In [17]:
file = os.path.join("data", "x_train.tsv.gz")
x_train_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "y_train.tsv.gz")
y_train_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "x_test.tsv.gz")
x_test_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "y_test.tsv.gz")
y_test_df.to_csv(file, sep="\t", index=False)