In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import glob

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('text', usetex=True)
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}') 
import seaborn as sns
import numpy as np
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition,
                                                  mark_inset)
import scipy.stats
import figure_utils as fu
from return_gene_descriptions import return_gene_descriptions

from numba import njit 

import matplotlib.pyplot as plt
import pickle
import pandas as pd
import config
import numpy
import random as rand

from random import randint,sample
from math import log

import sys
import os 
from scipy.spatial.distance import pdist,squareform

import matplotlib.gridspec as gridspec

from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition,
                                                  mark_inset)
from matplotlib.colors import ListedColormap

hap_cmap = ListedColormap(['grey', 'red', 'black', 'black','blue'], 'indexed')



In [None]:
## useful utility for quickly returning the upper triangle of a 2-d array as a 1-d array
def take_triu(df):
    
    N = df.shape[0]
    p=np.triu_indices(N,k=1)
    
    return(df[p])

## returns annotations 
def read_sites(species):
    
    snps_directory = "/u/project/ngarud/Garud_lab/HumanizedMouse/merged_midas_output/snps"
    
    df_sites = pd.read_csv(f"{snps_directory}/{species}/snps_info.txt.bz2",sep="\t",index_col=0,na_values="NaN")
   
    df_sites["contig"] = [d.split("|")[0] for d in df_sites.index]
    df_sites.index = [d.split("|")[1] for d in df_sites.index]
    
    df_sites["gene_id"] = df_sites["gene_id"].fillna("non coding")
    gene_ids = df_sites["gene_id"].values

    gene_breaks = [0]

    gc = gene_ids[0]
    unq_genes = [gc]
    unq_cont = [df_sites["contig"][0]]

    for i,g in enumerate(gene_ids):
        if g is not gc:
            gene_breaks.append(i)
            gc = g
            unq_genes.append(gc)
            unq_cont.append(df_sites["contig"][i])

    gene_breaks = np.array(gene_breaks)       
    gene_lengths = gene_breaks[1:] - gene_breaks[:-1] 

    df_sites.index.set_names("site_pos",inplace=True)
    
    df_sites.set_index('gene_id', append=True, inplace=True)
    df_sites.set_index('contig', append=True, inplace=True)

    df_sites = df_sites.reorder_levels(["contig",'gene_id', 'site_pos'])
    
    level_to_change = 2
    df_sites.index = df_sites.index.set_levels(df_sites.index.levels[level_to_change].astype(int), level=level_to_change)

    return(df_sites)

def read_haplotypes(species,good_samples=None,clade_control=False):
    
    output_dir = "/u/scratch/r/rwolff/LD/HMP"
    haplotype_directory = "%s/haplotypes" % output_dir
    
    hap_files = glob.glob(f"{haplotype_directory}/{species}/*_haplotypes.csv")
    
    idx_cols = ["contig","gene_id","site_pos","site_type"]
    all_haps = []

    for hap in hap_files:
        
        if good_samples is not None and clade_control==False:
            df = pd.read_csv(hap,index_col=[0,1,2,3],usecols = idx_cols + list(good_samples))
            
        elif good_samples is None and clade_control != False:
            df = pd.read_csv(hap,index_col=[0,1,2,3],usecols = idx_cols + list(clade_utils.load_largest_clade(species)))
                  
        else:
            df = pd.read_csv(hap,index_col=[0,1,2,3])

        all_haps.append(df)

    df = pd.concat(all_haps)
    
    return(df)
    

In [None]:
def get_cmap(n, name='Set3_r'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    return plt.cm.get_cmap(name, n)

In [None]:
### use of numba/njit below *substantially* increases computational efficiency

## calculate distances for forward polarization 
@njit
def D_mat_fun1(num,F,D,D_mat):   

    for k in range(num - 1):
        
        O = np.zeros(num)
        
        di = D[k]
        fi = F[k]
        
        for i in range(num - k - 1):

            j = i + k + 1

            fj = F[j]
            dj = D[j]

            O[j] = 2*np.nanmean((di + dj)*((fi - fj)**2)/((fi + fj)*(1 - fi + 1 - fj)))        
        
        D_mat[k] = O
    
    return D_mat

## calculate distances for reverse polarization
@njit
def D_mat_fun2(num,F,D,D_mat_in):   

    for k in range(num - 1):
        
        O = np.zeros(num)       
        di = D[k]
        
        fi = 1-F[k]
        
        for i in range(num - k - 1):

            j = i + k + 1
            
            fj = F[j]
            dj = D[j]

            O[j] = 2*np.nanmean((di + dj)*((fi - fj)**2)/((fi + fj)*(1 - fi + 1 - fj)))        
        
        D_mat_in[k] = O
    
    return D_mat_in

def return_clus(D_mat_close,Fs_sub):
    D_mat_close_sorted_sum = D_mat_close.sum().sort_values()
    desired_idx = D_mat_close_sorted_sum.index[-1]
    clus_idxs = D_mat_close.loc[D_mat_close[desired_idx]].index
    
    ### only return indices which are co-clustered w/ at least .25 of other points
    idxtrue = (D_mat_close.loc[clus_idxs,clus_idxs].T.mean() > .25)
    idxtrue = idxtrue[idxtrue].index
    clus_idxs = idxtrue
    clus = Fs_sub.loc[clus_idxs]
     
    return clus,clus_idxs

def drop_clus_idxs(D_mat_close,clus_idxs):
    D_mat_close_out = D_mat_close.drop(clus_idxs).drop(clus_idxs,axis=1)
    return D_mat_close_out

def polarize_clus(clus,clus_idxs,D_mat_1,D_mat_2):
    
    ## polarize whole cluster based on polarization of first cluster element
    clus_to_pol = clus_idxs[np.where(D_mat_1.loc[clus_idxs[:1],clus_idxs] >= D_mat_2.loc[clus_idxs[:1],clus_idxs])[1]]
    pol_clus = 1 - clus.loc[clus_to_pol]
    clus_non_pol = clus_idxs[np.where(D_mat_1.loc[clus_idxs[:1],clus_idxs] < D_mat_2.loc[clus_idxs[:1],clus_idxs])[1]]
    non_pol_plus = clus.loc[clus_non_pol]
    clus_pol = pd.concat([pol_clus,non_pol_plus],ignore_index=True)
    
    return(clus_pol)

@njit
def symmetrize(D_mat):
    for i in range(D_mat.shape[0]-1):
        for j in range(i,D_mat.shape[0]):
            D_mat[j][i] = D_mat[i][j]
    return(D_mat)

In [None]:
## non-looped version of distance calculation
def calc_dis(di,dj,fi,fj):
    
    return(2*np.nanmean((di + dj)*((fi - fj)**2)/((fi + fj)*(1 - fi + 1 - fj))))


In [None]:
def return_FAD(species,min_coverage=10,min_sample_coverage=5,poly_cov_frac=1/5):

    snp_alignment = pd.read_pickle("%s/%s/%s.strainfinder.p" %  (strainfinder_dir ,species, species))
    samples = pd.read_pickle("%s/%s/%s.strainfinder.samples.p" % (strainfinder_dir ,species, species))
    samples = [s.decode("utf-8") for s in samples]
    snp_locations = pd.read_pickle("%s/%s/%s.strainfinder.locations.p" % (strainfinder_dir,species,species))

    cluster_As = []
    cluster_Ds = []
    for snp_idx in range(0,snp_alignment.shape[1]):

        Ds = snp_alignment[:,snp_idx,:].sum(axis=1)
        As = snp_alignment[:,snp_idx,0]
        As = np.reshape(As, (1,len(As)))
        Ds = np.reshape(Ds, (1,len(Ds)))

        cluster_As.append(As[0])
        cluster_Ds.append(Ds[0])

    cluster_As = np.array(cluster_As)
    cluster_Ds = np.array(cluster_Ds)

    As = pd.DataFrame(cluster_As,columns=samples,index=snp_locations)
    Ds = pd.DataFrame(cluster_Ds,columns=samples,index=snp_locations)

    F = As/Ds

    Ass = As
    Dss = Ds.loc[Ass.index]
    Ass = Ass.mask(Dss < min_coverage)
    
    Fs = Ass/Dss

    samps = Dss.mean() > min_sample_coverage
    samps = samps[samps].index
    Dss = Dss[samps]
    Ass = Ass[samps]
    Fs = Fs[samps]

    Fs = Fs.loc[Fs.mask(Ass == 0).mask(Dss < min_coverage).notna().T.sum() > int(Fs.shape[1]*poly_cov_frac)]

    Ass = Ass.loc[Fs.index]
    Dss = Dss.loc[Fs.index]

    mnum = [f[:2] for f in Fs.columns]
    msite = [f[2:].split("_")[0][:-1] for f in Fs.columns]
    mdiet = [f[2:].split("_")[0][-1] for f in Fs.columns]
    Fs = Fs.T
    Fs["mouse_number"] = mnum
    Fs.set_index('mouse_number', append=True, inplace=True)
    Fs["region"] = msite
    Fs.set_index('region', append=True, inplace=True)
    Fs["diet"] = mdiet
    Fs.set_index('diet', append=True, inplace=True)

    Fs.index.names = ["sample","mouse_number","region","diet"]
    Fs = Fs.reorder_levels(["mouse_number","region","diet","sample"])
    Fs = Fs.T

    Ass.columns = Fs.columns
    Dss.columns = Fs.columns

    Fs = Fs.sort_index(level="mouse_number",axis=1)
    Dss = Dss.sort_index(level="mouse_number",axis=1)
    Ass = Ass.sort_index(level="mouse_number",axis=1)

    snv_idx = pd.MultiIndex.from_tuples(Fs.index,names=["contig","site_pos","ref/alt"])
    Fs.index = snv_idx
    Fs = Fs.droplevel("ref/alt")

    Ass.index = Fs.index
    Dss.index = Fs.index

    Fs=Fs.sort_index(level=["contig",'site_pos'])
    Ass=Ass.sort_index(level=["contig",'site_pos'])
    Dss=Dss.sort_index(level=["contig",'site_pos'])

    C_list = np.unique(Fs.index.get_level_values("contig"))

    all_site_pos = []
    offset = 0
    for C in C_list:

        all_site_pos.extend(Fs.loc[C].index + offset)

        offset += Fs.loc[C].index[-1]

    Fs['all_site_pos'] = all_site_pos
    Fs.set_index('all_site_pos', append=True, inplace=True)
    Ass['all_site_pos'] = all_site_pos
    Ass.set_index('all_site_pos', append=True, inplace=True)
    Dss['all_site_pos'] = all_site_pos
    Dss.set_index('all_site_pos', append=True, inplace=True)

    return(Fs,Ass,Dss)

In [None]:
## defines an order for each level type
## M1 --> M6
## Upper gut --> lower gut
## Control --> Guar gum
## Co-housing treatment 1 --> co-housing treatment 3
order_dict = {"M1":0,"M2":1,"M3":2,"M4":3,"M5":4,"M6":5,
              'D': 0, 'J': 1, 'I': 2,"Ce":3,"Co":4,
              "C":0,"G":1,
              "C1":0,"C2":1,"C3":2} 

## function sorts our multiindex of frequencies according to whatever key we specify, e.g. mouse_number
## while maintaining the order of subsequent levels according to order_dict
def reorder_sort(df,first_idx,order_dict=order_dict):
    
    reorder=list(Fs.T.index.names)
    reorder.remove(first_idx)
    reorder.insert(0, first_idx)
    return df.reorder_levels(reorder).sort_index(key=lambda x: x.map(order_dict))


In [None]:
!ls /u/project/ngarud/michaelw/PaulAllen/humanized_mouse/strainfinder/input/

In [None]:
species = "Clostridiales_bacterium_61057"

strainfinder_dir = "/u/project/ngarud/michaelw/PaulAllen/humanized_mouse/strainfinder/input/"


In [None]:
## Meta-parameters: experiment with these—no hard and fast rules!

## minimum number of SNVs which need to be clustered together in order to qualify as a "strain"
## if we didn't cap max_num_snvs, then min_cluster_size would be O(10^4), based on the typical evolutionary
## distance between strains
min_cluster_size = 1000

## minimum fraction of sites which pass our coverage threshold which must be in a cluster in order for it to qualify 
## as a strain
## basically, the idea is that as the initial number of sites we pass in gets bigger, we want to incrwease the min_cluster_size
## here, we say that 10% of all variable sites must be in a cluster in order for it to be considered a "strain"
## this will largely be redundant w/ min_cluster_size, but adds some more functionality to play with
min_cluster_fraction = 1/10

## For computational efficiency, we can downsample the SNVs we actually perform strain phasing on
## should still give us the same strain trajectory 
## clustering 20k SNVs takes ~90 seconds. 
max_num_snvs = 20000

## distance threshold to be considered linked—lower means trajectories have to be more 
## similar, higher means less similar, to be in a cluster
max_d = 3.5

## minimum coverage to consider allele frequency at a site for purposes of clustering
min_coverage = 10

## minimum average sample coverage at polymorphic sites (e.g. sites in the A/D matrices)
min_sample_coverage = 5

## polymorphic & covered fraction: what percentage of samples does a site need 
## with coverage > min_coverage and polymorphic to be included in downstream analyses? 
## NOTE: we may want to disaggregate coverage and polymorphic-ness so as to not lose evolutionary snvs
## but for strain clustering purposes, I think we should focus on SNVs that are actually polymorphic
## in a good number of samples
poly_cov_frac = 1/5


In [None]:
Fs,Ass,Dss = return_FAD(species, min_coverage=min_coverage, min_sample_coverage=min_sample_coverage, poly_cov_frac = poly_cov_frac)


In [None]:
Fs

In [None]:
%%time


fss = Ass.values/(Dss.values + (Dss.values == 0))

cluster_As = Ass.values
cluster_Ds = Dss.values
cluster_fs = cluster_As/(cluster_Ds + (cluster_Ds == 0))

## for compatibility in case of threshold number of SNVs
num = min(max_num_snvs,Fs.shape[0])

i_list = Dss.T.mean().sort_values(ascending=False).index[:num]

sys.stderr.write("Processing %s SNVs" % num)

## simply shuffles indices if no threshold is specified
#i_list = sample(range(Fs.shape[0]),num)
i_list_idx = Fs.loc[i_list].index

Ass_sub = Ass.loc[i_list_idx]
Dss_sub = Dss.loc[i_list_idx]
Fs_sub = Fs.loc[i_list_idx]

fss_sub = Ass_sub.values/(Dss_sub.values + (Dss_sub.values == 0))

cluster_As_sub = Ass_sub.values
cluster_Ds_sub = Dss_sub.values
cluster_fs_sub = cluster_As_sub/(cluster_Ds_sub + (cluster_Ds_sub == 0))

D_mat = np.zeros([num,num])
D_mat_1 = D_mat_fun1(num,fss_sub,cluster_Ds_sub,D_mat)
D_mat = np.zeros([num,num]) 
D_mat_2 = D_mat_fun2(num,fss_sub,cluster_Ds_sub,D_mat)

D_mat = np.fmin(D_mat_1,D_mat_2)
D_mat = symmetrize(D_mat)

D_mat_1 = pd.DataFrame(D_mat_1,index=Fs_sub.index,columns=Fs_sub.index)
D_mat_2 = pd.DataFrame(D_mat_2,index=Fs_sub.index,columns=Fs_sub.index)

D_mat_close = pd.DataFrame(D_mat < max_d) 

D_mat_close.index = Fs_sub.index
D_mat_close.columns = Fs_sub.index


In [None]:
## extracts up to 100 clusters
## in practice all SNVs should fall into one of a fairly small number of clusters
## really should re-write this with a while loop but this works for now
## the idea is that we exhaust all clusters—there should only be a small number of them ultimately

all_clus_pol = []
all_clus_idx = []
all_clus_A = []
all_clus_D = []

all_clus_F = []


for i in range(100):
    
    try:
        
        clus,clus_idxs = return_clus(D_mat_close,Fs_sub)
        clus_pol = polarize_clus(clus,clus_idxs,D_mat_1,D_mat_2)
        clus_pol.index = clus_idxs
        D_mat_close = drop_clus_idxs(D_mat_close,clus_idxs)
        
        if clus_pol.shape[0] > min_cluster_size and clus_pol.shape[0] > Fs.shape[0]*min_cluster_fraction:
            
            all_clus_D.append(Dss.loc[clus.index].mean().values)
            all_clus_pol.append(clus_pol)
            all_clus_A.append(clus_pol.mean()*all_clus_D[-1])
            all_clus_F.append(clus_pol.mean())

            print(clus_pol.shape[0])
            
    except:
        pass
    

In [None]:
## now, choosing a representative SNV from each cluster, and finding all other sites (not just limited to the 20k)
## which are consistent w/ being linked to it

final_clusters = []

all_aligned_sites = []

for i in range(len(all_clus_D)):
    
    sys.stderr.write(f'\n\nCluster {i+1}\n')
    ancD = all_clus_D[i]
    ancF = all_clus_F[i]

    dss = Dss.values
    fss = Fs.values
    
    disAnc_forward = []
    disAnc_backward = []

    for i in range(Dss.shape[0]):
        disAnc_forward.append(calc_dis(ancD,dss[i],ancF,fss[i]))
        disAnc_backward.append(calc_dis(ancD,dss[i],ancF,1-fss[i]))
        if i % 1000 == 0:
            sys.stderr.write(f"\n\t{np.around(100*i/Dss.shape[0],3)}% finished")
    
    disAnc = [min(els) for els in zip(disAnc_forward, disAnc_backward)]
    disAnc = np.array(disAnc)
    aligned_sites = Fs.loc[disAnc < max_d].index
    f_dist =  pd.DataFrame(np.array([disAnc_forward,disAnc_backward]).T,index=Fs.index)
    pols = f_dist.T.idxmin() > 0
    
    aligned_sites = [a for a in aligned_sites if a not in all_aligned_sites]
    
    pols = pols.loc[aligned_sites]
    re_polarize = pols.loc[pols].index
    
    all_aligned_sites.extend(aligned_sites)
    
    Fs_cluster = Fs.loc[aligned_sites]
    
    Fs_cluster.loc[re_polarize] = 1 - Fs_cluster.loc[re_polarize]
        
    final_clusters.append(Fs_cluster)

## Creating polarized clusters

Once clusters have been identified and internally polarized, they need to be polarized relative to one another. In the best case, the sum of strain frequencies will be 1. 

In [None]:
## If only a single cluster is detected, add a second "cluster" which is simply 1 minus the allele frequencies
## in the first cluster
## aids in visualization for people not familiar with this kind of clustering
if len(final_clusters) == 1:
    final_clusters.append(1-final_clusters[0])
    

## add cluster centroids
final_f = []
for cluster in final_clusters:
    final_f.append(cluster.mean())
df_final_f = pd.DataFrame(final_f)

## now, polarize clusters so that the sum of squareds of the centroids to 1 is minimized
## the idea here is that accurate strain frequencies should sum to 1
polarize = True

if polarize:

    pol_d2 = {}

    for i in range(df_final_f.shape[0]):
        df_final_f_temp = df_final_f.copy()
        df_final_f_temp.iloc[i] = 1 - df_final_f_temp.iloc[i]
        pol_d2[i] =  ((1 - df_final_f_temp.sum())**2).sum()

    pol_d2 = pd.Series(pol_d2)

    if pol_d2.min() < ((1 - df_final_f.sum())**2).sum():
        clus_to_re_pol = pol_d2.idxmin()
        final_f[clus_to_re_pol] = 1 - final_f[clus_to_re_pol]
        final_clusters[clus_to_re_pol] = 1 - final_clusters[clus_to_re_pol]
        df_final_f = pd.DataFrame(final_f) 

In [None]:
## plot the chosen polarization of strains
## sum of strain frequencies should be ~1 at all timepoints

fig,ax = plt.subplots(figsize=(12,8))
ax.plot(pd.DataFrame(final_f).sum().values,zorder=10,lw=3)
ax.set_ylim([.5,1.5])
ax.axhline(1,color="k",ls="--")
ax.set_ylabel("Sum of strain frequencies",size=20)



## Plotting strain frequencies

Strain frequencies can be plotted using a main key (e.g. mouse_number) and a secondary key (e.g. region), yielding a two-level identification of each sample

In [None]:
## More ordering utilities

mnum = list(set(Fs.T.index.get_level_values("mouse_number")))
msite = list(set(Fs.T.index.get_level_values("region")))
mdiet = list(set(Fs.T.index.get_level_values("diet")))

mnum_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"mouse_number").index.get_level_values("mouse_number") == m).ravel() for m in list(set(mnum))}

msite_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"region").index.get_level_values("region") == m).ravel() for m in list(set(msite))}

mdiet_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"diet").index.get_level_values("diet") == m).ravel() for m in list(set(mdiet))}

all_sample_dics = {"diet":mdiet_sample_dic,"region":msite_sample_dic,"mouse_number":mnum_sample_dic}


In [None]:
key_to_sort = "mouse_number"
secondary_key = "region"

cmap = get_cmap(len(list(set(mnum))))
cmap_clus = get_cmap(len(list(set(mnum))),name="Set3")

fig,ax = plt.subplots(figsize=(16,8))
fig.suptitle(fu.get_pretty_species_name(species),size=30)

for i,f in enumerate(final_f):
    
    ff = reorder_sort(f,key_to_sort).sort_index(key=lambda x: x.map(order_dict))
    ff_c = reorder_sort(final_clusters[i].T,key_to_sort).T.sort_index(key=lambda x: x.map(order_dict),axis=1)
    
    ax.plot(ff.values,zorder=100,lw=6,color=cmap_clus(i),label=f"Strain {i+1}");
    ax.plot(ff.values,zorder=80,lw=7,color="k");
    ax.plot(ff_c.sample(min(ff_c.shape[0],10000)).T.values,color=cmap_clus(i),alpha=.01)
        
major_x = []
minor_x = []
labels = []

second_xlabels = list(ff.index.get_level_values(secondary_key))

i = 0
for key, item in all_sample_dics[key_to_sort].items():
        
    xmin = item.min() 
    xmax = item.max()
    ax.axvspan(xmin - .1,xmax+.1,alpha=.2,color=cmap(i))
    
    for e in item:
        ax.axvline(e,color="k",zorder=0,alpha=.5)
        ax.text(e, -0.1, second_xlabels[e], ha='center', clip_on=False,size=15)
        
    if xmin != xmax:
        major_x.extend([xmin,(xmax + xmin)/2,xmax])
        minor_x.append((xmax + xmin)/2)
        labels.extend(["",key,""])
    else:
        major_x.append(xmin)
        minor_x.append(xmax)
        labels.extend([key])   
        
    i+=1

    ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())
    ax.vlines(item[-1] + 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())

ax.set_xticks(major_x)
ax.set_xticks(minor_x, minor = True)
ax.set_xticklabels(labels);

ax.axhline(0,color="grey")
ax.axhline(1,color="grey")

ax.tick_params(axis = 'x', which = 'major', length=0,labelsize = 20,pad=45)
ax.tick_params(axis = 'x', which = 'minor', length = 10,labelsize = 0)
    
ax.set_ylabel("SNV frequency",size=20)
ax.set_ylim([-0.05,1.05]);

fig.legend(prop={"size":20});

In [None]:
fig.savefig(f"/u/project/ngarud/michaelw/PaulAllen/humanized_mouse/figures/strain_phasing/{species}_strains_minclustersize{min_cluster_size}.png", facecolor='white', transparent=False)

## Plot locations of SNVs and SFS's

Plotting the physical location of SNVs can help to identify instances of HGT. Lots of SNVs which lie very close to one another are a good indication of possible HGT, though synteny issues of course persist

In [None]:
fig = plt.figure(constrained_layout=True,figsize=(32,12))

gs = fig.add_gridspec(2, 16)

ax1 = fig.add_subplot(gs[0, :-2])
ax1_H = fig.add_subplot(gs[0, -2:],sharey=ax1)

ax2 = fig.add_subplot(gs[1, :-2])
ax2_H = fig.add_subplot(gs[1, -2:],sharey=ax2)

alpha=1
B = np.linspace(0,1,30)


ax1.grid(True)
ax2.grid(True)

ax1.tick_params(axis='both', which='major', labelsize=20)
ax2.tick_params(axis='both', which='major', labelsize=20)

ax1_H.tick_params(axis='both', which='major', labelsize=0)
ax2_H.tick_params(axis='both', which='major', labelsize=0)

sample_1 = 1

for k,s in enumerate(range(len(final_clusters))):
    
    ax1.scatter(final_clusters[s].index.get_level_values("all_site_pos"),final_clusters[s].iloc[:,sample_1],alpha=alpha,color=cmap_clus(s),zorder=k)
    ax1_H.hist(final_clusters[s].iloc[:,sample_1],density=True,alpha=.8,color=cmap_clus(s), orientation='horizontal',label=f"Strain {s+1}",bins=B)
    
ax1.set_ylim([-0.05,1.05])
ax1.set_xticklabels([])

## can zoom in on a specific region using these ll (lower limit), ul (upper limit), and off (offset) variables
# ll = 53799
# ul = 53799
# off = 10000

# ax1.set_xlim([ll-off,ul+off])
# ax2.set_xlim([ll-off,ul+off])




sample_2 = 7

for k,s in enumerate(range(len(final_clusters))):
    ax2.scatter(final_clusters[s].index.get_level_values("all_site_pos"),final_clusters[s].iloc[:,sample_2],alpha=alpha,color=cmap_clus(s),zorder=k)
    ax2_H.hist(final_clusters[s].iloc[:,sample_2],density=True,alpha=.8,color=cmap_clus(s), orientation='horizontal',bins=B)

# ax2.scatter(final_clusters[s1].index.get_level_values("all_site_pos"),final_clusters[s1].iloc[:,i],alpha=.4,color=cmap(s1))
# ax2.scatter(final_clusters[s2].index.get_level_values("all_site_pos"),final_clusters[s2].iloc[:,i],alpha=.4,color=cmap(s2),zorder=10)
# ax2.scatter(final_clusters[s3].index.get_level_values("all_site_pos"),final_clusters[s3].iloc[:,i],alpha=.4,color=cmap(s3),zorder=20)




ax2.set_ylim([-0.05,1.05])

ax2.set_xlabel("Genomic position (bp)",size=30)

fig.text(0.075,0.42,"SNV frequency",size=30,rotation=90)

ax2_H.set_xlabel("Site frequency\nspectra",size=30)

ax1_H.set_xticks([],[])

ax1_H.tick_params(labelleft=False)
ax2_H.tick_params(labelleft=False)

ax1_H.xaxis.set_major_locator(plt.NullLocator())
ax1_H.xaxis.set_minor_locator(plt.NullLocator())

ax2_H.xaxis.set_major_locator(plt.NullLocator())
ax2_H.xaxis.set_minor_locator(plt.NullLocator())

fig.legend(prop={"size":20},bbox_to_anchor=(.98,.89))
plt.subplots_adjust(wspace=0.15, hspace=0.01);

#fig.tight_layout()

In [None]:
no_of_samples = len(final_clusters[0].columns.get_level_values(3))
#no_of_samples = 3
ax = []
ax_H = []

fig = plt.figure(constrained_layout=True,figsize=(32,6*no_of_samples))

gs = fig.add_gridspec(no_of_samples, 16)

for sample in range(no_of_samples):
    
    ax_temp = fig.add_subplot(gs[sample, :-2])
    ax.append(ax_temp)
    ax_H_temp = fig.add_subplot(gs[sample, -2:],sharey=ax[sample])
    ax_H.append(ax_H_temp)

    alpha=1
    B = np.linspace(0,1,30)


    ax[sample].grid(True)

    ax[sample].tick_params(axis='both', which='major', labelsize=20)

    ax_H[sample].tick_params(axis='both', which='major', labelsize=0)

    for k,s in enumerate(range(len(final_clusters))):

        ax[sample].scatter(final_clusters[s].index.get_level_values("all_site_pos"),final_clusters[s].iloc[:,sample],alpha=alpha,color=cmap_clus(s),zorder=k)
        ax_H[sample].hist(final_clusters[s].iloc[:,sample],density=True,alpha=.8,color=cmap_clus(s), orientation='horizontal',label=f"Strain {s+1}",bins=B)
        ax[sample].set_title(final_clusters[s].columns.get_level_values("mouse_number")[sample] + " " +final_clusters[s].columns.get_level_values("region")[sample], fontsize=20)  # Add title with the sample label

        
        
    ax[sample].set_ylim([-0.05,1.05])
    if sample != (no_of_samples-1):
        ax[sample].set_xticklabels([])
    ax_H[0].set_xticks([],[])
    ax_H[sample].tick_params(labelleft=False)
    ax_H[sample].xaxis.set_major_locator(plt.NullLocator())
    ax_H[sample].xaxis.set_minor_locator(plt.NullLocator())




    ## can zoom in on a specific region using these ll (lower limit), ul (upper limit), and off (offset) variables
    # ll = 53799
    # ul = 53799
    # off = 10000

    # ax1.set_xlim([ll-off,ul+off])
    # ax2.set_xlim([ll-off,ul+off])

ax[sample].set_xlabel("Genomic position (bp)",size=30)

fig.text(0.075,0.42,"SNV frequency",size=30,rotation=90)

ax_H[sample].set_xlabel("Site frequency\nspectra",size=30)

fig.legend(prop={"size":20},bbox_to_anchor=(.98,.89))
plt.subplots_adjust(wspace=0.15, hspace=0.15)

In [None]:
fig.savefig(f"/u/project/ngarud/michaelw/PaulAllen/humanized_mouse/figures/strain_phasing/{species}_genomiclocus_minclustersize{min_cluster_size}.png",
            facecolor='white', transparent=False, dpi=300, bbox_inches='tight', pad_inches=0.5)



In [None]:
## Plotting 2-D sfs w/ marginal 1-D sfs's

sample_1 = 0
sample_2 = 5
g = sns.JointGrid(Fs.iloc[:,sample_1].values,Fs.iloc[:,sample_2].values,
                  height=8, ratio=5, space=.05,ylim=(-.05, 1.05),xlim=(-.05, 1.05))

g.plot_joint(sns.scatterplot, s=25, alpha=.1)
g.plot_marginals(sns.histplot, kde=False,bins=25)

## Identifying potential anchors
There are lots of ways to find potentially evolutionarily interesting SNVs. A few that I came up with are:
* Finding SNVs whose mean frequency differs most by some level (e.g. diet)
* Finding SNVs with particularly high variance relative to some level

Identifying potential anchors using some strategies is just step 1. Step 2 involves sorting through what comes out and actually finding SNVs that have some behavior we might be interested in (e.g. looks like there's a sweep between samples etc.)

Michael's ideas:
- subset by mouse, and see if all SNVs always fall into the same clusters
- Identify SNVs that are outside of 2 SDs from the cluster mean in ANY sample
    - maybe further filter out sites if they don't have some number of sites that are traveling under tight linkage with the anchor SNV

## method 1: Look for frequency differences across levels

In [None]:
## First, pull out the SNV frequency matrices, but this time w/ more/less permissive filtering if desired

Fs,Ass,Dss = return_FAD(species,poly_cov_frac = 1/2)


In [None]:
## frequency differs most by diet
top_n = 25
var_sites =(Fs.T.xs("C",level="diet").mean() - Fs.T.xs("G",level="diet").mean()).abs().dropna().sort_values(ascending=False).index[:top_n]

## frequency has high variance across mice
#var_sites = Fs.T.groupby("mouse_number").mean().var().sort_values(ascending=False)[:top_n].index

## Fs_sites is the dataframe of evolutionary SNVs
## can do a primitive polarization of them all to relative to some sample (2nd line)
## or come up w/ your own polarization? 
Fs_sites = Fs.loc[var_sites]
# Fs_sites.loc[Fs_sites.iloc[:,0] > 0.5] = 1 - Fs_sites.loc[Fs_sites.iloc[:,0] > 0.5]


In [None]:
Fs_sites

In [None]:
## trial and error produces a potentially interesting site
Fs_sites = Fs.loc[var_sites[0]].reset_index()

Fs_sites['label'] = Fs_sites["mouse_number"] + ", " + Fs_sites["region"] 

Fs_sites = Fs_sites.set_index("label")


In [None]:
Fs_sites.plot(color="grey",legend=None)

### method 2: identify SNVs that fall outside of 2 SDs from the cluster mean

In [None]:
Fs,Ass,Dss = return_FAD(species,poly_cov_frac = 1/2)

## Identifying evolutionary SNVs using anchors
If we have identified a SNV that looks like a promising evolutionary target, we can find all SNVs that are clustered with it using the following

In [None]:
## B vulgatus empirically found anchors
# ancD = Dss.loc[('NC_009614', 663753)].values
# ancF = Fs.loc[('NC_009614', 663753)].values
# ancD = Dss.loc[('NC_009614', 1854143)].values
# ancF = Fs.loc[('NC_009614', 1854143)].values


## B. wex anchors
# ancD = Dss.loc[("AXVN01000080",  11540)].values
# ancF = Fs.loc[("AXVN01000080",  11540)].values

## B. uni anchors
# ancD = Dss.loc[('NZ_DS362247', 176768)]
# ancF = Fs.loc[('NZ_DS362247', 176768)]

## C. bac anchors
# ancD = Dss.loc[('NZ_DS362247', 176768)]
# ancF = Fs.loc[('NZ_DS362247', 176768)]
ancD = Dss.loc[('FP929062', 2094823)]
ancF = Fs.loc[('FP929062', 2094823)]

In [None]:
## can play around w/ max_d again here to find particularly tightly linked sites
max_d = 2.5

In [None]:
dss = Dss.values
fss = Fs.values

disAnc_forward = []
disAnc_backward = []
for i in range(Dss.shape[0]):
    disAnc_forward.append(calc_dis(ancD,dss[i],ancF,fss[i]))
    disAnc_backward.append(calc_dis(ancD,dss[i],ancF,1-fss[i]))
    
disAnc = [min(els) for els in zip(disAnc_forward, disAnc_backward)]
disAnc = np.array(disAnc)
var_sites = Fs.loc[disAnc < max_d].index


In [None]:
## very simple polarization w/r/t a given sample
Fs_sites = Fs.loc[var_sites]

samp_to_pol = 5
Fs_sites.loc[Fs_sites.iloc[:,samp_to_pol] > 0.5] = 1 - Fs_sites.loc[Fs_sites.iloc[:,samp_to_pol] > 0.5]


In [None]:
fig,ax = plt.subplots(figsize=(16,8))

fig.suptitle(fu.get_pretty_species_name(species),size=30)


key_to_sort = "mouse_number"
secondary_key = "region"

cmap = get_cmap(len(list(set(mnum))))
cmap_clus = get_cmap(len(list(set(mnum))),name="Set3")


Fs_sites_plot = reorder_sort(Fs_sites.T,key_to_sort).sort_index(key=lambda x: x.map(order_dict))

ax.plot(Fs_sites_plot.values,color="red",alpha=.5,zorder=500);

for i in range(Fs_sites.shape[0]):
    ax.scatter(range(Fs_sites.shape[1]),Fs_sites_plot.values[:,i],color="red",alpha=.5,zorder=500,s=90);


for i,f in enumerate(final_f):
    
    ff = reorder_sort(f,key_to_sort).sort_index(key=lambda x: x.map(order_dict))
    ff_c = reorder_sort(final_clusters[i].T,key_to_sort).T.sort_index(key=lambda x: x.map(order_dict),axis=1)
    
    ax.plot(ff.values,zorder=100,lw=6,color=cmap_clus(i),label=f"Strain {i+1}");
    ax.plot(ff.values,zorder=80,lw=7,color="k");
    ax.plot(ff_c.sample(min(ff_c.shape[0],10000)).T.values,color=cmap_clus(i),alpha=.01)
        
major_x = []
minor_x = []
labels = []

second_xlabels = list(ff.index.get_level_values(secondary_key))

i = 0
for key, item in all_sample_dics[key_to_sort].items():
        
    xmin = item.min() 
    xmax = item.max()
    ax.axvspan(xmin - .1,xmax+.1,alpha=.2,color=cmap(i))
    
    for e in item:
        ax.axvline(e,color="k",zorder=0,alpha=.5)
        ax.text(e, -0.1, second_xlabels[e], ha='center', clip_on=False,size=15)
        
    if xmin != xmax:
        major_x.extend([xmin,(xmax + xmin)/2,xmax])
        minor_x.append((xmax + xmin)/2)
        labels.extend(["",key,""])
    else:
        major_x.append(xmin)
        minor_x.append(xmax)
        labels.extend([key])   
        
    i+=1

    ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())
    ax.vlines(item[-1] + 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())

ax.set_xticks(major_x)
ax.set_xticks(minor_x, minor = True)
ax.set_xticklabels(labels);

ax.axhline(0,color="grey")
ax.axhline(1,color="grey")

ax.tick_params(axis = 'x', which = 'major', length=0,labelsize = 20,pad=45)
ax.tick_params(axis = 'x', which = 'minor', length = 10,labelsize = 0)
    
ax.set_ylabel("SNV frequency",size=20)
ax.set_ylim([-0.05,1.05]);

fig.legend(prop={"size":20});


#fig.savefig(f"figures/strains/{species}_strains")


## Plotting locations of evolutionary SNVs
Same procedure as above, but now instead of plotting just locations of SNVs segregating between relative to one another, show particularly the evolutionary SNVs relative to a strain background

In [None]:
fig,axs = plt.subplots(2,1,figsize=(18,12))

axs = axs.ravel()
ax1 = axs[0]
ax2 = axs[1]
ax1.grid(True)
ax2.grid(True)

i = 3
s1 = 0
ax1.scatter(all_clus_pol[s1].index.get_level_values("all_site_pos"),all_clus_pol[s1].iloc[:,i],alpha=.3,color="grey")
ax1.scatter(Fs_sites.index.get_level_values("all_site_pos"),Fs_sites.iloc[:,i],color="tomato",s=220,edgecolor="yellow")

ax1.set_ylim([-0.05,1.05])

# ll = 11678
# ul = 11678
# off = 2500
# ax1.set_xlim([ll-off,ul+off])
# ax2.set_xlim([ll-off,ul+off])

i = 0
ax2.scatter(all_clus_pol[s1].index.get_level_values("all_site_pos"),all_clus_pol[s1].iloc[:,i],alpha=.3,color="grey")
ax2.scatter(Fs_sites.index.get_level_values("all_site_pos"),Fs_sites.iloc[:,i],color="tomato",s=220,edgecolor="yellow")

ax2.set_ylim([-0.05,1.05])

ax2.set_xlabel("Genomic position (bp)",size=30)

fig.text(-0.05,0.44,"SNV frequency",size=30,rotation=90)

# fig.text(.775,.0925,"Mouse 1",size=45,rotation=0,color="darkslategrey")
# fig.text(.775,.575,"Mouse 2",size=45,rotation=0,color="darkslategrey")


fig.tight_layout()

## Evolutionary SNVs: functions
Now, pull out the functions of potentially interesting SNVs

In [None]:
## gene description and snps info files are produced using the write_snps_info.py script

gene_descriptions = pd.read_pickle(f"/u/project/ngarud/rwolff/mouse_sites/gene_descriptions/{species}_gene_descriptions.pkl")

snps_info = pd.read_pickle(f"/u/project/ngarud/rwolff/mouse_sites/snps_info/{species}_snps_info.pkl")
snps_info = snps_info.reset_index("gene_id")




In [None]:
## all site pos shows overall position in genome, but is not included in the snps_info file
var_sites_xref = var_sites.droplevel("all_site_pos")

In [None]:
## which genes are evolutionarily interesting snvs in? 

gene_descriptions.loc[snps_info.loc[var_sites_xref].groupby("gene_id").size().sort_values(ascending=False).index][:20]

In [None]:
## how many evolutionary snvs are in each gene? 

snps_info.loc[var_sites_xref].groupby("gene_id").size().sort_values(ascending=False)


In [None]:
## is each SNV syn or non-syn? 

for x in var_sites:
    st = snps_info.loc[[x]].site_type.values[0]
    if st == "1D":
        g = snps_info.loc[[x]].gene_id.values[0]
        print(g + ": (non-syn) " + gene_descriptions.loc[g] + "\n")
    else:
        g = snps_info.loc[[x]].gene_id.values[0]
        print(g + " (syn) : " + gene_descriptions.loc[g] + "\n")

In [None]:
## count up # of snvs by site type
np.unique(snps_info.loc[var_sites].site_type,return_counts=True)

## Experimental: comparing w/ HMP haplotypes
It could potentially be interesting to look at whether the strains we see in the mice have near relative haplotypes in HMP (my preliminary testing says maybe? for some species? needs work!). If so, where are the differences evolutionarily? Can we spot mouse-specific adaptations by comparing w/ HMP? 

Can look at whether strains pop up in HMP, or at linkage/allele frequencies of evolutionary SNVs, or other stuff. Basically, any kind of question where it might be interesting to have a broader cohort of controls

In [None]:
dfA = read_haplotypes(species)

In [None]:
F_rec = []

for gene_id in list(set(snps_info.loc[var_sites].gene_id))[:15]:

    focal_sites  = 1-(Fs.droplevel("all_site_pos").loc[snps_info.gene_id == gene_id]).T
    F_rec.append(focal_sites)

In [None]:
focal_sites = pd.concat(F_rec)

In [None]:
dfA_focal = dfA.droplevel(["gene_id","site_type"])

dfA_focal = dfA_focal.loc[[f for f in focal_sites.columns if f in dfA_focal.index]]

In [None]:
# dfA_focal =  dfA.xs(gene_id,level="gene_id").droplevel("site_type")


# df_gene = dfA_focal.loc[[g for g in focal_sites.columns if g in dfA_focal.index]]

In [None]:
df_gene = dfA_focal

In [None]:
D = squareform(pdist(df_gene.T,metric="hamming"))
D = pd.DataFrame(D,index=df_gene.columns,columns=df_gene.columns)
#thresh = 10
thresh = np.percentile(take_triu(D.values)[take_triu(D.values) > 0],20)

order = []
group_centroids = []

D_c = D < thresh
while D_c.shape[0]>0:

    D_c_o = D_c.sum().sort_values(ascending=False)
    group_centroid = D_c_o.index[0]
    D_c_i = D_c.loc[group_centroid]
    group_centroids.append(group_centroid)
    group = D_c_i.loc[D_c_i].index
    
    group = D.loc[group_centroid,group].sort_values().index
    
    order.append(list(group))
    D_c = D_c.drop(group,axis=0).drop(group,axis=1)
    
order = np.array(order)
order = sum(list(order[np.argsort(D.loc[group_centroids[0],group_centroids]).values]),[])

dfplot = df_gene.copy()

dfplot.loc[(snps_info.loc[df_gene.index] != "1D").site_type] = 4*dfplot.loc[(snps_info.loc[df_gene.index] != "1D").site_type]
dfplot = dfplot[order]

In [None]:
fig,ax = plt.subplots(figsize=(20,14))

sns.heatmap(dfplot.T,cmap=hap_cmap,ax=ax,linecolor='k',cbar=False)

ax.set_xticklabels([]);

#ax.set_title(gene_id,size=60,color="tomato")


In [None]:
snps_info.loc[df_gene.T.mean().sort_values(ascending=False).index]