In [None]:
import pandas as pd
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from workflow.wrangling_funcs import clean_string

In [None]:
edge_df = pd.read_csv("mashtriangle_edge.tsv",sep="\t", names=["source", "target", "dist", "p-val", "shared-hases"])

idx = sorted(set(edge_df["source"]).union(edge_df["target"]))
dist = (
    edge_df.pivot(index="source", columns="target", values="dist")
    .reindex(index=idx, columns=idx)
    .fillna(0, downcast="infer")
    .pipe(lambda x: x + x.values.T)
)

AC = AgglomerativeClustering(
    n_clusters=None,
    metric="precomputed",
    compute_full_tree=True,
    linkage="single",
    distance_threshold=0.0001,
)
clusters = AC.fit_predict(np.array(dist))
cluster_df = pd.DataFrame()
cluster_df["path"] = list(dist.index)
cluster_df["cluster"] = clusters

In [None]:
cluster_df['plasmid_bin'] = cluster_df['path'].str.split('/').str[-1]
cluster_df['run_accession'] = cluster_df['plasmid_bin'].str.split('_').str[0]
metadata_df = pd.read_csv("/home/bayraktar/PycharmProjects/reconstruct_plasmids_snakemake/metadata.csv",sep=",")
metadata_df['inferred_country'] = metadata_df['inferred_country'].apply(clean_string)
relevant_df = metadata_df[['run_accession', 'taxon_id', 'scientific_name', 'strain', 'inferred_source','inferred_collection_year', 'inferred_continent', 'inferred_country', 'inferred_city']]
merged = pd.merge(cluster_df, relevant_df, on="run_accession", how='inner')

In [None]:
merged

## Plot

In [None]:
# merged.groupby(['cluster']).size().reset_index(name='count')

In [None]:
counts = merged.groupby(['cluster']).size().reset_index(name='count')
mask = counts['count'] >= 2
large_cluster_names = counts[mask]['cluster'].tolist()

In [None]:
merged_cols = merged[['cluster', 'inferred_country']]
large_cluster = merged_cols.loc[merged_cols['cluster'].isin(large_cluster_names)]
grouped = large_cluster.groupby(['cluster','inferred_country']).size().reset_index(name='count')
grouped

In [None]:
from collections import defaultdict
dicto = defaultdict(dict)

for country in set(grouped['inferred_country']):
    dicto[country] = dict.fromkeys(grouped['cluster'], 0)

for cluster, country, count in zip(grouped['cluster'], grouped['inferred_country'], grouped['count']):
    dicto[country][cluster] += count

dicto

In [None]:
for a, b in dicto.items():
    dicto[a] = list(b.values())
dicto['cluster'] = large_cluster_names
print(dicto)

In [None]:
cluster = list(map(str, dicto['cluster']))
countries = list(set(grouped['inferred_country'].tolist()))

print(cluster)
print(countries)

In [None]:
test = pd.DataFrame(dicto)
test['cluster'] = test['cluster'].astype(str)
test

In [None]:
merged['cluster'] = merged['cluster'].astype(str)
years_per_cluster = merged.groupby('cluster')['inferred_collection_year'].unique()
sample_source = merged.groupby('cluster')['inferred_source'].unique()
sample_city = merged.groupby('cluster')['inferred_city'].unique()
plasmid_bins = merged.groupby('cluster')['plasmid_bin'].unique()

test2 = pd.merge(test, years_per_cluster, on="cluster", how='inner')
test2 = pd.merge(test2, sample_source, on="cluster", how='inner')
test2 = pd.merge(test2, sample_city, on="cluster", how='inner')
test2 = pd.merge(test2, plasmid_bins, on="cluster", how='inner')
test2

In [None]:
from bokeh.palettes import Category20
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

output_notebook()

palette = Category20[len(countries)]
tooltips = [(column, f"@{column}\n") for column in test2.columns]



p = figure(x_range=test2.cluster, height=1000, width=1500, tooltips=tooltips)

for idx, country in enumerate(countries):
    p.vbar(x='cluster', top=country, width=0.9, source=test2,
           color=palette[idx], legend_label=country)

p.y_range.start = 0
p.x_range.range_padding = 0.1
p.xgrid.grid_line_color = None
p.axis.minor_tick_line_color = None
p.outline_line_color = None
p.legend.location = "top_right"
p.legend.orientation = "horizontal"

show(p)
