In [1]:
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from utils import get_cell_type_compound_gene

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [2]:
def get_data(df, cell_type_embedding, compound_embedding, gene_embedding, cell_types, compounds, genes, get_target=True):
    n_samples = len(df)
    n_features = cell_type_embedding.shape[1] + compound_embedding.shape[1] + gene_embedding.shape[1]
    
    # get features
    x = np.zeros((n_samples, n_features))
    cell_type_idxs = np.zeros(n_samples)
    compound_idxs = np.zeros(n_samples)
    gene_idxs = np.zeros(n_samples)
    for i in tqdm(range(n_samples)):
        cell_type_idx = cell_types.index(df['cell_type'][i])
        compound_idx = compounds.index(df['sm_name'][i])
        gene_idx = genes.index(df['gene'][i])

        cell_type_vec = cell_type_embedding[cell_type_idx]
        compound_vec = compound_embedding[compound_idx]
        gene_vec = gene_embedding[gene_idx]

        x[i] = torch.concat([cell_type_vec, compound_vec, gene_vec])
        cell_type_idxs[i] = cell_type_idx
        compound_idxs[i] = compound_idx
        gene_idxs[i] = gene_idx
        
            
    if get_target:
        y = np.zeros(n_samples)
        for i in range(n_samples):
            y[i] = df['target'][i]
            
        return x, y, cell_type_idxs, compound_idxs, gene_idxs
    else:
        return x, cell_type_idxs, compound_idxs, gene_idxs

In [3]:
cell_types, compounds, genes = get_cell_type_compound_gene()

cell_type_names = {'NK cells': 'nk',
                   'T cells CD4+': 't_cd4',
                   'T cells CD8+': 't_cd8',
                   'T regulatory cells': 't_reg'}

In [4]:
if not os.path.exists('../../results/PerturbNet/deep_tf_v2'):
    os.makedirs('../../results/PerturbNet/deep_tf_v2')

In [5]:
# get embedding for cell type, compound, and gene
state_dict = torch.load('/data/pinello/PROJECTS/2023_08_ZL/kaggle_scp/model/deep_tf_v2/model.pth')
cell_type_embedding = state_dict['state_dict']['cell_type_embedding.weight'].cpu()
compound_embedding = state_dict['state_dict']['compound_embedding.weight'].cpu()
gene_embedding = state_dict['state_dict']['gene_embedding.weight'].cpu()

In [6]:
df = pd.read_parquet('/data/pinello/PROJECTS/2023_08_ZL/kaggle_scp/data/de_train.parquet')

In [7]:
for key, cell_type in cell_type_names.items():
    print(cell_type)
    
    df_train = pd.read_csv(f'../../results/PerturbNet/splited_data/train_{cell_type}.csv')
    df_valid = pd.read_csv(f'../../results/PerturbNet/splited_data/valid_{cell_type}.csv')
    
    # training data
    x, y, cell_type_idxs, compound_idxs, gene_idxs  = get_data(df=df_train,
                                                               cell_type_embedding=cell_type_embedding,
                                                               compound_embedding=compound_embedding,
                                                               gene_embedding=gene_embedding,
                                                               cell_types=cell_types,
                                                               compounds=compounds,
                                                               genes=genes)
    
    np.savez(f'../../results/PerturbNet/deep_tf_v2/train_{cell_type}.npz', 
             x=x, y=y, 
             cell_types=cell_type_idxs,
             compounds=compound_idxs,
             genes=gene_idxs)
    
    # validation data
    x, y, cell_type_idxs, compound_idxs, gene_idxs = get_data(df=df_valid,
                                                              cell_type_embedding=cell_type_embedding,
                                                              compound_embedding=compound_embedding,
                                                              gene_embedding=gene_embedding,
                                                              cell_types=cell_types,
                                                              compounds=compounds,
                                                              genes=genes)
    np.savez(f'../../results/PerturbNet/deep_tf_v2/valid_{cell_type}.npz', 
             x=x, y=y,
             cell_types=cell_type_idxs,
             compounds=compound_idxs,
             genes=gene_idxs)

nk


100%|██████████| 9742885/9742885 [21:32<00:00, 7539.82it/s] 
100%|██████████| 1438669/1438669 [03:11<00:00, 7528.54it/s]


t_cd4


100%|██████████| 9742885/9742885 [21:36<00:00, 7512.28it/s] 
100%|██████████| 1438669/1438669 [03:12<00:00, 7482.04it/s]


t_cd8


100%|██████████| 9779307/9779307 [21:39<00:00, 7525.33it/s] 
100%|██████████| 1402247/1402247 [03:07<00:00, 7472.46it/s]


t_reg


100%|██████████| 9742885/9742885 [21:39<00:00, 7499.23it/s] 
100%|██████████| 1438669/1438669 [03:12<00:00, 7455.61it/s]


In [8]:
df_test = pd.read_csv(f'../../results/PerturbNet/splited_data/test.csv')

x, cell_type_idxs, compound_idxs, gene_idxs = get_data(
    df = df_test,
    cell_type_embedding=cell_type_embedding,
    compound_embedding=compound_embedding,
    gene_embedding=gene_embedding,
    cell_types=cell_types,
    compounds=compounds,
    genes=genes,
    get_target=False)  

np.savez(f'../../results/PerturbNet/deep_tf_v2/test.npz', 
         x=x, 
         cell_types=cell_type_idxs, 
         compounds=compound_idxs, 
         genes=gene_idxs)

100%|██████████| 4643805/4643805 [10:17<00:00, 7520.74it/s] 
