In [1]:
from __future__ import print_function
import torch.utils.data as data
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem


def get_morgan_fingerprint(mol, radius, nBits, FCFP=False):
    m = Chem.MolFromSmiles(mol)
    fp = AllChem.GetMorganFingerprintAsBitVect(m, radius=radius, nBits=nBits, useFeatures=FCFP)
    fp_bits = fp.ToBitString()
    finger_print = np.fromstring(fp_bits, 'u1') - ord('0')
    return finger_print


class DrugGene(data.Dataset):
    def __init__(self, df, down_sample=True, random_seed=0):
        fn = '../data/go_fingerprints_2020.csv'
        gene_map = pd.read_csv(fn)
        self.gene_name = gene_map['gene']
        gene_map = gene_map.drop(columns='gene', axis=1)
        self.gene_map = gene_map.to_numpy()

        fn = '../data/drug_fingerprints-1024.csv'
        fp_map = pd.read_csv(fn, header=None, index_col=0)
        self.fp_name = fp_map.index
        self.fp_map = fp_map.to_numpy()

        self.df = df
        self.random_seed = random_seed
        self.down_sample = down_sample  # training set or test.txt set

        print(df.shape)
        labels = np.asarray(df['label'])
        smiles = df['smiles']
        genes = df['gene']  # be careful, label index need to be reset using np.array
        # quality = np.asarray(df['quality'])
        if self.down_sample:
            idx_in = self.down_sampling(labels)
            smiles = df['smiles'][idx_in]
            genes = df['gene'][idx_in]
            labels = np.asarray(df['label'][idx_in])  # be careful, label index need to be reset using np.array
            # quality = np.asarray(df['quality'][idx_in])

        print("get drug features")
        smiles_feature = self.get_drug_fp_batch(smiles).astype(np.float32)
        print("get gene features")
        genes_feature = self.get_gene_ft_batch(genes).astype(np.float32)
        data = np.concatenate([smiles_feature, genes_feature], axis=1)

        # self.data, self.labels, self.quality = data, labels, quality
        self.data, self.labels = data, labels
        self.genes, self.smiles = genes, smiles

        unique, counts = np.unique(self.labels, return_counts=True)
        print(counts)

        print('data shape:')
        print(self.data.shape)
        print(self.labels.shape)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        return self.data[index], self.labels[index], index

    def __len__(self):
        return len(self.data)

    def down_sampling(self, y):
        unique, counts = np.unique(y, return_counts=True)
        max_idx = np.argmax(counts)
        max_value = unique[max_idx]
        max_counts = counts[max_idx]
        n_select = np.int((np.sum(counts) - max_counts) * 0.5)
        print('max_value, max_counts, n_select')
        print(max_value, max_counts, n_select)

        random.seed(self.random_seed)
        tmp = list(np.where(y == max_value)[0])
        idx_select = random.sample(tmp, k=n_select)
        idx_select.sort()
        idx_select = np.array(idx_select)
        idx_final = np.concatenate([np.where(y == 0)[0], idx_select, np.where(y == 2)[0]])

        return idx_final

    def get_gene_ft_batch(self, gene):
        gene_features = []
        for g in tqdm(gene):
            idx = np.where(self.gene_name == g)[0][0]
            gene_features.append(self.gene_map[idx])
        gene_features = np.array(gene_features)
        # print(gene_features.shape)
        return gene_features

    def get_drug_fp_batch(self, smile):
        fp_features = []
        for s in tqdm(smile):
            # print(s)
            try:
                idx = np.where(self.fp_name == s)[0][0]
                fp_features.append(self.fp_map[idx])
            except:
                print(s)
                fp_features.append(get_morgan_fingerprint(s, 3, 1024, FCFP=False))
        fp_features = np.array(fp_features)
        # print(fp_features.shape)
        return fp_features


class Drug(data.Dataset):
    def __init__(self, df):
        fn = '../data/drug_fingerprints-1024.csv'
        # fn = '../data/ref_gnn_fp-512.csv'
        fp_map = pd.read_csv(fn, header=None, index_col=0)
        self.fp_name = fp_map.index
        self.fp_map = fp_map.to_numpy()

        # print(df.shape)
        labels = np.asarray(df['label']).astype(np.float32)
        smiles = df['smiles'].to_list()
        print("get drug features")
        data = self.get_drug_fp_batch(smiles).astype(np.float32)

        self.data, self.labels = data, labels
        self.smiles, self.df = smiles, df

        print('data shape:')
        print(self.data.shape)
        print(self.labels.shape)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        return self.data[index], self.labels[index], index

    def __len__(self):
        return len(self.data)

    def get_drug_fp_batch(self, smile):
        fp_features = []
        for s in tqdm(smile):
            # print(s)
            try:
                idx = np.where(self.fp_name == s)[0][0]
                fp_features.append(self.fp_map[idx])
            except:
                print(s)
                fp_features.append(get_morgan_fingerprint(s, 3, 1024, FCFP=False))
        fp_features = np.array(fp_features)
        # print(fp_features.shape)
        return fp_features


In [5]:
import pickle
import pandas as pd

input = 'drug'
# input = 'druggene'

with open("anchor_trial_drugs", "rb") as fp:
    anchor_drugs = pickle.load(fp)

print('gene == MYC')
cell_list = ['ASC', 'NPC', 'HCC515', 'HT29', 'A375', 'HA1E', 'VCAP', 'A549', 'PC3', 'MCF7']
df_data = pd.read_csv('/localscratch2/han/Pretrain_LINCS/data/LINCS2020_l5_cmpd_24h_10uM/level5_beta_trt_cp_24h_10uM.csv')
df_data = df_data[df_data['cell_iname'].isin(cell_list)]
# df_data.to_csv('10_cell_line.csv')
for cell in cell_list:
    print('train cell '+cell)
    df_finetune = df_data[df_data['cell_iname'] == cell]
    df_finetune = df_finetune[['SMILES', 'MYC']].groupby(by='SMILES').median().reset_index()
    df_finetune = df_finetune.rename(columns={'MYC': 'label', 'SMILES': 'smiles'})
        
    test_drugs = list(set(df_finetune['smiles'].to_list())-set(anchor_drugs))
    df_fine_train = df_finetune[df_finetune['smiles'].isin(anchor_drugs)]
    df_fine_test = df_finetune[df_finetune['smiles'].isin(test_drugs)]

    if input == 'drug':
        finetune_train_dataset = Drug(df=df_fine_train)
        finetune_test_dataset = Drug(df=df_fine_test)
    elif input == 'druggene':
        finetune_train_dataset = DrugGene(df=df_fine_train, down_sample=False)
        finetune_test_dataset = DrugGene(df=df_fine_test, down_sample=False)

gene == MYC
train cell ASC
get drug features


100%|██████████| 76/76 [00:00<00:00, 2168.29it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 945/945 [00:00<00:00, 2240.49it/s]


data shape:
(945, 1024)
(945,)
train cell NPC
get drug features


100%|██████████| 76/76 [00:00<00:00, 1624.15it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 1158/1158 [00:00<00:00, 1865.50it/s]


data shape:
(1158, 1024)
(1158,)
train cell HCC515
get drug features


100%|██████████| 76/76 [00:00<00:00, 2311.35it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 1301/1301 [00:00<00:00, 2193.86it/s]


data shape:
(1301, 1024)
(1301,)
train cell HT29
get drug features


100%|██████████| 76/76 [00:00<00:00, 2213.92it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 1348/1348 [00:00<00:00, 2351.61it/s]


data shape:
(1348, 1024)
(1348,)
train cell A375
get drug features


100%|██████████| 76/76 [00:00<00:00, 2263.89it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 1959/1959 [00:00<00:00, 2302.90it/s]


data shape:
(1959, 1024)
(1959,)
train cell HA1E
get drug features


100%|██████████| 76/76 [00:00<00:00, 2315.80it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 2056/2056 [00:00<00:00, 2057.60it/s]


data shape:
(2056, 1024)
(2056,)
train cell VCAP
get drug features


100%|██████████| 76/76 [00:00<00:00, 2440.04it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 2105/2105 [00:00<00:00, 2637.36it/s]


data shape:
(2105, 1024)
(2105,)
train cell A549
get drug features


100%|██████████| 76/76 [00:00<00:00, 3108.95it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 2497/2497 [00:00<00:00, 2990.91it/s]


data shape:
(2497, 1024)
(2497,)
train cell PC3
get drug features


100%|██████████| 76/76 [00:00<00:00, 3110.71it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 2984/2984 [00:01<00:00, 2363.01it/s]


data shape:
(2984, 1024)
(2984,)
train cell MCF7
get drug features


100%|██████████| 76/76 [00:00<00:00, 2468.92it/s]

data shape:
(76, 1024)
(76,)





get drug features


100%|██████████| 3178/3178 [00:01<00:00, 2232.33it/s]

data shape:
(3178, 1024)
(3178,)





In [1]:
import pandas as pd

In [14]:
df = pd.read_csv('/egr/research-aidd/menghan1/AnchorDrug/data/CellLineEncode/test_cell_line_expression_features_128_encoded_20240111.csv', index_col=0)
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,118,119,120,121,122,123,124,125,126,127
PC3,-0.312864,1.835685,-0.553374,1.626945,1.882913,-3.021637,-1.74472,0.774464,0.279778,1.372928,...,-1.371159,-1.811128,-2.606264,-0.24849,1.861487,-0.837809,-0.397481,2.38903,1.974789,2.150697
A549,-0.411082,-0.611002,0.75447,-0.360971,-0.377884,0.132923,1.506301,3.744813,-0.209339,1.329737,...,-2.42139,0.237642,0.273738,-1.381825,-0.789053,1.590336,0.346917,1.029464,-1.565292,-1.662287
MCF7,-2.455031,-0.053608,-1.027708,-0.003675,-1.523577,0.937249,1.359096,-4.238229,-0.315477,-2.155535,...,1.893407,0.97538,2.97398,0.56829,-1.163515,-0.803519,0.246548,-2.377404,-1.030902,-2.768969


In [15]:
cell_name = df.index
cell_map = df.to_numpy()

In [16]:
cell_name

Index(['PC3', 'A549', 'MCF7'], dtype='object')

In [17]:
cell_map.shape

(3, 128)