In [1]:
import pandas as pd
import numpy as np
from sklearn.cluster import AgglomerativeClustering

In [2]:
edge_df = pd.read_csv("mashtriangle_edge.tsv",sep="\t", names=["source", "target", "dist", "p-val", "shared-hases"])

idx = sorted(set(edge_df["source"]).union(edge_df["target"]))
dist = (
    edge_df.pivot(index="source", columns="target", values="dist")
    .reindex(index=idx, columns=idx)
    .fillna(0, downcast="infer")
    .pipe(lambda x: x + x.values.T)
)

AC = AgglomerativeClustering(
    n_clusters=None,
    metric="precomputed",
    compute_full_tree=True,
    linkage="single",
    distance_threshold=0.0001,
)
clusters = AC.fit_predict(np.array(dist))
cluster_df = pd.DataFrame()
cluster_df["path"] = list(dist.index)
cluster_df["cluster"] = clusters

  .fillna(0, downcast="infer")


In [3]:
cluster_df['plasmid_bin'] = cluster_df['path'].str.split('/').str[-1]
cluster_df['run_accession'] = cluster_df['plasmid_bin'].str.split('_').str[0]
metadata_df = pd.read_csv("/home/bayraktar/PycharmProjects/reconstruct_plasmids_snakemake/metadata.csv",sep=",")
relevant_df = metadata_df[['run_accession', 'taxon_id', 'scientific_name', 'strain', 'inferred_source','inferred_collection_year', 'inferred_continent', 'inferred_country', 'inferred_city']]
merged = pd.merge(cluster_df, relevant_df, on="run_accession", how='inner')

In [4]:
merged

Unnamed: 0,path,cluster,plasmid_bin,run_accession,taxon_id,scientific_name,strain,inferred_source,inferred_collection_year,inferred_continent,inferred_country,inferred_city
0,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,71,ERR10074377_bin_1.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,South Africa,Western Cape
1,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1597,ERR10074377_bin_2.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,South Africa,Western Cape
2,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1,ERR10074377_bin_Isolated_1.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,South Africa,Western Cape
3,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1337,ERR10074377_bin_Unbinned.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,South Africa,Western Cape
4,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,935,ERR10074378_bin_1.fasta,ERR10074378,624,Shigella sonnei,,Stool culture,2013,Africa,South Africa,Gauteng
...,...,...,...,...,...,...,...,...,...,...,...,...
2369,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,416,ERR11597012_bin_4.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,Lebanon,
2370,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,193,ERR11597012_bin_5.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,Lebanon,
2371,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,12,ERR11597012_bin_6.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,Lebanon,
2372,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,596,ERR11597012_bin_7.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,Lebanon,


In [13]:
merged.groupby('cluster')['inferred_collection_year'].unique()

cluster
0                         [2010, 2015]
1       [2015, 2011, 2012, 2013, 2014]
2                         [2018, 2012]
3       [2012, 2014, 2011, 2013, 2015]
4       [2013, 2014, 2011, 2015, 2012]
                     ...              
1719                            [2012]
1720                            [2013]
1721                            [2014]
1722                            [2012]
1723                            [2012]
Name: inferred_collection_year, Length: 1724, dtype: object

## Plot

In [14]:
# merged.groupby(['cluster']).size().reset_index(name='count')

In [15]:
counts = merged.groupby(['cluster']).size().reset_index(name='count')
mask = counts['count'] >= 3
large_cluster_names = counts[mask]['cluster'].tolist()
large_cluster_names

[1,
 3,
 4,
 5,
 7,
 8,
 11,
 12,
 21,
 28,
 30,
 31,
 32,
 33,
 38,
 41,
 43,
 44,
 47,
 50,
 51,
 54,
 65,
 66,
 70,
 71,
 73,
 74,
 77,
 83,
 90,
 94,
 97,
 101,
 102,
 103,
 105,
 106,
 114,
 119,
 124,
 127,
 134,
 137,
 138,
 142,
 144,
 145,
 148,
 149,
 150,
 152,
 171,
 179,
 192,
 207,
 208,
 214,
 216,
 251,
 256,
 276,
 283,
 287,
 302,
 306,
 358,
 368,
 374,
 382,
 414,
 430,
 606,
 613,
 614,
 764]

In [16]:
merged_cols = merged[['cluster', 'inferred_country']]
large_cluster = merged_cols.loc[merged_cols['cluster'].isin(large_cluster_names)]
grouped = large_cluster.groupby(['cluster','inferred_country']).size().reset_index(name='count')
grouped

Unnamed: 0,cluster,inferred_country,count
0,1,South Africa,93
1,3,South Africa,25
2,4,South Africa,21
3,5,South Africa,11
4,7,Belgium,3
...,...,...,...
84,606,Belgium,3
85,606,South Africa,1
86,613,South Africa,5
87,614,Lebanon,6


In [17]:
from collections import defaultdict
dicto = defaultdict(dict)

for country in set(grouped['inferred_country']):
    dicto[country] = dict.fromkeys(grouped['cluster'], 0)

for cluster, country, count in zip(grouped['cluster'], grouped['inferred_country'], grouped['count']):
    dicto[country][cluster] += count

dicto

defaultdict(dict,
            {'Lebanon': {1: 0,
              3: 0,
              4: 0,
              5: 0,
              7: 6,
              8: 0,
              11: 0,
              12: 1,
              21: 0,
              28: 0,
              30: 7,
              31: 0,
              32: 0,
              33: 0,
              38: 0,
              41: 0,
              43: 0,
              44: 0,
              47: 0,
              50: 0,
              51: 0,
              54: 0,
              65: 0,
              66: 0,
              70: 0,
              71: 0,
              73: 3,
              74: 0,
              77: 2,
              83: 0,
              90: 0,
              94: 0,
              97: 0,
              101: 0,
              102: 0,
              103: 0,
              105: 4,
              106: 7,
              114: 0,
              119: 0,
              124: 0,
              127: 0,
              134: 0,
              137: 0,
              138: 6,
              142: 0

In [18]:
for a, b in dicto.items():
    # print(a, list(b.values()))
    dicto[a] = list(b.values())
dicto['cluster'] = large_cluster_names
print(dicto)

defaultdict(<class 'dict'>, {'Lebanon': [0, 0, 0, 0, 6, 0, 0, 1, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 2, 0, 0, 0, 0, 0, 0, 0, 4, 7, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0], 'South Africa': [93, 25, 21, 11, 0, 5, 9, 3, 4, 0, 1, 5, 4, 0, 21, 4, 4, 8, 3, 5, 0, 0, 10, 0, 4, 7, 1, 2, 1, 5, 10, 0, 0, 3, 3, 10, 0, 0, 0, 9, 3, 18, 3, 0, 0, 5, 14, 0, 0, 5, 3, 3, 3, 3, 4, 17, 4, 3, 4, 5, 4, 14, 0, 0, 4, 8, 0, 4, 5, 3, 3, 5, 1, 5, 0, 3], 'Belgium': [0, 0, 0, 0, 3, 0, 0, 1, 0, 18, 1, 0, 0, 6, 0, 0, 0, 0, 0, 0, 5, 6, 1, 3, 0, 0, 0, 2, 1, 0, 0, 3, 3, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 3, 0, 0, 0, 18, 7, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 4, 0, 0, 0, 0, 0, 3, 0, 0, 0], 'cluster': [1, 3, 4, 5, 7, 8, 11, 12, 21, 28, 30, 31, 32, 33, 38, 41, 43, 44, 47, 50, 51, 54, 65, 66, 70, 71, 73, 74, 77, 83, 90, 94, 97, 101, 102, 103, 105, 106, 114, 119, 124, 127, 134, 137, 138, 142, 144, 145, 148, 149, 150, 

In [19]:
cluster = list(map(str, dicto['cluster']))
countries = list(set(grouped['inferred_country'].tolist()))

print(cluster)
print(countries)

['1', '3', '4', '5', '7', '8', '11', '12', '21', '28', '30', '31', '32', '33', '38', '41', '43', '44', '47', '50', '51', '54', '65', '66', '70', '71', '73', '74', '77', '83', '90', '94', '97', '101', '102', '103', '105', '106', '114', '119', '124', '127', '134', '137', '138', '142', '144', '145', '148', '149', '150', '152', '171', '179', '192', '207', '208', '214', '216', '251', '256', '276', '283', '287', '302', '306', '358', '368', '374', '382', '414', '430', '606', '613', '614', '764']
['Lebanon', 'South Africa', 'Belgium']


In [20]:
test = pd.DataFrame(dicto)
test['cluster'] = test['cluster'].astype(str)
test

Unnamed: 0,Lebanon,South Africa,Belgium,cluster
0,0,93,0,1
1,0,25,0,3
2,0,21,0,4
3,0,11,0,5
4,6,0,3,7
...,...,...,...,...
71,0,5,0,430
72,0,1,3,606
73,0,5,0,613
74,6,0,0,614


In [38]:
merged['cluster'] = merged['cluster'].astype(str)
years_per_cluster = merged.groupby('cluster')['inferred_collection_year'].unique()
sample_source = merged.groupby('cluster')['inferred_source'].unique()
sample_city = merged.groupby('cluster')['inferred_city'].unique()
test2 = pd.merge(test, years_per_cluster, on="cluster", how='inner')
test2 = pd.merge(test2, sample_source, on="cluster", how='inner')
test2 = pd.merge(test2, sample_city, on="cluster", how='inner')
test2

Unnamed: 0,Lebanon,South Africa,Belgium,cluster,inferred_collection_year,inferred_source,inferred_city
0,0,93,0,1,"[2015, 2011, 2012, 2013, 2014]",[Stool culture],"[Western Cape, Gauteng, Eastern Cape, Free Sta..."
1,0,25,0,3,"[2012, 2014, 2011, 2013, 2015]",[Stool culture],"[Gauteng, Eastern Cape, Western Cape, Mpumalan..."
2,0,21,0,4,"[2013, 2014, 2011, 2015, 2012]",[Stool culture],"[Gauteng, Free State, Western Cape, KwaZulu-Na..."
3,0,11,0,5,"[2011, 2012, 2014, 2013]",[Stool culture],"[KwaZulu-Natal, Mpumalanga, Gauteng, Western C..."
4,6,0,3,7,"[2018, 2011, 2012, 2009, 2013, 2015]","[feces, stool]",[nan]
...,...,...,...,...,...,...,...
71,0,5,0,430,"[2012, 2011, 2014]",[Stool culture],"[Gauteng, KwaZulu-Natal, Western Cape, Norther..."
72,0,1,3,606,"[2014, 2018]","[Stool culture, feces]","[Gauteng, nan]"
73,0,5,0,613,"[2011, 2013, 2014, 2015]",[Stool culture],"[Gauteng, KwaZulu-Natal, Western Cape]"
74,6,0,0,614,"[2012, 2015]",[stool],[nan]


In [47]:
from bokeh.palettes import HighContrast3
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

output_notebook()

tooltips = [(column, f"@{column}") for column in test2.columns]

p = figure(x_range=test2.cluster, height=250, width=1500, tooltips=tooltips)

p.vbar_stack(countries, x='cluster', width=0.9, color=HighContrast3, source=test2,
             legend_label=countries)

p.y_range.start = 0
p.x_range.range_padding = 0.1
p.xgrid.grid_line_color = None
p.axis.minor_tick_line_color = None
p.outline_line_color = None
p.legend.location = "top_right"
p.legend.orientation = "horizontal"

show(p)