In [1]:
import pandas as pd
import numpy as np
from sklearn.cluster import AgglomerativeClustering

In [2]:
edge_df = pd.read_csv("mashtriangle_edge.tsv",sep="\t", names=["source", "target", "dist", "p-val", "shared-hases"])

idx = sorted(set(edge_df["source"]).union(edge_df["target"]))
dist = (
    edge_df.pivot(index="source", columns="target", values="dist")
    .reindex(index=idx, columns=idx)
    .fillna(0, downcast="infer")
    .pipe(lambda x: x + x.values.T)
)

AC = AgglomerativeClustering(
    n_clusters=None,
    metric="precomputed",
    compute_full_tree=True,
    linkage="single",
    distance_threshold=0.0001,
)
clusters = AC.fit_predict(np.array(dist))
cluster_df = pd.DataFrame()
cluster_df["path"] = list(dist.index)
cluster_df["cluster"] = clusters

  .fillna(0, downcast="infer")


In [3]:
cluster_df['plasmid_bin'] = cluster_df['path'].str.split('/').str[-1]
cluster_df['run_accession'] = cluster_df['plasmid_bin'].str.split('_').str[0]
metadata_df = pd.read_csv("/home/bayraktar/PycharmProjects/reconstruct_plasmids_snakemake/metadata.csv",sep=",")
relevant_df = metadata_df[['run_accession', 'taxon_id', 'scientific_name', 'strain', 'inferred_source','inferred_collection_year', 'inferred_continent', 'inferred_country', 'inferred_city']]
merged = pd.merge(cluster_df, relevant_df, on="run_accession", how='inner')

## Plot

In [4]:
# merged.groupby(['cluster']).size().reset_index(name='count')

In [5]:
counts = merged.groupby(['cluster']).size().reset_index(name='count')
mask = counts['count'] > 10
large_cluster_names = counts[mask]['cluster'].tolist()
large_cluster_names

[1, 3, 4, 5, 28, 38, 65, 114, 127, 144, 145, 207, 276]

In [6]:
merged_cols = merged[['cluster', 'inferred_country']]
large_cluster = merged_cols.loc[merged_cols['cluster'].isin(large_cluster_names)]
grouped = large_cluster.groupby(['cluster','inferred_country']).size().reset_index(name='count')
grouped

Unnamed: 0,cluster,inferred_country,count
0,1,South Africa,93
1,3,South Africa,25
2,4,South Africa,21
3,5,South Africa,11
4,28,Belgium,18
5,38,South Africa,21
6,65,Belgium,1
7,65,South Africa,10
8,114,Belgium,12
9,127,South Africa,18


In [7]:
from collections import defaultdict
dicto = defaultdict(dict)

for country in set(grouped['inferred_country']):
    dicto[country] = dict.fromkeys(grouped['cluster'], 0)

for cluster, country, count in zip(grouped['cluster'], grouped['inferred_country'], grouped['count']):
    dicto[country][cluster] += count

dicto

defaultdict(dict,
            {'Lebanon': {1: 0,
              3: 0,
              4: 0,
              5: 0,
              28: 0,
              38: 0,
              65: 0,
              114: 0,
              127: 0,
              144: 0,
              145: 0,
              207: 1,
              276: 0},
             'South Africa': {1: 93,
              3: 25,
              4: 21,
              5: 11,
              28: 0,
              38: 21,
              65: 10,
              114: 0,
              127: 18,
              144: 14,
              145: 0,
              207: 17,
              276: 14},
             'Belgium': {1: 0,
              3: 0,
              4: 0,
              5: 0,
              28: 18,
              38: 0,
              65: 1,
              114: 12,
              127: 0,
              144: 0,
              145: 18,
              207: 0,
              276: 0}})

In [8]:
for a, b in dicto.items():
    # print(a, list(b.values()))
    dicto[a] = list(b.values())
dicto['cluster'] = large_cluster_names
print(dicto)

defaultdict(<class 'dict'>, {'Lebanon': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 'South Africa': [93, 25, 21, 11, 0, 21, 10, 0, 18, 14, 0, 17, 14], 'Belgium': [0, 0, 0, 0, 18, 0, 1, 12, 0, 0, 18, 0, 0], 'cluster': [1, 3, 4, 5, 28, 38, 65, 114, 127, 144, 145, 207, 276]})


In [9]:
cluster = list(map(str, dicto['cluster']))
countries = list(set(grouped['inferred_country'].tolist()))

print(cluster)
print(countries)

['1', '3', '4', '5', '28', '38', '65', '114', '127', '144', '145', '207', '276']
['Lebanon', 'South Africa', 'Belgium']


In [10]:
test = pd.DataFrame(dicto)
test['cluster'] = test['cluster'].astype(str)
test

Unnamed: 0,Lebanon,South Africa,Belgium,cluster
0,0,93,0,1
1,0,25,0,3
2,0,21,0,4
3,0,11,0,5
4,0,0,18,28
5,0,21,0,38
6,0,10,1,65
7,0,0,12,114
8,0,18,0,127
9,0,14,0,144


In [12]:
from bokeh.palettes import HighContrast3
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

output_notebook()

p = figure(x_range=test.cluster, height=250, tooltips="$name")

p.vbar_stack(countries, x='cluster', width=0.9, color=HighContrast3, source=test,
             legend_label=countries)

p.y_range.start = 0
p.x_range.range_padding = 0.1
p.xgrid.grid_line_color = None
p.axis.minor_tick_line_color = None
p.outline_line_color = None
p.legend.location = "top_left"
p.legend.orientation = "horizontal"

show(p)