## Imports

In [1]:
import pandas as pd
import plotly.graph_objects as go
from pathlib import Path
import logging


## Variables

In [2]:
ORGANISMS = ["human", "mouse"]
MIRNAS = ["mir197", "mir769"]
PROJECTDIR = Path('.').resolve().parents[2]

## Paths

In [3]:
targetpath = f'{PROJECTDIR}/milestones/data/targetscan/unweighted_TargetScanHuman_human_mouse_mir197_mir769.tsv'
outdir = ''
differential_genes_path = '{}/milestones/data/rnaseq/{}_{}_differential_genes.tsv'
rnaseq_path = '{}/external_data/counts_matrix/{}_counts.matrix'


## Collect data

In [9]:
from rodent_loss_src.rnaseq import split_up_down_regulated
from rodent_loss_src.doapr import enumerate_overlap, count_targets

col = []
tardf = pd.read_csv(targetpath, sep='\t')
for organism in ORGANISMS:
    for mirna in MIRNAS:
        condition = f'{organism}|{mirna}'
        # find condition specific targets
        target_genes = count_targets(tardf, organism, mirna, cutoff)
        # parse and split differenetial genes
        up, down = split_up_down_regulated(differential_genes_path.format(PROJECTDIR, organism, mirna))
        # count targets
        observed_targets_up = enumerate_overlap(up, target_genes)
        observed_targets_down = enumerate_overlap(down, target_genes)
        
        # collect sankey data
        col.append([len(up), 'up_regulated', organism, mirna, f'{organism}_up'])
        col.append([len(down), 'down_regulated', organism, mirna, f'{organism}_down'])
        col.append([observed_targets_up, 'up_targets', organism, mirna, f'{organism}_up_targets'])
        col.append([observed_targets_down, 'down_targets', organism, mirna, f'{organism}_down_targets'])

        
sankey = pd.DataFrame(col, columns=['target_count', 'type', 'organism', 'mirna', 'label'])
display(sankey)


Unnamed: 0,target_count,type,organism,mirna,label
0,85,up_regulated,human,mir197,human_up
1,109,down_regulated,human,mir197,human_down
2,2,up_targets,human,mir197,human_up_targets
3,24,down_targets,human,mir197,human_down_targets
4,25,up_regulated,human,mir769,human_up
5,78,down_regulated,human,mir769,human_down
6,1,up_targets,human,mir769,human_up_targets
7,24,down_targets,human,mir769,human_down_targets
8,101,up_regulated,mouse,mir197,mouse_up
9,115,down_regulated,mouse,mir197,mouse_down


## Plot

In [22]:

def mirna_specific_dict(df, mirna):
    d = {}
    df = df[df.mirna == mirna]
    for organism in ORGANISMS:
        upregulated = df.target_count[df.label == f'{organism}_up'].iloc[0]
        downregulated = df.target_count[df.label == f'{organism}_down'].iloc[0]
        primary = df.target_count[df.label == f'{organism}_down_targets'].iloc[0]
        up_targets = df.target_count[df.label == f'{organism}_up_targets'].iloc[0]
        down_secondary = downregulated - primary
        secondary = upregulated + down_secondary
        
        d[f'{organism}_up'] = upregulated
        d[f'{organism}_down'] = downregulated
        d[f'{organism}_primary'] = primary
        d[f'{organism}_up_targets'] = up_targets
        d[f'{organism}_down_secondary'] = down_secondary
        d[f'{organism}_all_secondary'] = secondary
        
    return d
    

def plot_sankey(df, mirna):
    label2count = mirna_specific_dict(df, mirna)
    v = list(label2count.values())
    
    for i, labelcount in enumerate(label2count.items()):
        label, count = labelcount
        print(i, label, count)
    # display(label2count)
    
    fig = go.Figure(data=[go.Sankey(
    node = dict(
        pad = 50,
        thickness = 15,
        line = dict(color = "black", width = 0.5),
        label = list(label2count.keys()),
        align='left',
        # x = [0.1, 0.1, 0.1, 0.1, 0.5, 0.5, 0.5, 0.5, 0.7],
        # y = [0.1, 0.3, 0.6, 0.8, 0.2, 0.4, 0.5, 0.7, 0.5],
        color = "blue"
    ),
    link = dict(
      source = [
          0, # up -> all_secondary (targets)
          0, # up -> all_secondary (non- targets)
          1, # down -> primary
          1, # down -> down_secondary
          
          0+6, # up -> all_secondary (targets)
          0+6, # up -> all_secondary (non- targets)
          1+6, # down -> primary
          1+6, # down -> down_secondary
          ],
      target = [
          3, # up -> all_secondary (targets)
          5, # up -> all_secondary (non- targets)
          2, # down -> primary
          5, # down -> down_secondary
          
          3+6, # up -> all_secondary (targets)
          5+6, # up -> all_secondary (non- targets)
          2+6, # down -> primary
          5+6, # down -> down_secondary
          ],
      value = [
          v[3], # up -> all_secondary (targets)
          v[0]-v[3], # up -> all_secondary (non- targets)
          v[2], # down -> primary
          v[4], # down -> down_secondary
          
          v[3+6], # up -> all_secondary (targets)
          v[0+6]-v[3+6], # up -> all_secondary (non- targets)
          v[2+6], # down -> primary
          v[4+6], # down -> down_secondary
          ]
      )
    )
                          ]
                    )
    
    
    
    fig.update_layout(title_text=mirna, font_size=10, height=600, width=250)
    fig.write_image(f'../results/{mirna}_sankey.svg')
    return fig
    
fig197 = plot_sankey(sankey, 'mir197')
fig197.show()

fig769 = plot_sankey(sankey, 'mir769')
fig769.show()

0 human_up 85
1 human_down 109
2 human_primary 24
3 human_up_targets 2
4 human_down_secondary 85
5 human_all_secondary 170
6 mouse_up 101
7 mouse_down 115
8 mouse_primary 20
9 mouse_up_targets 5
10 mouse_down_secondary 95
11 mouse_all_secondary 196


0 human_up 25
1 human_down 78
2 human_primary 24
3 human_up_targets 1
4 human_down_secondary 54
5 human_all_secondary 79
6 mouse_up 76
7 mouse_down 152
8 mouse_primary 17
9 mouse_up_targets 1
10 mouse_down_secondary 135
11 mouse_all_secondary 211
