# clustering of strains and type naming

## steps
- get core alignment
- generate snp dists
- cluster
- retain members
- run again with previous members and new core alignment/snpdists

## be able to
- add new samples with just SNPs without re-aligning to other samples. allows us to type new samples from snps quickly. 

## requirements
* current core alignments for baseline pop - can be recreated at any time 
* current cluster membership
* snpdist matrix is recreated from a) full alignment or b) on the fly from new addition

## Refs
* https://www.biorxiv.org/content/10.1101/2023.02.03.527052v1.full

In [2]:
import sys,os,random,time,string,itertools
import json
import subprocess
import math
from importlib import reload
import numpy as np
import pandas as pd
import pylab as plt
import seaborn as sns
import networkx as nx
from IPython.display import display, HTML
import toyplot,toytree

from btbabm import models,utils
from snipgenie import clustering,tools,plotting

## simulate test data

In [None]:
reload(models)
reload(utils)
model = models.FarmPathogenModel(F=40,C=800,S=10,seq_length=800)
model.run(1000)
gdf, snpdist, meta = model.get_metadata(treefile='cluster_test/tree.newick')
utils.draw_tree('cluster_test/tree.newick')
snpdist.to_csv('cluster_test/snpdist.csv')
meta.to_csv('cluster_test/meta.csv',index=False)

In [133]:
#test data
snpdist = pd.read_csv('cluster_test/snpdist.csv',index_col=0)
meta = pd.read_csv('cluster_test/meta.csv')
treefile = 'cluster_test/tree.newick'
X=meta.set_index('id')[['species','strain']]
print (len(snpdist),len(meta))

302 302


In [176]:
reload(clustering)

def get_subset(snpdist,X,n=10):
    #subset dist matrix
    sub = list(snpdist.sample(n).index)
    S = snpdist.loc[sub,sub]
    X=X.loc[sub]
    return S,X
    
def get_subtree(S,treefile,snpdist,filename):  
    #get subtree
    tree = toytree.tree(treefile)
    tips = list(snpdist.drop(S.index).index)
    subtree = tree.drop_tips(tips)
    subtree.write(filename)
    return 


In [None]:
S2,X2 = get_subset(snpdist,X,n=20)
S1,X1 = get_subset(S2,X2,n=15)
labels,clusters1 = clustering.dm_cluster(S1, T)

X1[scol]=labels
get_tree(S1,'sub1.newick')
utils.draw_tree('sub1.newick',X1,scol,tip_labels=True,width=600,cmap=clustering.snp200_cmap)

In [None]:
labels,clusters2 = clustering.dm_cluster(S2, T,clusters1)
X2[scol]=labels
get_tree(S2,'sub2.newick')
utils.draw_tree('sub2.newick',X2,scol,tip_labels=True,width=600,cmap=clustering.snp200_cmap)
cm=dict(zip(X2[scol],X2.color))
#X2.style.applymap(lambda x: "background-color: %s" %clustering.snp200_cmap[x], subset=[scol])

In [None]:
X1.style.applymap(lambda x: "background-color: %s" %clustering.snp200_cmap[x], subset=[scol])

In [None]:
cg = sns.clustermap(S2, cmap='Blues', row_colors=X2.color,figsize=(6,6))
p=plotting.make_legend(cg.fig, cm, loc=(1.1, .6), title='cluster',fontsize=10)

## iteratively add samples to test clusters

In [217]:
T=7
scol='snp'+str(T)

def test_cluster_runs(X,snpdist,T,treefile=None,n=50,steps=None):

    S1,X1 = get_subset(snpdist,X,n=n)
    prevclusters = None   
    if steps == None:
        steps = np.arange(10,n,20)
    res=[]
    i=1
    for l in steps:
        Sr = S1.iloc[:l,:l].copy()
        Xr = X1.iloc[:l].copy()
        #print(Xr)
        labels,clusters = clustering.dm_cluster(Sr, T, prevclusters)
        prevclusters = clusters
        Xr[scol]=labels
        if treefile != None:
            get_subtree(Sr,treefile,snpdist,'sub.newick')
            h=len(Xr)*10+100
            utils.draw_tree('sub.newick',Xr,scol,tip_labels=True,width=500,height=h,cmap=clustering.snp200_cmap)
        #display(Xr.style.applymap(lambda x: "background-color: %s" %clustering.snp200_cmap[x], subset=[scol]))
        #cg = sns.clustermap(Sr, cmap='Blues', row_colors=Xr.color,figsize=(6,6))
        #p=plotting.make_legend(cg.fig, cm, loc=(1.1, .6), title='cluster',fontsize=10)
        Xr['run'] = i
        i+=1
        res.append(Xr)
    return res

Xtest=test_cluster_runs(X,snpdist,T,treefile)

In [173]:
#distances within each cluster
for i, df in Xr.groupby(scol):
    #print (df)
    idx=df.index
    s=Sr.loc[idx,idx]
    print (i,s.max().max())

1 4
2 4
3 0
4 0
5 0
6 2
7 2
8 1
9 2


In [None]:
reload(utils)
reload(clustering)
labels,clusters = clustering.dm_cluster(snpdist, T, prevclusters)

X[scol]=labels
#print (newclusters)
cm=clustering.snp12_cmap
utils.draw_tree(treefile,X,scol,tip_labels=False,height=700)#,cmap=cm)
cg=sns.clustermap(snpdist, cmap='Blues', row_colors=X.color,xticklabels=[],figsize=(8,8))
cm=dict(zip(X[scol],X.color))
p=plotting.make_legend(cg.fig, cm, loc=(1.1, .6), title='cluster',fontsize=10)
#X.style.applymap(lambda x: "background-color: %s" %clustering.snp200_cmap[x], subset=[scol])

In [None]:
X.groupby(['snp200','SB']).count().sort_values('Species')[-12:]
#X.to_csv('newclusts.csv')
w=X[X[scol].isin([1,12])].index
W=snpdist.loc[w,w]
#sns.heatmap(W)
X.loc[W.index]

## strain naming from clusters

In [50]:
reload(clustering)
cl,members1 = clustering.get_cluster_levels(S1)
st1 = clustering.generate_strain_names(cl)
cols=['species','strain']#['Species','SB','County']
st1=X1[cols].merge(st1,left_index=True,right_index=True)
st1.to_csv('sub1_strains.csv')

In [51]:
cl,members2 = clustering.get_cluster_levels(S2,members1)
st2 = clustering.generate_strain_names(cl)
st2=X2[cols].merge(st2,left_index=True,right_index=True)
st2.to_csv('sub2_strains.csv')

## ireland data

In [174]:
#ireland data
iresnpdist = pd.read_csv('/storage/btbgenie/all_ireland_results/snpdist.csv',index_col=0)
iremeta = pd.read_csv('/storage/btbgenie/all_ireland_results/metadata.csv')
iretree = '/storage/btbgenie/all_ireland_results/tree.newick'
Xire = iremeta.set_index('sample')[['Species','SB','SB1','County','county1','Year']]
print (len(iresnpdist),len(Xire))

1436 1435


In [231]:
T=50
scol='snp'+str(T)
res = test_cluster_runs(Xire,iresnpdist,T,n=600,steps=[400,500,550])

In [234]:
x=pd.concat(res).reset_index()
p=pd.pivot_table(x,index='sample',columns=['run'],values=scol)
p[:10]

run,1,2,3
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1147,,,7.0
1212,,11.0,11.0
1418,10.0,10.0,10.0
1468,10.0,10.0,10.0
1829,,,13.0
19-1426,7.0,7.0,7.0
19-1601,7.0,7.0,7.0
19-1603,13.0,13.0,13.0
19-1690,7.0,7.0,7.0
19-2438,27.0,27.0,27.0


In [189]:
cl,members = clustering.get_cluster_levels(iresnpdist)
stire = clustering.generate_strain_names(cl)
cols=['Species','SB1','County','county1']
stire=Xire[cols].merge(stire,left_index=True,right_index=True)
#stire.to_csv('new_strains.csv')
stire

Unnamed: 0,Species,SB1,County,county1,snp500,snp200,snp50,snp12,snp3,strain_name,code
1034,Bovine,SB0054,,Other,1,1,9,200,800,ST-1-1-9-800,9375ec3b
13-11594,,SB0054,,,1,1,9,75,82,ST-1-1-9-82,91e9844d
14-MBovis,,SB0054,,,1,1,9,78,93,ST-1-1-9-93,3d698822
15-11643,,SB0054,,,1,1,9,194,103,ST-1-1-9-103,d26f9cfc
17-11662,,SB0054,,,1,1,9,75,539,ST-1-1-9-539,e5e32355
...,...,...,...,...,...,...,...,...,...,...,...
SRR8600250,,Other,,,1,3,13,293,293,ST-1-3-13-293,6031dc6b
SRR8600292,,Other,,,1,3,13,68,327,ST-1-3-13-327,d3add6b2
SRR8600306,,SB0263,,,1,3,13,8,331,ST-1-3-13-331,44afef11
SRR8600308,,SB0263,,,1,3,13,6,255,ST-1-3-13-255,2dab4566
