In [None]:
from functools import reduce
import pandas as pd
from bokeh.plotting import figure, show, output_notebook
import panel as pn
import bokeh.models as bm
from bokeh.models.widgets import DataTable, TableColumn
from bokeh.models import HoverTool
from bokeh.io import curdoc
from bokeh.palettes import Colorblind
import itertools

path = "/students/2024-2025/master/chili/output/kraken2/merged"
dfs = []
rank_list = ["superkingdom", "kingdom", "phylum", "class", "order", "family", "genus", "species"]

for barcode in range(1,21):
    df = pd.read_csv(f"{path}/barcode{str(barcode).zfill(2)}.tsv", sep="\t", index_col=False,  usecols=[1, 2])
 
    df = df.rename(columns={"fragments": f"fragments{barcode}"})

    eukaryote_filter = df["tax"].str.contains("d__Eukaryota")
    df = df[~eukaryote_filter]

    ratio_dict = {"d__Archaea" : max(df[df["tax"] == "d__Archaea"][f"fragments{barcode}"]),
                  "d__Bacteria" : max(df[df["tax"] == "d__Bacteria"][f"fragments{barcode}"]),
                  "d__Viruses" : max(df[df["tax"] == "d__Viruses"][f"fragments{barcode}"])}
    
    #df[f"ratio{barcode}"] = (round(df[f"fragments{barcode}"] / ratio_dict[f"{df['tax'].split('|')[0]}"], 10))
    # print(df["fragments1"])
    df[f"ratio{barcode}"] = df.apply(lambda row: (round(row[f"fragments{barcode}"] / ratio_dict[f"{row['tax'].split('|')[0]}"], 10)), axis=1)
    dfs.append(df)  

merged_df = reduce(lambda df1,df2: pd.merge(df1,df2,on="tax", how="outer"), dfs)
merged_df = merged_df.sort_values("tax")
merged_df = merged_df.fillna(0)
merged_df[rank_list] = merged_df['tax'].str.split('|', expand=True)
merged_df = merged_df.set_index("tax")
merged_df.to_csv(f"{path}/merged-all.tsv", sep="\t", index=False)



In [None]:
# p1 = [1,5,9,13,17]
# p2 = [2,6,10,14,18]
# p3 = [3,7,11,15,19]
# p4 = [4,8,12,16,20]

ratio_cols = [[] for _ in range(0,4)]

for i in range(1,21):
    rem = i % 4
    if rem == 0:
        ratio_cols[3].append(f"ratio{i}")
    if rem == 1:
        ratio_cols[0].append(f"ratio{i}")
    if rem == 2:
        ratio_cols[1].append(f"ratio{i}")
    if rem == 3:
        ratio_cols[2].append(f"ratio{i}")

def get_delta(df, cols):
    for part, part_cols in enumerate(cols):
        delta_cols = []
        for i in range(0, len(part_cols)-1):
            delta_col = f"delta_{part_cols[i]}_{part_cols[i+1]}"
            df[delta_col] = abs(df[part_cols[i]] - df[part_cols[i+1]])
            delta_cols.append(delta_col)
        df[f"max_delta{part}"] = df[delta_cols].max(axis=1)

get_delta(merged_df, ratio_cols)
    

In [None]:
# merged_df[["delta_ratio3_ratio7",	"delta_ratio7_ratio11", "ratio3", "ratio7", "ratio11"]].to_csv("test_ratios.csv", sep="\t", index=False)
merged_df["max_delta_all"] = merged_df[["max_delta0", "max_delta1", "max_delta2", "max_delta3"]].max(axis=1)

In [None]:
merged_df["max_delta_all"].sort_values(ascending=False)


In [None]:
# # Family or genus
# superkingdom, kingdom, phylum, class, order, family, genus, species
merged_df.index = pd.MultiIndex.from_tuples(merged_df.index.str.split("|").tolist())
merged_df.index = merged_df.index.set_names(rank_list)
merged_df.to_csv(f"test.tsv", sep="\t")


In [None]:
pn.extension()

selectors = {
    0: pn.widgets.Select(
        name="super_kingdom", 
        options=merged_df.index.get_level_values(0).unique().tolist()
    )
}

source = bm.ColumnDataSource(pd.DataFrame(columns=merged_df.index.names))
plot = figure(title="Line Chart", x_axis_label="Sample", y_axis_label="Micro-organism ratio", width=800, height=400)
plot.add_tools(HoverTool(tooltips=[("Sample", "$x"), ("Ratio", "$y")]))

def update_table(event):
    select_len = len(selectors)
    filtered_df = merged_df[merged_df.index.get_level_values(select_len - 1) == selectors[select_len - 1].value]
    source.data = bm.ColumnDataSource.from_df(filtered_df)
    
    flat_ratios = rank_list[select_len:] + ["max_delta_all"] + list(itertools.chain.from_iterable(ratio_cols))
    data_table.columns = [TableColumn(field=col, title=col) for col in flat_ratios]
    
    if select_len < len(merged_df.index.levels):
        new_selector = pn.widgets.Select(
            name=filtered_df.index.names[select_len],
            options=filtered_df.index.get_level_values(select_len).unique().tolist()
        )
        selectors[select_len] = new_selector
        new_selector.param.watch(update_table, 'value')
        selectors_row.append(new_selector)

def update_plot(attr, old, new):
    plot.renderers = []
    if not new:
        return
    
    selected_index = new[0]
    selected_row = source.to_df().iloc[selected_index]
    
    for idx, cols in enumerate(ratio_cols):
        plot.line(
            [1, 2, 3, 4, 5], 
            selected_row[cols], 
            line_width=2, 
            legend_label=str(idx), 
            color=Colorblind.get(5)[idx]
        )

source.selected.on_change("indices", update_plot)
for selector in selectors.values():
    selector.param.watch(update_table, 'value')

columns = [TableColumn(field=col, title=col) for col in merged_df.index.names]
data_table = DataTable(source=source, columns=columns, width=1880)

selectors_row = pn.Row(*selectors.values())
layout = pn.Column(selectors_row, data_table, plot)

served = pn.serve(layout)


In [None]:
selectors.get(0).value

In [None]:
merged_df.index

In [None]:
merged_df.index.names

In [None]:
Colorblind

In [None]:
merged_df.index.names[0]