In [13]:
import os
import numpy as np
import math
from itertools import chain
import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from matplotlib import cm
from flask import Flask
from scipy.interpolate import interp1d
from dash import Dash, html, dcc, Input, Output
from cit_vis import *

In [14]:
G_DF = pd.read_csv('data/genus_abundances.csv', index_col=0)
B_DF = pd.read_csv('data/assay_scores.csv', index_col=0)

In [18]:
# Initialize an empty dictionary to store the DataFrames
dfs = {}

# Iterate through the directory and its subdirectories
for root, dirs, files in os.walk('data/', topdown=False):
    for name in files:
        # Check if the file meets the specified conditions
        if ('Assay_' in name) & ('.csv' in name):
            # Extract the assay name from the file name
            assay = name[:-4]
            # Create the full path to the file
            path = os.path.join(root, name)
            # Load the DataFrame
            df = pd.read_csv(path, index_col=0)
            # Store the DataFrame in the dictionary
            dfs[assay] = df

# Get a list of all the assay names
assays = list(dfs.keys())
print(assays)

# Select the first assay from the list
b = assays[0]
#b = assays[14]

['Assay_10_predictions', 'Assay_11_predictions', 'Assay_12_predictions', 'Assay_13_predictions', 'Assay_14_predictions', 'Assay_15_predictions', 'Assay_16_predictions', 'Assay_17_predictions', 'Assay_18_predictions', 'Assay_19_predictions', 'Assay_1_predictions', 'Assay_20_predictions', 'Assay_21_predictions', 'Assay_22_predictions', 'Assay_23_predictions', 'Assay_24_predictions', 'Assay_25_predictions', 'Assay_26_predictions', 'Assay_27_predictions', 'Assay_28_predictions', 'Assay_29_predictions', 'Assay_2_predictions', 'Assay_30_predictions', 'Assay_3_predictions', 'Assay_4_predictions', 'Assay_5_predictions', 'Assay_6_predictions', 'Assay_7_predictions', 'Assay_8_predictions', 'Assay_9_predictions']


In [16]:
color_range = [-3, 3]
color_domain = [0, 1]
color_map = 'RdYlBu_r'
def get_color(v):
    # Get a colormap (RdYlBu) from matplotlib
    my_cmap = cm.get_cmap(color_map)

    # Define a linear interpolation function that maps a value from the range [-3, 3] to the range [0, 1]
    m = interp1d(color_range, color_domain)

    # Get the RGBA values from the colormap based on the interpolated value of v
    r, g, b, a = my_cmap(m(v))

    # Return the RGBA color value as a formatted string
    return f'rgba({r}, {g}, {b}, {a})'

In [5]:
def get_relevant(genus='g__Turicibacter'):
    focus_bs = []
    for assay, df in split_dfs.items():
        if genus in df['vnames'].values:
            focus_bs.append(assay)
    return focus_bs

In [6]:
def make_bubbles(leaf_df, assay):
    fig = go.Figure(go.Scatter(
        #x=leaf_df['level'],
        x=leaf_df.index,#['node'],
        y=leaf_df.index,#['node'],
        marker_size=[math.sqrt(n)*10 for n in leaf_df['nobs']],
        marker_color=[get_color(m) for m in leaf_df['y_mean']],
        marker_opacity=1,
        marker_line_color='black',
        mode = 'markers'
    ))
    fig.update_layout(template='plotly_white', title=assay)
    return fig

In [7]:
def divide_into_bins(slice, g, total_bins, split_value):
    min_value = min(slice[g])
    max_value = max(slice[g])
    range_value = max_value - min_value
    bin_size = range_value / total_bins
    if split_value >= 0:
        above_bins = int((range_value - split_value) / bin_size)
        below_bins = total_bins - above_bins
    else:
        below_bins = int(abs(split_value) / bin_size)
        above_bins = total_bins - below_bins
    return below_bins, above_bins

In [8]:
split_dfs[assays[0]]

Unnamed: 0,node,vnames,split,split_value,y_mean,error,nobs,pvalue,rules,isLeaf
0,1,root,,,-0.045163,0.99474,378,0.0058,,False
1,2,g__unclassified_Porphyromonadaceae,<=,-1.57138,0.855799,1.090892,22,0.0157,g__unclassified_Porphyromonadaceae <= -1.57137...,False
2,3,g__unclassified_Firmicutes,<=,-1.63094,1.690039,0.561726,9,,g__unclassified_Porphyromonadaceae <= -1.57137...,True
3,4,g__unclassified_Firmicutes,>,-1.63094,0.278247,0.994431,13,,g__unclassified_Porphyromonadaceae <= -1.57137...,True
4,5,g__unclassified_Porphyromonadaceae,>,-1.57138,-0.10084,0.962819,356,0.0059,g__unclassified_Porphyromonadaceae > -1.571377...,False
5,6,g__Syntrophococcus,<=,0.27801,-0.24517,0.944481,218,,g__unclassified_Porphyromonadaceae > -1.571377...,True
6,7,g__Syntrophococcus,>,0.27801,0.127159,0.950619,138,0.0326,g__unclassified_Porphyromonadaceae > -1.571377...,False
7,8,g__Blautia,<=,-0.02646,-0.138688,0.999037,56,,g__unclassified_Porphyromonadaceae > -1.571377...,True
8,9,g__Blautia,>,-0.02646,0.308713,0.876643,82,0.0102,g__unclassified_Porphyromonadaceae > -1.571377...,False
9,10,g__Anaerotruncus,<=,0.77112,0.501029,0.888555,57,0.0338,g__unclassified_Porphyromonadaceae > -1.571377...,False


In [9]:
G_DF

Unnamed: 0,Mouse ID,g__Syntrophococcus,g__unclassified_Bacteroidales,g__Turicibacter,g__unclassified_Porphyromonadaceae,g__Roseburia,g__Fusicatenibacter,g__Acetatifactor,g__Barnesiella,g__Clostridium_XlVb,...,g__Parvibacter,g__Dorea,g__Escherichia_Shigella,g__unclassified_Peptococcaceae_1,g__Anaerofustis,g__unclassified_Bacteria,g__Enterococcus,g__Desulfonispora,g__Alistipes,g__Staphylococcus
0,7578,-0.877994,0.507025,-0.169459,-0.209865,-1.981873,-0.744681,0.169459,-0.009921,-0.607475,...,-0.647744,-0.404268,-0.309086,-0.122662,-0.735982,-0.552691,-0.026458,-0.006614,-0.003307,-0.092726
1,7579,0.347478,0.537345,1.287589,2.306044,-1.630943,0.587708,0.507025,2.557126,2.148649,...,0.503268,0.937921,-0.309086,-0.122662,-0.735982,-0.552691,-0.026458,-0.006614,-0.003307,1.594527
2,7580,0.503268,0.714474,-0.169459,0.433147,1.142375,0.162754,0.473462,1.265186,0.365096,...,-0.647744,-0.404268,-0.309086,-0.122662,-0.735982,-0.552691,-0.026458,-0.006614,-0.003307,-0.092726
3,7581,-1.981873,-0.298697,-0.169459,-0.560413,0.689092,-0.807245,-1.350733,0.076132,0.948239,...,-0.647744,-0.404268,0.882866,-0.122662,-0.735982,1.318504,-0.026458,-0.006614,-0.003307,1.915655
4,7582,-0.595587,-1.017939,-0.169459,0.076132,1.410774,-0.727339,0.811835,-0.049624,-0.069501,...,0.209865,-0.404268,-0.309086,-0.122662,-0.735982,-0.552691,-0.026458,-0.006614,-0.003307,-0.092726
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
373,8680,0.250616,0.631514,1.287589,-0.697503,-0.350993,-0.382828,0.664147,-0.243797,0.333461,...,-0.647744,-0.404268,-0.309086,-0.122662,1.958837,0.877994,-0.026458,-0.006614,-0.003307,-0.092726
374,8681,0.023150,-0.969183,-0.169459,-1.001422,-0.518339,1.129765,-0.518339,-1.310657,0.365096,...,-0.647744,-0.404268,-0.309086,-0.122662,-0.735982,-0.552691,-0.026458,-0.006614,-0.003307,-0.092726
375,8682,-1.895339,-0.389955,-0.169459,-1.034737,-0.868311,1.643576,1.643576,-1.136048,-0.753437,...,-0.647744,-0.404268,1.771262,-0.122662,0.397101,1.075118,-0.026458,-0.006614,-0.003307,1.594527
376,8683,-0.631514,-2.655650,-0.169459,-0.958658,-1.630943,0.418664,0.109345,-1.287589,0.102694,...,0.503268,-0.404268,0.882866,-0.122662,-0.735982,1.161638,-0.026458,-0.006614,-0.003307,-0.092726


In [19]:
def get_subset(row, current_bug):
    rules = row[1]['rules']
    subset = G_DF.copy()
    if rules:
        rules = rules.split(' & ')
        for rule in rules:
            bug, sign, val = rule.split(' ')
            if current_bug != bug:
                if sign == '<=':
                    subset = subset[subset[bug] <= float(val)]
                elif sign == '>':
                    subset = subset[subset[bug] > float(val)]
    return subset

In [20]:
def make_charts(assay):
    try:
        os.mkdir(f'output')
    except:
        pass
    try:
        os.mkdir(f'output/{assay}')
    except:
        pass
    assay_lookup = dict(zip(B_DF['Mouse ID'], B_DF[assay]))
    split_df = split_dfs[assay]
    split_df['level'] = [(len(str(s).split(' '))) for s in split_df['rules']]
    split_df['level'] = split_df['level'].rank(method = 'dense')
    leaf_df = split_df[split_df['isLeaf']==True]#.reset_index(drop=True)
    bubbles = make_bubbles(leaf_df, assay)
    bubbles.write_image(f'output/{assay}/bubbles.svg')
    seen_bugs = {'root':0}
    levels = { l : [] for l in set(split_df['level'].values) }
    num_cols = split_df['isLeaf'].sum()
    num_rows = split_df['level'].max()

    for row in split_df.iterrows():
        g = row[1]['vnames']
        v = row[1]['split_value']
        compute = False
        if g not in seen_bugs.keys():
            compute = True
        if g in seen_bugs.keys():
            if g != 'root':
                if seen_bugs[g] != v:
                    compute = True
        if compute == True:
            #slice = G_DF.copy().filter([g, 'Mouse ID'])
            slice = get_subset(row, g)
            slice[assay] = [assay_lookup[m] for m in slice['Mouse ID']]
            slice['group'] = slice[g].apply(lambda x: 'below' if x < v else 'above')
            above_group = slice[slice['group'] == 'above']
            below_group = slice[slice['group'] == 'below']
            above_mean = above_group[assay].mean()
            #fig.add_vline(x=0, line_color='lightgray', line_width=2)
            below_mean = below_group[assay].mean()
            fig = px.histogram(slice, x=g, color='group', 
                            #histnorm='probability density', 
                            color_discrete_map={'above': get_color(above_mean), 'below': get_color(below_mean)})
            fig.update_traces(marker_line_width=1, marker_line_color="black")
            fig.update_layout(template='plotly_white', width=300, height=200, showlegend=False, bargap=0, barmode='stack')
            fig.update_xaxes(title=g.replace('g__','').replace('_','<br>'), zeroline=True)
            fig.update_yaxes(title='')#title='Probability<br>Density', visible=False)
            fig.add_vline(x=v, line_color='red', line_dash="dot")
            fig.add_annotation(x=v+0.5, y=1, yref='paper', text=f"{v:.2f}", font_color='red',showarrow=False,)
            # fig.show()
            seen_bugs[g]=v
            this_level = row[1]['level']
            this_col = len(levels[this_level])
            #fig.update_layout(title=f'Row: {this_level}, Col: {this_col}')
            fig.update_layout(title=f"Node: {row[1]['node']-1}, Nobs: {row[1]['nobs']}, P-val: {row[1]['pvalue']}, Split: {v}")
            levels[this_level].insert(this_col, dcc.Graph(figure=fig, responsive=True, style={'display': 'inline-block', 'background-color': '#000'}))
            fig.update_layout(autosize=False,
                              width=400, height=300,
                              paper_bgcolor='rgba(0,0,0,0)',
                              plot_bgcolor='rgba(0,0,0,0)'
                              )
            fig.write_image(f"output/{assay}/{row[1]['node']}_{g}.svg")
    return levels


In [21]:
def make_webpage(levels):
    content = []
    for level, figs in levels.items():
        figs = [f for f in figs if f is not None]
        content.append(html.Div(figs, style={'margin': '0 auto', 'background-color': '#666'}))
    server = Flask(__name__)
    app = Dash(__name__, server=server)
    assays.sort()
    app.layout = html.Div(content, style={'margin': '0 auto', 'background-color': '#333'})
    if __name__ == '__main__':
        app.run_server()

In [22]:
bug = 'g__Turicibacter'
relevant = get_relevant(genus=bug)

In [26]:
%%capture
for b in relevant:
    make_charts(b)

In [27]:
%%capture
make_charts('OFA_total_distance_traveled_batch_ranknorm')