In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import seaborn as sns
from sklearn.metrics import roc_curve, roc_auc_score, auc, precision_recall_curve
from sklearn.model_selection import train_test_split
import statsmodels.api as sm
from tqdm import tqdm
from functools import reduce

from src.config import *
from src.prediction_functions import *
from src.model_creation import *


# Functions

In [3]:
def set_perturabtion_type_values(perturbation, pert_type_lists):
    if perturbation in pert_type_lists['compound_list']: perturbation_type = 'cp'
    elif perturbation in pert_type_lists['ligand_list']: perturbation_type = 'lig'
    elif perturbation in pert_type_lists['receptor_list']: perturbation_type = 'rec'
    else: perturbation_type = np.nan
    return perturbation_type

In [4]:
def determine_type_of_perturbation_of_samples(data_index, data_index_perturbation, pert_type_lists):
    pert_types = pd.DataFrame(columns = ['perturbation', 'perturbation_type'], dtype='str')
    pert_types['perturbation'] = data_index_perturbation
    pert_types['perturbation_type'] = pert_types['perturbation'].apply(lambda x: set_perturabtion_type_values(x, pert_type_lists)) 
    pert_types.index = data_index
    return pert_types

# Read in data

##### RIDDEN

In [5]:
lincs_data = pd.read_csv(f'data/lincs_consensus/high_quality/lm_all_pert_cell_liana.csv', index_col = 0)
lincs_design_matrix = pd.read_csv(f'data/design_matrices/high_quality/all_pert_binary_liana.csv', index_col = 0)

##### CytoSig

In [6]:
# data
cytosig_data = pd.read_csv(f'data/cytosig/diff.merge.gz', compression='gzip', sep = '\t')
cytosig_data = cytosig_data.T
cytosig_data = cytosig_data.fillna(0)

In [7]:
# predictions
# cacluate cytosig on lincs inferred genes consensus
# Created by running CytoSig_run.py -i input.csv -o output_filename
cytosig = {}
cytosig_lincs = pd.read_table(f'results/benchmark/cytosig_prediction_inferred_signature.Coef', sep = '\t')
cytosig['lincs'] = cytosig_lincs.T

In [8]:
lr_associations = pd.read_csv(LIG_REC_DF, index_col = 0)
lr_associations = lr_associations[['source_genesymbol', 'target_genesymbol']]

In [9]:
compound_info = pd.read_csv('data/filtered_lincs_meta/filtered_coumpound_info_to_receptor_perturbation_signatures_signed.csv', index_col =0)
compound_info = compound_info[['cmap_name', 'target', 'sign']].reset_index(drop=True)

## Data preparation

In [10]:
# Create perturbation metadata series
lincs_compounds_list = list(compound_info.cmap_name.unique())
lincs_ligand_list = list(lincs_data.index[(lincs_data.reset_index()['index'].str.split('_', expand = True)[2] == 'lig')].str.split('_', expand = True).get_level_values(0).unique())

In [11]:
pert_type_lists = {}
pert_type_lists['compound_list'] = lincs_compounds_list
pert_type_lists['ligand_list'] = list(lr_associations.source_genesymbol) + lincs_ligand_list
pert_type_lists['receptor_list'] = list(lr_associations.target_genesymbol)


# Matching data
- translating from ligand to receptor and vica versa

In [12]:
# lincs
tmp = lincs_data.index.str.split('_',expand = True).get_level_values(0)
lincs_pert_types = determine_type_of_perturbation_of_samples(lincs_data.index, tmp, pert_type_lists)

In [13]:
def get_target_list_cp(sample: pd.Series):
    drug = sample.perturbation
    target_dict = compound_info[compound_info['cmap_name'] == drug][['target', 'sign']].set_index('target')['sign'].to_dict()
    return target_dict

def get_target_list_lig(sample: pd.Series):

    receptors = list(lr_associations[lr_associations.source_genesymbol == sample.perturbation]['target_genesymbol'])
    if len(receptors) == 0:
        return np.nan
    
    # sample name PERT_CELL_PERTTYPE (eg. ACVR1_MCF7_oe)
    if (sample.name.split('_')[2] == 'oe'):
        target_dict = dict(zip(receptors, [1]*len(receptors)))
        return target_dict

    # sample name PERT_CELL_PERTTYPE (eg. ACVR1_MCF7_xpr)
    if (sample.name.split('_')[2] == 'xpr'):
        target_dict = dict(zip(receptors, [-1]*len(receptors)))
        return target_dict
    
    # sample name PERT_CELL_PERTTYPE (eg. ACVR1_MCF7_sh)
    if (sample.name.split('_')[2] == 'sh'):
        target_dict = dict(zip(receptors, [-1]*len(receptors)))
        return target_dict

    return np.nan

def get_target_list_rec(sample: pd.Series):
    # if receptor perturbation and overexpression add receptor to target dictionary with sign 1 
    # sample name PERT_CELL_PERTTYPE (eg. RPS19_MCF7_oe)
    if (sample.name.split('_')[2] == 'oe'):
        target_dict = {sample.perturbation:1}
        return target_dict

    # sample name PERT_CELL_PERTTYPE (eg. RPS19_MCF7_xpr)
    if (sample.name.split('_')[2] == 'xpr'):
        target_dict = {sample.perturbation:-1}
        return target_dict
    
    # sample name PERT_CELL_PERTTYPE (eg. RPS19_MCF7_sh)
    if (sample.name.split('_')[2] == 'sh'):
        target_dict = {sample.perturbation:-1}
        return target_dict

    return np.nan
    
def add_target_to_compound_lincs(sample):
    # if compoound perturbation add targets to target dictionary
    if sample.perturbation_type == 'cp':
       return get_target_list_cp(sample)

     # if ligand perturbation - fill return with receptor targets of ligands with sign (use lr_associations (LIANA))
    if sample.perturbation_type == 'lig':
        return get_target_list_lig(sample)
            
    if sample.perturbation_type == 'rec':
        return get_target_list_rec(sample)

    return np.nan
   

In [14]:
lincs_pert_types['signed_interactions_rec'] = lincs_pert_types.apply(lambda x: add_target_to_compound_lincs(x), axis = 1)


In [15]:
def filter_nan(dict):
  return {key: value for key, value in dict.items() if pd.notna(key) and pd.notna(value)}

In [21]:
def fill_signed_interactions_ligand(rec, sign):
    ligands = list(lr_associations[lr_associations['target_genesymbol'] == rec].source_genesymbol)
    if len(ligands) > 0:
        ligands = dict(zip(ligands, [sign]*len(ligands)))
        return ligands
    else:
        return {np.nan: np.nan}

In [22]:
def translate_receptor_to_ligand(sample_row):
    ligand_dict = dict(map(lambda sample: (sample[0], fill_signed_interactions_ligand(sample[0], sample[1])), sample_row.items()))
    flat_ligand_dict = reduce(lambda all, current: {**all, **filter_nan(current)}, ligand_dict.values(), {})
    if flat_ligand_dict == {}:
        return np.nan
    return flat_ligand_dict

In [27]:
lincs_pert_types['signed_interactions_lig'] = lincs_pert_types.apply(\
    lambda sample: translate_receptor_to_ligand(sample.signed_interactions_rec)\
        if sample.perturbation_type!='lig' else {sample.perturbation:1}, axis = 1)

In [39]:
lincs_pert_types.to_csv('results/benchmark/lincs_translate_to_ligands.csv')