In [None]:
import random
import pandas as pd
import numpy as np

In [None]:
# CONSTANTS
# random.seed(1034)
# np.random.seed(1034)

POS_CONTROL = set(['c1ccc(-c2nn3c(c2-c2ccnc4cc(OCCN5CCOCC5)ccc24)CCC3)nc1', \
                  'COc1ncc2cc(C(=O)Nc3cc(C(O)=NCc4cccc(Cl)c4)ccc3Cl)c(O)nc2n1', \
                  'CC1CC2C3CC=C4CC(=O)C=CC4(C)C3(F)C(O)CC2(C)C1(O)C(=O)CO', \
                  'C=CC1CN2CCC1CC2C(O)c1ccnc2ccc(OC)cc12', \
                  'CCOC(=O)C1OC1C(O)=NC(CC(C)C)C(O)=NCCC(C)C', \
                  'Cc1csc(-c2nnc(Nc3ccc(Oc4ncccc4-c4cc[nH]c(=N)n4)cc3)c3ccccc23)c1', \
                  'O=C(c1ccccc1)N1CCC(CCCCN=C(O)C=Cc2cccnc2)CC1', \
                  'CC(C)N=C(O)N1CCC(N=C2Nc3cc(F)ccc3N(CC(F)F)c3ccc(Cl)cc32)C1'])

CONTROL = 'CS(C)=O'

SEEN_SOURCE = ["source_3", "source_5", "source_7", "source_8", "source_11"]
UNSEEN_SOURCE = ["source_2"]

ALL_COMPOUND = POS_CONTROL.union(CONTROL)
ALL_SOURCE = SEEN_SOURCE + UNSEEN_SOURCE

# Read in Metadata
metadata = pd.read_csv("datasets/metadata/local_image_paths.csv")
plate = pd.read_csv("datasets/metadata/plate.csv.gz")
well = pd.read_csv("datasets/metadata/well.csv.gz")
metadata = metadata.merge(plate, on=['Metadata_Plate', 'Metadata_Batch', 'Metadata_Source'])
compound = pd.read_csv("datasets/metadata/compound.csv")
metadata = metadata.merge(compound, on=['Metadata_JCP2022'])

In [None]:
# POSCTL dataset

pos_metadata = metadata.loc[(metadata['Metadata_PlateType'] == 'COMPOUND') & (metadata['Metadata_SMILES'].isin(ALL_COMPOUND)) & (metadata['Metadata_Source'].isin(ALL_SOURCE))]

def sample_train_test_ID_posctl(pos_metadata, sources):
    test_data = pd.DataFrame()
    # Batch Stratification
    test_batches = []
    for s in sources:
        test_batch = np.random.choice(pos_metadata.loc[pos_metadata['Metadata_Source'] == s]['Metadata_Batch'].unique(), 2, replace=False).tolist()
        test_data = pd.concat([test_data,  pos_metadata.loc[pos_metadata['Metadata_Batch'].isin(test_batch)]])
        test_batches.extend(test_batch)
    train_data = pos_metadata.loc[~pos_metadata['Metadata_Batch'].isin(test_batches)]
    train_image_num = len(train_data.groupby(['Metadata_Plate','Metadata_Well']))
    test_image_num = len(test_data.groupby(['Metadata_Plate','Metadata_Well']))
    train_batch_num = len(train_data['Metadata_Batch'].unique())
    test_batch_num = len(test_data['Metadata_Batch'].unique())
    train_plate_num = len(train_data['Metadata_Plate'].unique())
    test_plate_num = len(test_data['Metadata_Plate'].unique())
    print(f"""
          Train well # : {train_image_num}
          Test well # : {test_image_num}
          Train batch # : {train_batch_num}
          Test batch # : {test_batch_num}
          Train plate # : {train_plate_num}
          Test plate # : {test_plate_num}
    """)
    return train_data, test_data
train_data, test_data = sample_train_test_ID(pos_metadata, ALL_SOURCE)

def sample_train_test_OOD_posctl(pos_metadata, seen_sources, unseen_sources, knn=False):
    train_data = pos_metadata.loc[pos_metadata['Metadata_Source'].isin(seen_sources)]
    # Batch Stratification
    if knn:
        test_batch = np.random.choice(pos_metadata.loc[pos_metadata['Metadata_Source'].isin(unseen_sources)]['Metadata_Batch'].unique(), 5, replace=False)
        test_data = pos_metadata.loc[pos_metadata['Metadata_Batch'].isin(test_batch)]
        train_data = pd.concat([train_data, pos_metadata.loc[~pos_metadata['Metadata_Batch'].isin(test_batch)]])
    else:
        test_data = pos_metadata.loc[pos_metadata['Metadata_Source'].isin(unseen_sources)]
    train_image_num = len(train_data.groupby(['Metadata_Plate','Metadata_Well']))
    test_image_num = len(test_data.groupby(['Metadata_Plate','Metadata_Well']))
    train_batch_num = len(train_data['Metadata_Batch'].unique())
    test_batch_num = len(test_data['Metadata_Batch'].unique())
    train_plate_num = len(train_data['Metadata_Plate'].unique())
    test_plate_num = len(test_data['Metadata_Plate'].unique())
    print(f"""
          Train well # : {train_image_num}
          Test well # : {test_image_num}
          Train batch # : {train_batch_num}
          Test batch # : {test_batch_num}
          Train plate # : {train_plate_num}
          Test plate # : {test_plate_num}
    """)
    return train_data, test_data
train_data, test_data = sample_train_test_OOD(pos_metadata, SEEN_SOURCE, UNSEEN_SOURCE, knn=False)

In [None]:
# TGT2 dataset

target2 = metadata.loc[(metadata['Metadata_PlateType'] == 'TARGET2')]
selected_compound = pd.read_csv("Selected_Compound_4.txt", delimiter="\t").sort_values(by=['Average Rank'], ascending=False)
target2_compound = set(target2["Metadata_JCP2022"].unique().tolist())
unseen_compound = []
cnt = 0
for comp in selected_compound["CompoundID"].tolist():
    if comp not in target2_compound:
        unseen_compound.append(comp)
        cnt += 1
    if cnt == 184:
        break
        
unseen_data = metadata.loc[metadata['Metadata_JCP2022'].isin(unseen_compound)]
# Add all control in same plate for post-processing only
unseen_plates = metadata["Metadata_Plate"].unique()
unseen_data_with_control = pd.concat([unseen_data, metadata.loc[(metadata["Metadata_SMILES"] == CONTROL) & metadata["Metadata_Plate"].isin(unseen_plates)]])

def sample_train_test_ID_tgt2(tgt2_metadata):
    test_data = pd.DataFrame()
    # Batch Stratification
    test_batches = []
    for s in tgt2_metadata['Metadata_Source'].unique():
        test_batch = np.random.choice(tgt2_metadata.loc[tgt2_metadata['Metadata_Source'] == s]['Metadata_Batch'].unique(), 2, replace=False).tolist()
        test_data = pd.concat([test_data,  tgt2_metadata.loc[tgt2_metadata['Metadata_Batch'].isin(test_batch)]])
        test_batches.extend(test_batch)
    train_data = tgt2_metadata.loc[~tgt2_metadata['Metadata_Batch'].isin(test_batches)]
    train_image_num = len(train_data.groupby(['Metadata_Plate','Metadata_Well']))
    test_image_num = len(test_data.groupby(['Metadata_Plate','Metadata_Well']))
    train_batch_num = len(train_data['Metadata_Batch'].unique())
    test_batch_num = len(test_data['Metadata_Batch'].unique())
    train_plate_num = len(train_data['Metadata_Plate'].unique())
    test_plate_num = len(test_data['Metadata_Plate'].unique())
    print(f"""
          Train well # : {train_image_num}
          Test well # : {test_image_num}
          Train batch # : {train_batch_num}
          Test batch # : {test_batch_num}
          Train plate # : {train_plate_num}
          Test plate # : {test_plate_num}
    """)
    return train_data, test_data
train_data, test_data = sample_train_test_ID_tgt2(target2)

def sample_train_test_OOD_tgt2(tgt2_metadata, unseen_metadata, knn=False):
    test_data = pd.DataFrame()
    # Batch Stratification
    if knn:
        for comp in unseen_metadata["Metadata_JCP2022"].unique():
            selected = unseen_metadata.loc[unseen_metadata["Metadata_JCP2022"] == comp]
            temp = selected.groupby(['Metadata_Batch', 'Metadata_Plate','Metadata_Well'])
            a=np.arange(temp.ngroups)
            np.random.shuffle(a)
            test_data = pd.concat([test_data, selected[temp.ngroup().isin(a[:2])]])
            # test_data = pd.concat([test_data, unseen_metadata.loc[unseen_metadata["Metadata_JCP2022"] == comp].groupby(['Metadata_Batch', 'Metadata_Plate','Metadata_Well']).sample(n=1)])
        train_data = pd.concat([tgt2_metadata, unseen_metadata.merge(test_data, indicator=True,how='outer').loc[lambda x : x['_merge']=='left_only'].copy()])
    else:
        train_data = tgt2_metadata
        test_data = unseen_metadata

    train_image_num = len(train_data.groupby(['Metadata_Plate','Metadata_Well']))
    test_image_num = len(test_data.groupby(['Metadata_Plate','Metadata_Well']))
    train_batch_num = len(train_data['Metadata_Batch'].unique())
    test_batch_num = len(test_data['Metadata_Batch'].unique())
    train_plate_num = len(train_data['Metadata_Plate'].unique())
    test_plate_num = len(test_data['Metadata_Plate'].unique())
    print(f"""
          Train well # : {train_image_num}
          Test well # : {test_image_num}
          Train batch # : {train_batch_num}
          Test batch # : {test_batch_num}
          Train plate # : {train_plate_num}
          Test plate # : {test_plate_num}
    """)
    return train_data, test_data, unseen_metadata.merge(test_data, indicator=True,how='outer').loc[lambda x : x['_merge']=='left_only'].copy()
train_data, test_data, check = sample_train_test_OOD_tgt2(target2, unseen_data, knn=True)