# UNAGI walkthrough
A walkthrough of UNAGI pipeline, from loading the example data, to train UNAGI model, perform in-silico perturbation to results visualization.
## Install UNAGI from pypi
The installation takes around 5 minutes, however, it depends on your download speed.  

In [None]:
!pip install scUNAGI

## Use the example dataset

In [None]:
!mkdir data
!cp -r ../../../UNAGI/data/example data/
!git clone https://github.com/phoenixding/idrem.git

## Load UNAGI 

In [None]:
import warnings
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
unagi = UNAGI()

## load the dataset
The example dataset contains 4 stages, and the attribute to indicate the stage id is 'stage', the cell-type annotation attribute is 'name.simple'

In [None]:
unagi.setup_data('./data/example',total_stage=4,stage_key='stage')

## Configure the model architecture of UNAGI and training hyper-parameters

In [None]:
unagi.setup_training('example',dist='ziln',device='cuda:0',GPU=True,epoch_iter=1,epoch_initial=1,max_iter=2,BATCHSIZE=560)
unagi.run_UNAGI(idrem_dir = './idrem')

In [None]:
import warnings
warnings.filterwarnings('ignore')
from UNAGI import UNAGI
unagi = UNAGI()
unagi.analyse_UNAGI('./data/example/1/stagedata/dataset.h5ad',iteration=1,progressionmarker_background_sampling_times=1,run_pertubration=False)

## Load the post-analysis data

In [None]:
import pickle
import scanpy as sc
from UNAGI import plotting
adata = sc.read_h5ad('./example_1/dataset.h5ad')
adata.uns = pickle.load(open('./example_1/attribute.pkl', 'rb'))

## Plot the composition of cells in individual stages

In [None]:
plotting.cell_type_composition(adata,'ident','stage',dpi=80)

## Plot the UMAPS of cell embeddings for individual stages

In [None]:
plotting.plot_stages_latent_representation(adata,'ident',stage_key='stage',dpi=80)

## Plot dotplots of cell type markers

In [None]:
plotted_genes = ['IGF1', 'NLGN1', 'ROBO2', 'SLIT2', 'EGFEM1P', 'LINC02388', 'DACH2', 'FGF14', 'ITGBL1', 'MMRN1', 'CCL21', 'ADARB2', 'LRRTM4', 'ZNF385D', 'PRUNE2', 'LDB2', 'AQP1', 'FLT1', 'SLCO2A1', 'EPAS1', 'PCDH17']
for i in adata.obs['stage'].unique():
    stageadata = adata[adata.obs['stage']==i]
    sc.pl.dotplot(stageadata, groupby='ident', var_names=plotted_genes,vmax=2,swap_axes=True)

## Plot increasing gene markers and decreasing gene markers from a spcecific track
Note: Some tracks might not have increasing or decreasing genes

In [None]:
import matplotlib.pyplot as plt
import pickle
import pandas as pd
import os
import json
import numpy as np
import scanpy as sc
from cycler import cycler
def readIdremJson(path, filename):
    # print('getting Target genes from ', filename)

    path = os.path.join(path,filename,'DREM.json')
    f=open(path,"r")
    lf=f.readlines()
    f.close()
    lf="".join(lf)
    lf=lf[5:-2]+']'
    tt=json.loads(lf,strict=False)
    return tt
def readIdremJson(path, filename):
    # print('getting Target genes from ', filename)

    path = os.path.join(path,filename,'DREM.json')
    f=open(path,"r")
    lf=f.readlines()
    f.close()
    lf="".join(lf)
    lf=lf[5:-2]+']'
    tt=json.loads(lf,strict=False)
    return tt
def getvalueofMarkers(idrem,filename, gene):
    tt = readIdremJson(idrem,filename)
    temp = np.array(tt[8])
    idrem_genes = np.array(temp[1:,0].tolist())
    tendency = temp[1:,4].astype(float)* temp[1:,3].astype(float) * temp[1:,2].astype(float) * temp[1:,1].astype(float)
    tendency[tendency <0] = 0
    index = [i for i, x in enumerate(tendency) if x <= 0]
    genenames = temp[1:,0].tolist()
    gene_idx = genenames.index(gene)
    change = temp[1:,4].astype(float) - temp[1:,1].astype(float)
    stage0 = temp[1:,1].astype(float)
    stage1 = temp[1:,2].astype(float)-stage0
    stage2 = temp[1:,3].astype(float)-stage0
    stage3 = temp[1:,4].astype(float)-stage0
    # stage0 = 
    #return [stage1[gene_idx],stage2[gene_idx],stage3[gene_idx]]
    return [0,stage1[gene_idx],stage2[gene_idx],stage3[gene_idx] ]#[change[gene_idx]]
def getvaluesFromIDREM(path,genes,tracks,target_track):
    out = {}
    for gene in genes:
        out[gene] = []
        filenames = tracks
        for each in filenames:
            if each == target_track:
                name = each+'.txt_viz'
                if each[0] != '.':
                    each = each.split('.')[0]#.split('-')[-1].split('n')
                    out[gene]+=getvalueofMarkers(path,name,gene.split('\\')[0])
    return out



######
pm = adata.uns['progressionMarkers']
target_track = '12-13-15-12' # use a track from ./example_1/idrem folder, use, x-xnx-x-x format
increasing = []
decreasing = []
genes_increasing = []
genes_decreasing = []
plt.rcParams['axes.prop_cycle'] = plt.cycler("color", plt.cm.tab20.colors)
for each in pm.keys():
    if each == target_track:
        
        tt = readIdremJson('./example_1/idrem',each+'.txt_viz')
        tt = readIdremJson('./example_1/idrem',each+'.txt_viz')
        scope = []
        for i, gene in enumerate(tt[0][6]['genesInNode']):
            if gene:
                scope.append(tt[3][i])
        track = pm[each]
        df = pd.DataFrame.from_dict(track['increasing'])
        print(scope)
        temp = list(df['gene'].values)
        go = []
        for gene in scope:
            if gene in list(df['gene'].values):
                go.append(gene)
        # go = np.array(go)
        df = df.loc[df['gene'].isin(go)]
      
        df.sort_values(by=['rank'], inplace=True)
        increasing.append(df.values[:10])
        genes_increasing+=list(df['gene'].values[:10]+'\\'+str(each))
        df = pd.DataFrame.from_dict(track['decreasing'])
        scope = []
        for i, gene in enumerate(tt[0][8]['genesInNode']):
            if gene:
                scope.append(tt[3][i])
        
        go = []
        for gene in scope:
            if gene in list(df['gene'].values):
                go.append(gene)
        # go = np.array(go)
        df = df.loc[df['gene'].isin(go)]
        df.sort_values(by=['rank'], inplace=True)
        decreasing.append(df.values[:10])
        genes_decreasing+=list(df['gene'].values[:10]+'\\'+str(each))

out_increasing = getvaluesFromIDREM('./example_1/idrem',genes_increasing,list(pm.keys()),target_track=target_track)
out_decreasing = getvaluesFromIDREM('./example_1/idrem',genes_decreasing,list(pm.keys()),target_track=target_track)

import matplotlib.pyplot as plt
fig, ax = plt.subplots(1,1,figsize=(4,4),dpi=80)

gene_names = list(out_increasing.keys())
increasing_gene_names = [x.split('\\')[0] for x in gene_names]
ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
plt.setp(ax.get_yticklabels(), rotation=-30, ha="right",rotation_mode="anchor")
plt.setp(ax.get_xticklabels(), rotation=-45, ha="right",rotation_mode="anchor")

im0 = ax.plot(np.array([0,1,2,3]),np.array(list(out_increasing.values())+list(out_decreasing.values())).T)
gene_names = list(out_decreasing.keys())
gene_names = [x.split('\\')[0] for x in gene_names]
ax.legend(increasing_gene_names+gene_names,loc=2,fontsize='xx-small',bbox_to_anchor=(1.08, 1))
ax.tick_params(top=False, bottom=True,
                   labeltop=False, labelbottom=True)
ax.set_xticks(np.array([0,1,2,3])+0.25,('Control','Stage 1', 'Stage 2','Stage 3'))
ax.set_ylim(-1.5,1)

ax.set_title('Top Dynamic Genes in the FibAlv-4')
ax.set_ylabel('Expression Change (Log2FC)')
ax.set_xlabel('Disease Progression')

# plt.savefig('genes_in_fib4.pdf',bbox_inches='tight', pad_inches=0)
plt.show()

## Plot dendrogram of hierarchical clusterings

In [None]:
plotting.plot_hc_dendrogram(adata,'stage','ident',dpi=80)

## Plot hierarchical static markers heatmap

In [None]:
plotting.hierarchical_static_markers_heatmap(adata,stage=0,cluster=3,level=1,n_genes=20,stage_key='stage',celltype_key='ident')

## Perturabtion using custimized pathway database

In [None]:
!cp ./UNAGI/UNAGI/data/gesa_pathways.npy ./data/gesa_pathways.npy
import numpy as np
built_in_pathway_data = np.load('./data/gesa_pathways.npy',allow_pickle=True).item()

#The keys are the pathway names
print(list(built_in_pathway_data.keys())[:10]) # show first 10 pathways
#The values are the gene sets
print(built_in_pathway_data['BIOCARTA_GRANULOCYTES_PATHWAY'])
customized_pathway_database = {}
customized_pathway_database['Pathway_A'] = ['COL6A3', 'MET', 'COL7A1', 'MMP1', 'COL11A1', 'COL1A2', 'COL5A2', 'COL4A3', 'COL12A1', 'COL10A1', 'COL5A1', 'COL3A1', 'COL4A4', 'COL14A1', 'COL8A1', 'MMP9', 'COL4A1', 'MMP7', 'COL15A1', 'COL1A1', 'COL17A1', 'COL4A6']
customized_pathway_database['Pathway_B'] = ['MAP2','THBS1',]
np.save('./example_1/customized_pathway_database.npy',customized_pathway_database)
from UNAGI import UNAGI
import warnings
warnings.filterwarnings("ignore")
unagi = UNAGI()
data_path = './example_1/dataset.h5ad'
iteration = 1 #which iteration of the model to use
change_level = 0.5 #reduce the expression to 50% of the original value
customized_pathway = './example_1/customized_pathway_database.npy'
results = unagi.customize_pathway_perturbation(data_path,iteration,customized_pathway,change_level,target_dir='./example_1',device='cuda:0')

In [None]:
from UNAGI.perturbations import get_top_pathways
get_top_pathways(results, change_level, top_n=10)

## Customized drug perturbation


In [None]:
import numpy as np
customized_drug_database = {}
customized_drug_database['Drug_A'] = ['TACR1:-', 'MAPK4:+', 'DUSP10:-', 'EGFR:+']
customized_drug_database['Drug_B'] = ['C3:-', 'KDR:+', 'INSR:-','PLA2G2A:-']
customized_drug_database['Drug_C'] = ['COL1A1:-', 'ROBO2:+']
np.save('./example_1/customized_drug_database.npy',customized_drug_database)

In [None]:
from UNAGI import UNAGI
import warnings
warnings.filterwarnings("ignore")
unagi = UNAGI()
data_path = './example_1/dataset.h5ad'
iteration = 1 #which iteration of the model to use
change_level = 0.5 #reduce the expression to 50% of the original value
customized_drug = './example_1/customized_drug_database.npy'
results = unagi.customize_drug_perturbation(data_path,iteration,customized_drug,change_level,target_dir='./example_1/',device='cuda:0')

In [None]:
from UNAGI.perturbations import get_top_compounds
get_top_compounds(results, change_level, top_n=10)