## Experiment 1: cellTypeSpecific

This experiment studies different options for the base GRN provided to CellOracle, similar to **baseNetwork**. 
This is a different specific question and a different setup: within a given source (e.g. ANANSE), does the Ko lab ESC data work best with an ESC-specific network structure? 

In [1]:
EXPERIMENT_NAME="baseNetwork_v1"

In [2]:
import warnings
warnings.filterwarnings('ignore')
import importlib
import os
import gc
import re
import sys
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import celloracle as co

#      visualization settings
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
plt.rcParams['figure.figsize'] = [6, 4.5]
plt.rcParams["savefig.dpi"] = 300

In [3]:
# Deal with various file paths specific to this project
PROJECT_PATH = '/home/ekernf01/Desktop/jhu/research/projects/perturbation_prediction/cell_type_knowledge_transfer/'
os.chdir(PROJECT_PATH + "benchmarking/")
try:
    os.makedirs("results/" + EXPERIMENT_NAME)
except:
    pass

sys.path.append(os.path.expanduser(PROJECT_PATH + 'networks/load_networks'))
sys.path.append(os.path.expanduser(PROJECT_PATH + 'perturbations/load_perturbations')) 
sys.path.append(os.path.expanduser(PROJECT_PATH + 'benchmarking/evaluator')) 
import evaluator
import load_networks
import load_perturbations
importlib.reload(evaluator) 
importlib.reload(load_networks) 
importlib.reload(load_perturbations)
os.environ["GRN_PATH"]           = PROJECT_PATH + "networks/networks"
os.environ["PERTURBATION_PATH"]  = PROJECT_PATH + "perturbations/perturbations"

Loading prebuilt promoter base-GRN. Version: hg19_gimmemotifsv5_fpr2
Loading prebuilt promoter base-GRN. Version: hg19_gimmemotifsv5_fpr2


### Networks setup

This experiment aims to test a variety of published sparse regulatory network structures. 

In [None]:
networks = {
    'dense': evaluator.makeRandomNetwork(density = 1),
    'ANANSE_all': evaluator.networkEdgesToMatrix(load_networks.load_grn_all_subnetworks("ANANSE"))
}

In [None]:
for subnetwork in load_networks.list_subnetworks("ANANSE"):
    print("Loading " + network)
    if not network in networks:
        networks[network] = evaluator.networkEdgesToMatrix(load_networks.load_grn_by_subnetwork("ANANSE", network))
    gc.collect()
    
# One more network used in a reprogramming-related project
networks["mogrify"] = pd.concat([networks[n] for n in ['MARA_FANTOM4','STRING']])

network_sizes = pd.DataFrame({bn:evaluator.countMatrixEdges(networks[bn]) for bn in networks}, index = ["numEdges"])
network_sizes = network_sizes.T.reset_index().rename({"index":"network"}, axis = 1)
network_sizes

### Memory consumption

This experiment has been a little problematic recently in terms of memory consumed. We can check on that briefly. 

In [None]:
pd.DataFrame({bn:sys.getsizeof(networks[bn])/1e6 for bn in networks}, index = ["memory"])

### Data setup

We use the Nakatake et al data. This experiment is on per-cluster versus shared regression models, so we run Leiden clustering at many different resolutions.

In [None]:
ko_lab_esc_data = sc.read_h5ad(os.environ["PERTURBATION_PATH"] + "/nakatake/" + "test.h5ad")

In [None]:
ko_lab_esc_data.obs.columns

In [None]:
allowedRegulators = set.intersection(*[set(networks[key].columns) for key in networks])
ko_lab_esc_data_train, ko_lab_esc_data_heldout, perturbationsToPredict = \
    evaluator.splitData(ko_lab_esc_data, allowedRegulators, minTestSetSize=250)

### Experimental metadata

In [None]:
n_networks = len(networks.keys())
experiments = pd.DataFrame({"network":[n for n in networks.keys()], 
                            "p":[1]*n_networks,
                            "threshold_number":[int(network_sizes['numEdges'].max())]*n_networks,
                            "pruning":["none"]*n_networks})
experiments["index"] = experiments.index
experiments.to_csv("results/" + EXPERIMENT_NAME + "/networkExperiments.csv")
experiments

In [None]:
predictions = {
    i: evaluator.trainCausalModelAndPredict(expression=ko_lab_esc_data_train,
                                  baseNetwork=networks[experiments.loc[i,'network']],
                                  memoizationName="results/" + EXPERIMENT_NAME + "/" + str(i) + ".celloracle.oracle", 
                                  perturbations=perturbationsToPredict,
                                  clusterColumnName = ,
                                  pruningParameters = {"p":experiments.loc[i,'p'], 
                                                       "threshold_number":experiments.loc[i,'threshold_number']}) 
    for i in experiments.index
}


In [None]:
predictions[0]

### Evaluation

We compute the correlation of the predictions with held-out perturbations.

In [None]:
controlIndex = ko_lab_esc_data_train.obs["perturbation"]=="Control"
evaluationResults = {}
for i in predictions:
    evaluationResults[i] = \
        evaluateCausalModel(ko_lab_esc_data_heldout, 
                            predictions[i],   
                            baseline = ko_lab_esc_data_train[controlIndex,:].X.mean(axis=0),     
                            doPlots=False)[0]
    evaluationResults[i]["index"] = i
evaluationResults = pd.concat(evaluationResults)
evaluationResults = evaluationResults.merge(experiments, how = "left")
evaluationResults = pd.DataFrame(evaluationResults.to_dict())
evaluationResults.head()

In [None]:
noPredictionMade = evaluationResults.iloc[[x==0 for x in evaluationResults["spearman"]],:]['perturbation']
noPredictionMade = set(noPredictionMade)
noPredictionMade
evaluationResults["somePredictionRefused"] = evaluationResults["perturbation"].isin(noPredictionMade) 
evaluationResults.to_csv("../results/"+ EXPERIMENT_NAME +"/networksExperimentEvaluation.csv")
evaluationResults.head()

In [None]:
baseNetworkComparisonFigure = sns.FacetGrid(evaluationResults[~evaluationResults['somePredictionRefused']], 
                                            col = 'pruning',
                                            sharey = False, 
                                            height=5, 
                                            aspect=1).set(title = "Performance")
baseNetworkComparisonFigure.map(sns.violinplot, "spearman", "network", 
                                palette=["r", "b", "k", "y", "g"]
                               ).add_legend()
baseNetworkComparisonFigure.set(ylabel="Spearman correlation\nminus average over all methods")
plt.show()

In [None]:
summary = evaluationResults[~evaluationResults['somePredictionRefused']]
summary = summary.groupby(["pruning", "network"]).mean()[["spearman"]].reset_index(["pruning", "network"])
summary = summary.merge(network_sizes)
summary.sort_values(['pruning', 'network'], inplace=True)
summary.to_csv("../results/" + EXPERIMENT_NAME + "/networksExperimentEvaluationSummary.csv")
print(summary)
baseNetworkComparisonFigureCompact = sns.scatterplot(data=summary[[p!="harsh" for p in summary["pruning"]]],
                x='numEdges',
                y='spearman', 
                hue='network')
baseNetworkComparisonFigureCompact.set_xscale("log")
baseNetworkComparisonFigureCompact.set(title="Density vs performance")
baseNetworkComparisonFigureCompact.legend(loc='center left', bbox_to_anchor=(1, 0.5))
