In [1]:
import pandas as pd
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from workflow.wrangling_funcs import clean_string

In [2]:
edge_df = pd.read_csv("mashtriangle_edge.tsv",sep="\t", names=["source", "target", "dist", "p-val", "shared-hases"])

idx = sorted(set(edge_df["source"]).union(edge_df["target"]))
dist = (
    edge_df.pivot(index="source", columns="target", values="dist")
    .reindex(index=idx, columns=idx)
    .fillna(0, downcast="infer")
    .pipe(lambda x: x + x.values.T)
)

AC = AgglomerativeClustering(
    n_clusters=None,
    metric="precomputed",
    compute_full_tree=True,
    linkage="single",
    distance_threshold=0.0001,
)
clusters = AC.fit_predict(np.array(dist))
cluster_df = pd.DataFrame()
cluster_df["path"] = list(dist.index)
cluster_df["cluster"] = clusters

  .fillna(0, downcast="infer")


In [3]:
cluster_df['plasmid_bin'] = cluster_df['path'].str.split('/').str[-1]
cluster_df['run_accession'] = cluster_df['plasmid_bin'].str.split('_').str[0]
metadata_df = pd.read_csv("/home/bayraktar/PycharmProjects/reconstruct_plasmids_snakemake/metadata.csv",sep=",")
metadata_df['inferred_country'] = metadata_df['inferred_country'].apply(clean_string)
relevant_df = metadata_df[['run_accession', 'taxon_id', 'scientific_name', 'strain', 'inferred_source','inferred_collection_year', 'inferred_continent', 'inferred_country', 'inferred_city']]
merged = pd.merge(cluster_df, relevant_df, on="run_accession", how='inner')

In [4]:
merged

Unnamed: 0,path,cluster,plasmid_bin,run_accession,taxon_id,scientific_name,strain,inferred_source,inferred_collection_year,inferred_continent,inferred_country,inferred_city
0,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,71,ERR10074377_bin_1.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,south_africa,Western Cape
1,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1597,ERR10074377_bin_2.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,south_africa,Western Cape
2,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1,ERR10074377_bin_Isolated_1.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,south_africa,Western Cape
3,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,1337,ERR10074377_bin_Unbinned.fasta,ERR10074377,624,Shigella sonnei,,Stool culture,2015,Africa,south_africa,Western Cape
4,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,935,ERR10074378_bin_1.fasta,ERR10074378,624,Shigella sonnei,,Stool culture,2013,Africa,south_africa,Gauteng
...,...,...,...,...,...,...,...,...,...,...,...,...
2369,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,416,ERR11597012_bin_4.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,lebanon,
2370,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,193,ERR11597012_bin_5.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,lebanon,
2371,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,12,ERR11597012_bin_6.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,lebanon,
2372,/hpc/dla_mm/dbayraktar/data/29_03_2024_run_Shi...,596,ERR11597012_bin_7.fasta,ERR11597012,624,Shigella sonnei,Cl-059,stool,2016,Asia,lebanon,


In [5]:
merged.groupby('cluster')['inferred_collection_year'].unique()

cluster
0                         [2010, 2015]
1       [2015, 2011, 2012, 2013, 2014]
2                         [2018, 2012]
3       [2012, 2014, 2011, 2013, 2015]
4       [2013, 2014, 2011, 2015, 2012]
                     ...              
1719                            [2012]
1720                            [2013]
1721                            [2014]
1722                            [2012]
1723                            [2012]
Name: inferred_collection_year, Length: 1724, dtype: object

## Plot

In [6]:
# merged.groupby(['cluster']).size().reset_index(name='count')

In [7]:
counts = merged.groupby(['cluster']).size().reset_index(name='count')
mask = counts['count'] >= 2
large_cluster_names = counts[mask]['cluster'].tolist()
large_cluster_names

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 59,
 60,
 61,
 62,
 63,
 65,
 66,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 110,
 114,
 119,
 123,
 124,
 125,
 126,
 127,
 128,
 133,
 134,
 137,
 138,
 141,
 142,
 143,
 144,
 145,
 146,
 148,
 149,
 150,
 151,
 152,
 155,
 156,
 158,
 165,
 169,
 171,
 172,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 186,
 187,
 188,
 189,
 190,
 192,
 197,
 199,
 202,
 204,
 206,
 207,
 208,
 209,
 210,
 212,
 214,
 216,
 222,
 240,
 248,
 251,
 256,
 270,
 275,
 276,
 283,
 285,
 287,
 288,
 292,
 297,
 301,
 302,
 304,
 306,
 344,
 358,
 362,
 364,
 368,
 374,
 376,
 

In [8]:
merged_cols = merged[['cluster', 'inferred_country']]
large_cluster = merged_cols.loc[merged_cols['cluster'].isin(large_cluster_names)]
grouped = large_cluster.groupby(['cluster','inferred_country']).size().reset_index(name='count')
grouped

Unnamed: 0,cluster,inferred_country,count
0,0,lebanon,2
1,1,south_africa,93
2,2,belgium,1
3,2,lebanon,1
4,3,south_africa,25
...,...,...,...
220,610,south_africa,2
221,613,south_africa,5
222,614,lebanon,6
223,729,south_africa,2


In [9]:
from collections import defaultdict
dicto = defaultdict(dict)

for country in set(grouped['inferred_country']):
    dicto[country] = dict.fromkeys(grouped['cluster'], 0)

for cluster, country, count in zip(grouped['cluster'], grouped['inferred_country'], grouped['count']):
    dicto[country][cluster] += count

dicto

defaultdict(dict,
            {'belgium': {0: 0,
              1: 0,
              2: 1,
              3: 0,
              4: 0,
              5: 0,
              6: 0,
              7: 3,
              8: 0,
              9: 0,
              10: 2,
              11: 0,
              12: 1,
              13: 0,
              14: 2,
              15: 0,
              16: 0,
              17: 0,
              18: 0,
              19: 0,
              20: 0,
              21: 0,
              22: 0,
              23: 0,
              24: 0,
              25: 1,
              26: 0,
              27: 2,
              28: 18,
              29: 2,
              30: 1,
              31: 0,
              32: 0,
              33: 6,
              34: 0,
              35: 0,
              36: 2,
              37: 1,
              38: 0,
              39: 0,
              40: 0,
              41: 0,
              42: 0,
              43: 0,
              44: 0,
              45: 0,
              

In [10]:
for a, b in dicto.items():
    # print(a, list(b.values()))
    dicto[a] = list(b.values())
dicto['cluster'] = large_cluster_names
print(dicto)

defaultdict(<class 'dict'>, {'belgium': [0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 18, 2, 1, 0, 0, 6, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 5, 2, 0, 6, 2, 0, 2, 0, 0, 2, 0, 1, 3, 0, 0, 0, 0, 0, 0, 2, 0, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0, 2, 0, 1, 0, 2, 0, 0, 3, 1, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 12, 0, 0, 0, 2, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, 0, 18, 2, 7, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0, 0, 0, 4, 0, 0, 2, 0, 0, 2, 0, 0, 4, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0], 'south_africa': [0, 93, 0, 25, 21, 11, 2, 0, 5, 2, 0, 9, 3, 2, 0, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 0, 2, 0, 0, 0, 1, 5, 4, 0, 2, 2, 0, 1, 21, 0, 2, 4, 2, 4, 8, 0, 0, 3, 0, 2, 5, 0, 0, 2, 0, 0, 2, 0, 2, 2, 0, 2, 10, 0, 2, 2, 4, 7, 0, 1, 2, 2, 0, 1, 2, 0, 2, 0, 2, 5, 0, 2, 0, 2, 1, 10, 0, 1, 0, 0, 1, 0, 0, 2, 2, 2, 3, 3, 10, 0, 0, 0, 2, 1, 0, 9, 2, 3, 0, 2, 

In [11]:
cluster = list(map(str, dicto['cluster']))
countries = list(set(grouped['inferred_country'].tolist()))

print(cluster)
print(countries)

['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '59', '60', '61', '62', '63', '65', '66', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '110', '114', '119', '123', '124', '125', '126', '127', '128', '133', '134', '137', '138', '141', '142', '143', '144', '145', '146', '148', '149', '150', '151', '152', '155', '156', '158', '165', '169', '171', '172', '177', '178', '179', '180', '181', '182', '183', '186', '187', '188', '189', '190', '192', '197', '199', '202', '204', '206', '207', '208', '209', '210', '212', '214'

In [12]:
test = pd.DataFrame(dicto)
test['cluster'] = test['cluster'].astype(str)
test

Unnamed: 0,belgium,south_africa,colombia,lebanon,cluster
0,0,0,0,2,0
1,0,93,0,0,1
2,1,0,0,1,2
3,0,25,0,0,3
4,0,21,0,0,4
...,...,...,...,...,...
195,0,2,0,0,610
196,0,5,0,0,613
197,0,0,0,6,614
198,0,2,0,0,729


In [13]:
merged['cluster'] = merged['cluster'].astype(str)
years_per_cluster = merged.groupby('cluster')['inferred_collection_year'].unique()
sample_source = merged.groupby('cluster')['inferred_source'].unique()
sample_city = merged.groupby('cluster')['inferred_city'].unique()
test2 = pd.merge(test, years_per_cluster, on="cluster", how='inner')
test2 = pd.merge(test2, sample_source, on="cluster", how='inner')
test2 = pd.merge(test2, sample_city, on="cluster", how='inner')
test2

Unnamed: 0,belgium,south_africa,colombia,lebanon,cluster,inferred_collection_year,inferred_source,inferred_city
0,0,0,0,2,0,"[2010, 2015]",[stool],[nan]
1,0,93,0,0,1,"[2015, 2011, 2012, 2013, 2014]",[Stool culture],"[Western Cape, Gauteng, Eastern Cape, Free Sta..."
2,1,0,0,1,2,"[2018, 2012]","[feces, stool]",[nan]
3,0,25,0,0,3,"[2012, 2014, 2011, 2013, 2015]",[Stool culture],"[Gauteng, Eastern Cape, Western Cape, Mpumalan..."
4,0,21,0,0,4,"[2013, 2014, 2011, 2015, 2012]",[Stool culture],"[Gauteng, Free State, Western Cape, KwaZulu-Na..."
...,...,...,...,...,...,...,...,...
195,0,2,0,0,610,"[2011, 2012]","[Blood culture, Stool culture]",[Gauteng]
196,0,5,0,0,613,"[2011, 2013, 2014, 2015]",[Stool culture],"[Gauteng, KwaZulu-Natal, Western Cape]"
197,0,0,0,6,614,"[2012, 2015]",[stool],[nan]
198,0,2,0,0,729,"[2012, 2014]",[Stool culture],"[Gauteng, Western Cape]"


In [17]:
from bokeh.palettes import Category20
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

output_notebook()

palette = Category20[len(countries)]
tooltips = [(column, f"@{column}") for column in test2.columns]

p = figure(x_range=test2.cluster, height=250, width=1500, tooltips=tooltips)

for idx, country in enumerate(countries):
    p.vbar(x='cluster', top=country, width=0.9, source=test2,
           color=palette[idx], legend_label=country)

p.y_range.start = 0
p.x_range.range_padding = 0.1
p.xgrid.grid_line_color = None
p.axis.minor_tick_line_color = None
p.outline_line_color = None
p.legend.location = "top_right"
p.legend.orientation = "horizontal"

show(p)
