# Stratify Data into Training and Testing Sets

**Gregory Way, 2019**

Split the input data into training and testing sets balanced by guide infection.

In [1]:
import os
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from pycytominer.get_na_columns import get_na_columns

In [2]:
np.random.seed(123)

## Load X Matrices

In [3]:
batch = "CRISPR_PILOT_B1"
data_dir = os.path.join("..", "0.generate-profiles", "data")
profile_dir = os.path.join(data_dir, "profiles", batch)

all_profile_files = []
for plates in os.listdir(profile_dir):
    plate_dir = os.path.join(profile_dir, plates)
    for profile_file in os.listdir(plate_dir):
        if "feature_select" in profile_file:
            all_profile_files.append(os.path.join(plate_dir, profile_file))

In [4]:
x_df = (
    pd.concat(
        [pd.read_csv(x) for x in all_profile_files],
        sort=True
    )
    .rename(
        {
            "Image_Metadata_Plate": "Metadata_Plate",
            "Image_Metadata_Well": "Metadata_Well"
        },
        axis="columns")
)

# Drop all features that have missing values
additional_exclude_features = get_na_columns(x_df, features="infer", cutoff=0)
x_df = x_df.drop(additional_exclude_features, axis="columns")

print(x_df.shape)
x_df.head(2)

(3456, 1288)


Unnamed: 0,Cells_AreaShape_Area,Cells_AreaShape_Center_X,Cells_AreaShape_Center_Y,Cells_AreaShape_EulerNumber,Cells_AreaShape_Extent,Cells_AreaShape_MaxFeretDiameter,Cells_AreaShape_MaximumRadius,Cells_AreaShape_MeanRadius,Cells_AreaShape_MedianRadius,Cells_AreaShape_MinFeretDiameter,...,Nuclei_Texture_Variance_DNA_5_0,Nuclei_Texture_Variance_ER_10_0,Nuclei_Texture_Variance_ER_20_0,Nuclei_Texture_Variance_ER_5_0,Nuclei_Texture_Variance_Mito_10_0,Nuclei_Texture_Variance_Mito_20_0,Nuclei_Texture_Variance_Mito_5_0,Nuclei_Texture_Variance_RNA_10_0,Nuclei_Texture_Variance_RNA_20_0,Nuclei_Texture_Variance_RNA_5_0
0,-0.95807,0.814507,-0.768176,0.0,0.976303,-1.094077,-1.005394,-0.861253,-0.872141,-0.81719,...,-1.08292,-0.403869,-0.273727,-0.892703,-1.142401,-1.181333,-1.404747,1.008647,0.830481,1.025472
1,0.547816,1.209206,-0.413024,0.0,0.822687,0.543376,0.903009,0.897552,0.707922,0.557282,...,0.746211,1.900085,1.554589,1.421829,0.556411,1.134919,0.348166,1.136893,1.948684,0.944716


## Load Y Matrix

In [5]:
file = os.path.join(data_dir, "labels", "normalized_cell_health_labels.tsv")
y_df = pd.read_csv(file, sep='\t').drop(["plate_name", "well_col", "well_row"], axis="columns")

print(y_df.shape)
y_df.head(2)

(2302, 72)


Unnamed: 0,cell_id,guide,cc_all_high_n_spots_h2ax_mean,cc_all_large_notround_polynuclear_mean,cc_all_large_round_polyploid_mean,cc_all_n_objects,cc_all_n_spots_mean,cc_all_n_spots_per_nucleus_area_mean,cc_all_nucleus_area_mean,cc_all_nucleus_roundness_mean,...,vb_num_live_cells,vb_percent_all_apoptosis,vb_percent_all_early_apoptosis,vb_percent_all_late_apoptosis,vb_percent_caspase_dead_only,vb_percent_dead,vb_percent_dead_only,vb_percent_live,vb_ros_back_mean,vb_ros_mean
0,ES2,AKT1-1,0.655229,-0.565658,-0.839186,-0.513748,0.3136,0.263062,0.109983,-0.226513,...,0.281397,-0.279051,-0.429141,-0.177258,-0.9203,-0.139875,-0.016549,0.14057,,
1,ES2,AKT1-1,-0.251336,-0.816445,-0.52594,-0.81981,-0.450799,-0.811628,-0.468875,-0.167787,...,0.543716,-0.221588,-0.311041,-0.149198,-1.070176,-0.046783,0.268559,0.040163,-0.29248,0.008339


## Determine how many profiles have status labels

In [6]:
x_groupby_cols = ["Metadata_gene_name", "Metadata_pert_name", "Metadata_cell_line"]

In [7]:
x_meta_df = (
    x_df
    .loc[:, x_groupby_cols]
    .assign(n_measurements=1)
    .groupby(x_groupby_cols)
    .count()
    .reset_index()
    .assign(data_type="cell_painting")
    .merge(x_df.loc[:, x_groupby_cols + ["Metadata_Well", "Metadata_Plate"]],
           how="left",
           on=x_groupby_cols)
)

print(x_meta_df.shape)
x_meta_df.head(8)

(3456, 7)


Unnamed: 0,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,n_measurements,data_type,Metadata_Well,Metadata_Plate
0,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014611
1,AKT1,AKT1-1,A549,6,cell_painting,O22,SQ00014611
2,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014610
3,AKT1,AKT1-1,A549,6,cell_painting,O22,SQ00014610
4,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014612
5,AKT1,AKT1-1,A549,6,cell_painting,O22,SQ00014612
6,AKT1,AKT1-1,ES2,6,cell_painting,A03,SQ00014615
7,AKT1,AKT1-1,ES2,6,cell_painting,O22,SQ00014615


In [8]:
y_groupby_cols = ["guide", "cell_id"]

In [9]:
y_meta_df = (
    y_df
    .loc[:, y_groupby_cols]
    .assign(n_measurements=1)
    .groupby(y_groupby_cols)
    .count()
    .reset_index()
    .assign(data_type="cell_health")
)

print(y_meta_df.shape)
y_meta_df.head(8)

(364, 4)


Unnamed: 0,guide,cell_id,n_measurements,data_type
0,AKT1-1,A549,4,cell_health
1,AKT1-1,ES2,4,cell_health
2,AKT1-1,HCC44,4,cell_health
3,AKT1-2,A549,4,cell_health
4,AKT1-2,ES2,4,cell_health
5,AKT1-2,HCC44,4,cell_health
6,ARID1B-1,A549,4,cell_health
7,ARID1B-1,ES2,4,cell_health


In [10]:
all_measurements_df = (
    x_meta_df
    .merge(
        y_meta_df,
        left_on=["Metadata_pert_name", "Metadata_cell_line"],
        right_on=["guide", "cell_id"],
        suffixes=["_paint", "_health"],
        how="inner")
    .sort_values(by=["Metadata_cell_line", "Metadata_pert_name"])
    .reset_index(drop=True)
)

file = os.path.join("results", "all_profile_metadata.tsv")
all_measurements_df.to_csv(file, sep='\t', index=False)

print(all_measurements_df.shape)
all_measurements_df.head()

(3456, 11)


Unnamed: 0,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,n_measurements_paint,data_type_paint,Metadata_Well,Metadata_Plate,guide,cell_id,n_measurements_health,data_type_health
0,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014611,AKT1-1,A549,4,cell_health
1,AKT1,AKT1-1,A549,6,cell_painting,O22,SQ00014611,AKT1-1,A549,4,cell_health
2,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014610,AKT1-1,A549,4,cell_health
3,AKT1,AKT1-1,A549,6,cell_painting,O22,SQ00014610,AKT1-1,A549,4,cell_health
4,AKT1,AKT1-1,A549,6,cell_painting,A03,SQ00014612,AKT1-1,A549,4,cell_health


## Aggregate Profiles and Outcomes Further

Because the plates do not match (no way to map wells across experiments), we must aggregate the ~6 cell painting replicates per guide and ~4 cell health replicates per guide together to form a single profile and single outcome.

In [11]:
x_columns = x_groupby_cols + x_df.loc[:, ~x_df.columns.str.startswith("Metadata_")].columns.tolist()

In [12]:
x_agg_df = (
    x_df
    .loc[:, x_columns]
    .groupby(x_groupby_cols)
    .median()
    .reset_index()
    .query("Metadata_gene_name in @all_measurements_df.Metadata_gene_name.unique()")
    .query("Metadata_pert_name in @all_measurements_df.Metadata_pert_name.unique()")
    .query("Metadata_cell_line in @all_measurements_df.Metadata_cell_line.unique()")
    .sort_values(by=["Metadata_cell_line", "Metadata_pert_name"])
    .reset_index(drop=True)
    .reset_index()
    .rename({"index": "Metadata_profile_id"}, axis='columns')
)

x_agg_df.Metadata_profile_id = ["profile_{}".format(x) for x in x_agg_df.Metadata_profile_id]


print(x_agg_df.shape)
x_agg_df.head(5)

(357, 1285)


Unnamed: 0,Metadata_profile_id,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,Cells_AreaShape_Area,Cells_AreaShape_Center_X,Cells_AreaShape_Center_Y,Cells_AreaShape_EulerNumber,Cells_AreaShape_Extent,Cells_AreaShape_MaxFeretDiameter,...,Nuclei_Texture_Variance_DNA_5_0,Nuclei_Texture_Variance_ER_10_0,Nuclei_Texture_Variance_ER_20_0,Nuclei_Texture_Variance_ER_5_0,Nuclei_Texture_Variance_Mito_10_0,Nuclei_Texture_Variance_Mito_20_0,Nuclei_Texture_Variance_Mito_5_0,Nuclei_Texture_Variance_RNA_10_0,Nuclei_Texture_Variance_RNA_20_0,Nuclei_Texture_Variance_RNA_5_0
0,profile_0,AKT1,AKT1-1,A549,0.128119,-0.523081,0.656915,0.0,-0.216196,0.179985,...,0.943312,0.957411,0.980166,0.92517,0.694211,0.396416,0.943503,0.258312,0.146497,0.439103
1,profile_1,AKT1,AKT1-2,A549,-0.151237,0.433182,0.560337,0.0,0.34288,-0.154219,...,0.606264,0.089609,-0.064143,0.230317,-0.062713,-0.133535,0.186883,0.225188,0.021131,0.002909
2,profile_2,ARID1B,ARID1B-1,A549,-0.050721,0.866254,-0.4125,0.0,-0.821827,0.140568,...,-0.537697,0.043528,0.328203,-0.029953,-0.033669,0.071269,-0.317794,0.186336,0.207975,0.258786
3,profile_3,ARID1B,ARID1B-2,A549,0.500161,-0.464468,0.568008,0.0,-0.325497,0.647413,...,0.087831,-0.230552,0.015408,-0.262526,0.010822,-0.025603,0.200284,0.725877,0.172991,0.567965
4,profile_4,ATF4,ATF4-1,A549,3.247545,0.587303,-0.31146,0.0,0.405565,3.322467,...,0.225977,-0.167351,0.243971,-0.05826,-0.073743,-0.379313,0.288709,1.120616,2.423174,0.255682


In [13]:
y_meta_cols = ["Metadata_profile_id", "Metadata_gene_name", "Metadata_pert_name", "Metadata_cell_line"]

y_agg_df = (
    y_df
    .groupby(y_groupby_cols)
    .median()
    .reset_index()
    .query("guide in @all_measurements_df.Metadata_pert_name.unique()")
    .query("cell_id in @all_measurements_df.Metadata_cell_line.unique()")
    .sort_values(by=["cell_id", "guide"])
    .reset_index(drop=True)
    .merge(
        x_agg_df.loc[:, y_meta_cols],
        left_on=["guide", "cell_id"],
        right_on=["Metadata_pert_name", "Metadata_cell_line"]
    )
)

y_columns = y_meta_cols + y_agg_df.loc[:, ~y_agg_df.columns.str.startswith("Metadata_")].columns.tolist()
y_agg_df = y_agg_df.loc[:, y_columns].drop(["guide", "cell_id"], axis="columns")

print(y_agg_df.shape)
y_agg_df.head(2)

(357, 74)


Unnamed: 0,Metadata_profile_id,Metadata_gene_name,Metadata_pert_name,Metadata_cell_line,cc_all_high_n_spots_h2ax_mean,cc_all_large_notround_polynuclear_mean,cc_all_large_round_polyploid_mean,cc_all_n_objects,cc_all_n_spots_mean,cc_all_n_spots_per_nucleus_area_mean,...,vb_num_live_cells,vb_percent_all_apoptosis,vb_percent_all_early_apoptosis,vb_percent_all_late_apoptosis,vb_percent_caspase_dead_only,vb_percent_dead,vb_percent_dead_only,vb_percent_live,vb_ros_back_mean,vb_ros_mean
0,profile_0,AKT1,AKT1-1,A549,0.008156,0.587977,0.01882,0.381501,0.176564,0.187675,...,0.399842,0.0,0.0,0.0,-0.118976,-0.132871,-0.12109,0.132882,0.80697,1.293984
1,profile_1,AKT1,AKT1-2,A549,0.056667,1.264627,0.24145,0.568443,0.235304,0.372684,...,0.10167,0.318027,0.132751,0.467027,0.621374,0.100032,0.074036,-0.099917,0.558041,1.151867


In [14]:
# Confirm that matrices are aligned
pd.testing.assert_series_equal(x_agg_df.Metadata_profile_id, y_agg_df.Metadata_profile_id, check_names=False)

# Are the guides aligned?
pd.testing.assert_series_equal(x_agg_df.Metadata_pert_name, y_agg_df.Metadata_pert_name, check_names=False)

# Are the cells aligned?
pd.testing.assert_series_equal(x_agg_df.Metadata_cell_line, y_agg_df.Metadata_cell_line, check_names=False)

## Split into Training and Testing

In [15]:
test_proportion = 0.15

In [16]:
x_train_df, x_test_df, y_train_df, y_test_df = train_test_split(
    x_agg_df,
    y_agg_df,
    test_size=test_proportion,
    random_state=42)

In [17]:
print(x_train_df.shape)
print(x_test_df.shape)

(303, 1285)
(54, 1285)


In [18]:
file = os.path.join("data", "x_train.tsv.gz")
x_train_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "y_train.tsv.gz")
y_train_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "x_test.tsv.gz")
x_test_df.to_csv(file, sep="\t", index=False)

file = os.path.join("data", "y_test.tsv.gz")
y_test_df.to_csv(file, sep="\t", index=False)