In [1]:
# Setting working directory
import sys
sys.path.insert(0, "/u/project/ngarud/michaelw/Diversity-Along-Gut/ConventionalMouse/scripts/postprocessing/postprocessing_scripts/")
sys.path.insert(0, "/u/project/ngarud/michaelw/Diversity-Along-Gut/ConventionalMouse/scripts/strain_phasing/")

# Normal Libraries
import pandas as pd
import numpy as np
import pickle
import scipy.stats
import random as rand
from random import randint,sample,choices
from math import log
import os 
from datetime import datetime

# Plotting libraries
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from matplotlib.lines import Line2D
import figure_utils as fu
from matplotlib import colormaps
plasma_cmap = colormaps.get_cmap('plasma')

# config
import config

# predefined functions
from strain_phasing_functions import *
from parse_midas_data import *
import diversity_utils
import species_utils

In [2]:
species_path = "%s%s" % (config.metadata_directory, "species_snps.txt")
with open(species_path, 'r') as file:
    species_list = [line.strip() for line in file]

species_code_map = species_utils.parse_species_code_maps()[0]

sample_metadata_map = parse_sample_metadata_map()


In [3]:
species = "229547"

## Directories

In [4]:
strainfinder_dir = "%sinput" % (config.strain_phasing_directory)
#Raw cluster
# raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
species_raw_cluster_dir = "%s%s/" % (raw_cluster_path, species)
if not os.path.exists(species_raw_cluster_dir):
    os.makedirs(species_raw_cluster_dir)
    print("Cluster directory created successfully!")
else:
    print("Cluster directory already exists.")

strain_phasing_figures_dir = "%s%s" % (config.figure_directory, "strain_phasing/")
if not os.path.exists(strain_phasing_figures_dir):
    os.makedirs(strain_phasing_figures_dir)
    print("Figure directory created successfully!")
else:
    print("Figure directory already exists.")


Cluster directory already exists.
Figure directory already exists.


## Selecting host

In [5]:
qp_array = []
for species in species_list:
    highcoverage_samples = diversity_utils.calculate_highcoverage_samples(species)
    if len(highcoverage_samples) == 0:
        continue
    haploid_samples = diversity_utils.calculate_haploid_samples(species, quick_and_dirty=True) #  quick_and_dirty=True
    non_haploid_samples = [s for s in highcoverage_samples if s not in set(haploid_samples)]
    number_of_haploid_samples = len(haploid_samples)
    number_of_non_haploid_samples = len(non_haploid_samples)
    qp_array.append([species, number_of_haploid_samples, "QP"])
    qp_array.append([species, number_of_non_haploid_samples, "Not QP"])
qp_df = pd.DataFrame(data = qp_array, columns = ["species", "sample_count", "QP"])
qp_df['sample_count'] = qp_df["sample_count"].astype(float)
qp_df['species_name'] = [species_code_map[species] for species in qp_df['species']]

In [6]:
qp_df[qp_df['QP'] == "Not QP"].sort_values("sample_count", ascending = False)

Unnamed: 0,species,sample_count,QP,species_name
33,207693,6.0,Not QP,f__Oscillospiraceae (207693)
11,100555,4.0,Not QP,g__Lawsonibacter (100555)
51,214603,4.0,Not QP,g__CAG-95 (214603)
17,203686,4.0,Not QP,g__UBA3282 (203686)
97,261672,2.0,Not QP,g__Angelakisella (261672)
...,...,...,...,...
63,217378,0.0,Not QP,g__1XD42-69 (217378)
67,229722,0.0,Not QP,g__Ruminiclostridium_E
69,231109,0.0,Not QP,g__CAG-81
71,231118,0.0,Not QP,g__Clostridium_Q


In [7]:
qp_df.groupby("species")['sample_count'].sum().sort_values(ascending = False)

species
205567    37.0
263040    27.0
234341    18.0
261649    18.0
213999    18.0
          ... 
231349     1.0
217240     1.0
261755     1.0
217378     1.0
100111     1.0
Name: sample_count, Length: 64, dtype: float64

In [8]:
species = "229547"


## Analysis

In [26]:
## Meta-parameters: experiment with these—no hard and fast rules!

## minimum number of SNVs which need to be clustered together in order to qualify as a "strain"
min_cluster_size = 1000

## minimum fraction of sites which pass our coverage threshold which must be in a cluster in order for it to qualify 
## as a strain
min_cluster_fraction = 1/10

## For computational efficiency, we can downsample the SNVs we actually perform strain phasing on
max_num_snvs = 20000

## distance threshold to be considered linked—lower means trajectories have to be more   
# max_d = 3.5
max_d = 4.75 # 4.5 for species 229547

## minimum coverage to consider allele frequency at a site for purposes of clustering
min_coverage = 10 

## minimum average sample coverage at polymorphic sites (e.g. sites in the A/D matrices)
min_sample_coverage = 5


## polymorphic & covered fraction: what percentage of samples does a site need 
## with coverage > min_coverage and polymorphic to be included in downstream analyses? 
poly_cov_frac = 1/5 #

## Number of clusters to calculate
n_clusters = 100

#Minimum number of snvs per sample
min_num_snvs_per_sample = 100

In [27]:
Fs,Ass,Dss = return_FAD(species, min_coverage=min_coverage, 
                        min_sample_coverage=min_sample_coverage, 
                        poly_cov_frac = poly_cov_frac, 
                        calculate_poly_cov_frac=False, 
                        read_support = True) 


Using a read support of 4 for each polymorphism.


In [28]:
#filter out samples without an adequate number of SNVs when all is said and done
sample_with_adequate_snv_count = ~((~np.isnan(Fs)).sum() < min_num_snvs_per_sample)

Fs = Fs.loc[:,sample_with_adequate_snv_count]
Ass = Ass.loc[:,sample_with_adequate_snv_count]
Dss = Dss.loc[:,sample_with_adequate_snv_count]

In [29]:
%%time


fss = Ass.values/(Dss.values + (Dss.values == 0)) #This is so it doesn't produce a na (division by 0)

cluster_As = Ass.values
cluster_Ds = Dss.values
cluster_fs = cluster_As/(cluster_Ds + (cluster_Ds == 0))

## for compatibility in case of threshold number of SNVs
num = min(max_num_snvs,Fs.shape[0])

i_list = Dss.T.mean().sort_values(ascending=False).index[:num]

sys.stderr.write("Processing %s SNVs" % num)

## simply shuffles indices if no threshold is specified
#i_list = sample(range(Fs.shape[0]),num)
i_list_idx = Fs.loc[i_list].index

Ass_sub = Ass.loc[i_list_idx]
Dss_sub = Dss.loc[i_list_idx]
Fs_sub = Fs.loc[i_list_idx]

fss_sub = Ass_sub.values/(Dss_sub.values + (Dss_sub.values == 0))

cluster_As_sub = Ass_sub.values
cluster_Ds_sub = Dss_sub.values
cluster_fs_sub = cluster_As_sub/(cluster_Ds_sub + (cluster_Ds_sub == 0))

D_mat = np.zeros([num,num])
D_mat_1 = D_mat_fun1(num,fss_sub,cluster_Ds_sub,D_mat)
D_mat = np.zeros([num,num]) 
D_mat_2 = D_mat_fun2(num,fss_sub,cluster_Ds_sub,D_mat)

D_mat = np.fmin(D_mat_1,D_mat_2) #I believe this is filling in the minimum of the two polarizations
D_mat = symmetrize(D_mat)

D_mat_1 = pd.DataFrame(D_mat_1,index=Fs_sub.index,columns=Fs_sub.index)
D_mat_2 = pd.DataFrame(D_mat_2,index=Fs_sub.index,columns=Fs_sub.index)

D_mat_close = pd.DataFrame(D_mat < max_d) 

D_mat_close.index = Fs_sub.index
D_mat_close.columns = Fs_sub.index


Processing 14897 SNVs

CPU times: user 27 s, sys: 3.25 s, total: 30.2 s
Wall time: 30.1 s


In [30]:
## extracts up to 100 clusters
## in practice all SNVs should fall into one of a fairly small number of clusters
## really should re-write this with a while loop but this works for now
## the idea is that we exhaust all clusters—there should only be a small number of them ultimately

###Idea with while loop:
##### While there are still variants out there, have it try to be clusterings

all_clus_pol = []
all_clus_idx = []
all_clus_A = []
all_clus_D = []

all_clus_F = []

for i in range(n_clusters):
    
    try:

        clus,clus_idxs = return_clus(D_mat_close,Fs_sub)
#         clus,clus_idxs = return_clus(D_mat_close,Fs_sub, co_cluster_pct=0.5) #Finding points that cluster with 25% other points. That's a cluster.
                                                            #We would modify this function to get smaller clusters...
        clus_pol = polarize_clus(clus,clus_idxs,D_mat_1,D_mat_2)
        clus_pol.index = clus_idxs
        D_mat_close = drop_clus_idxs(D_mat_close,clus_idxs)

        if clus_pol.shape[0] > min_cluster_size and clus_pol.shape[0] > Fs.shape[0]*min_cluster_fraction:

            all_clus_D.append(Dss.loc[clus.index].mean().values)
            all_clus_pol.append(clus_pol)
            all_clus_A.append(clus_pol.mean()*all_clus_D[-1])
            all_clus_F.append(clus_pol.mean())

            print(clus_pol.shape[0])

    except:
        pass


10420


In [31]:
## now, choosing a representative SNV from each cluster, and finding all other sites (not just limited to the 20k)
## which are consistent w/ being linked to it

final_clusters = []

all_aligned_sites = []

for i in range(len(all_clus_D)):
    
    sys.stderr.write(f'\n\nCluster {i+1}\n')
    ancD = all_clus_D[i]
    ancF = all_clus_F[i]

    dss = Dss.values
    fss = Fs.values
    
    disAnc_forward = []
    disAnc_backward = []

    for j in range(Dss.shape[0]):
        disAnc_forward.append(calc_dis(ancD,dss[j],ancF,fss[j]))
        disAnc_backward.append(calc_dis(ancD,dss[j],ancF,1-fss[j]))
        if j % 1000 == 0:
            sys.stderr.write(f"\n\t{np.around(100*j/Dss.shape[0],3)}% finished")
    
    disAnc = [min(els) for els in zip(disAnc_forward, disAnc_backward)]
    disAnc = np.array(disAnc)
    aligned_sites = Fs.loc[disAnc < max_d].index
    f_dist =  pd.DataFrame(np.array([disAnc_forward,disAnc_backward]).T,index=Fs.index)
    pols = f_dist.T.idxmin() > 0
    
    aligned_sites = [a for a in aligned_sites if a not in all_aligned_sites]
    
    pols = pols.loc[aligned_sites]
    re_polarize = pols.loc[pols].index
    
    all_aligned_sites.extend(aligned_sites)
    
    Fs_cluster = Fs.loc[aligned_sites]
    
    Fs_cluster.loc[re_polarize] = 1 - Fs_cluster.loc[re_polarize]
        
    final_clusters.append(Fs_cluster)
    



Cluster 1

	0.0% finished
	6.713% finished
	13.426% finished
	20.138% finished
	26.851% finished
	33.564% finished
	40.277% finished
	46.989% finished
	53.702% finished
	60.415% finished
	67.128% finished
	73.84% finished
	80.553% finished
	87.266% finished
	93.979% finished

In [32]:
#SAVING RAW FILE
species_raw_cluster_path = "%s%s%s" % (species_raw_cluster_dir, species, "_RawCluster.pckl")

pickle_object = open(species_raw_cluster_path, "wb")
pickle.dump(final_clusters, pickle_object)
pickle_object.close()

In [None]:
# #LOADING THE RAW FILE
# species_raw_cluster_path = "%s%s%s" % (species_raw_cluster_dir, species, "_RawCluster.pckl")


# final_clusters = pd.read_pickle(species_raw_cluster_path)

## Creating polarized clusters

Once clusters have been identified and internally polarized, they need to be polarized relative to one another. In the best case, the sum of strain frequencies will be 1. 

In [None]:
if len(final_clusters) > 1:
    sys.stderr.write("More than 1 strains detected.\n")
## If only a single cluster is detected, add a second "cluster" which is simply 1 minus the allele frequencies
## in the first cluster
## aids in visualization for people not familiar with this kind of clustering
if len(final_clusters) == 1:
    final_clusters.append(1-final_clusters[0])

## add cluster centroids
final_f = []
for cluster in final_clusters:
    final_f.append(cluster.mean())
df_final_f = pd.DataFrame(final_f)

## now, polarize clusters so that the sum of squareds of the centroids to 1 is minimized
## the idea here is that accurate strain frequencies should sum to 1
polarize = True

pol_d2 = {}

for i in range(df_final_f.shape[0]):
    df_final_f_temp = df_final_f.copy() #Makes a copy of the centroids
    df_final_f_temp.iloc[i] = 1 - df_final_f_temp.iloc[i] #gets the polarized version of ONE of the centroids.
    pol_d2[i] =  ((1 - df_final_f_temp.sum())**2).sum()   #Get the across centroids for all samples (should be close to 1), 
                                                                #subtract this from 1, and square. Sum all those values
                                                                #Ideally, this value is really close to 0. 
                                                                #Add this value to the dictionary.

    pol_d2 = pd.Series(pol_d2)                                #Make the dictionary a series 

    if pol_d2.min() < ((1 - df_final_f.sum())**2).sum(): #If any of the above repolarizations actually made the overall sum of centroids closer to 1, repolarize.
        clus_to_re_pol = pol_d2.idxmin()
        final_f[clus_to_re_pol] = 1 - final_f[clus_to_re_pol]
        final_clusters[clus_to_re_pol] = 1 - final_clusters[clus_to_re_pol]
        df_final_f = pd.DataFrame(final_f)  


    

### Filtering out samples in which each cluster does not have adequate snvs

In [None]:
good_indices = []

for i,cluster in enumerate(final_clusters):
    if i == 0:
        good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
    else:
        new_good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
        good_samples = good_samples & new_good_samples

for i,cluster in enumerate(final_clusters): 
    final_clusters[i] = final_clusters[i].T.loc[good_samples].T
    final_f[i] = final_f[i].T.loc[good_samples]

Fs = Fs.T.loc[good_samples].T
Ass = Ass.T.loc[good_samples].T
Dss = Dss.T.loc[good_samples].T


In [None]:
#Filter all all na columns if there are any - THis is redundant
if len(final_clusters) > 0:
    for i,cluster in enumerate(final_clusters):
        if i == 0:
            mask = ~np.isnan(cluster).all(axis = 0)
        final_clusters[i] = cluster.loc[:,mask]
        final_f[i] = final_f[i][mask]
    Fs = Fs.loc[:,mask]

In [None]:

# #SAVING RAW FILE
# species_centroid_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_ClusterCentroid.pckl")
# species_polarized_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_PolarizedCluster.pckl")
# final_Fs_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_final_Fs.pckl")

# pickle_object = open(species_centroid_cluster_path, "wb")
# pickle.dump(final_f, pickle_object)
# pickle_object.close()

# pickle_object = open(species_polarized_cluster_path, "wb")
# pickle.dump(final_clusters, pickle_object)
# pickle_object.close()

# pickle_object = open(final_Fs_path, "wb")
# pickle.dump(Fs, pickle_object)
# pickle_object.close()

In [None]:
## plot the chosen polarization of strains
## sum of strain frequencies should be ~1 at all timepoints

fig,ax = plt.subplots(figsize=(12,8))
ax.plot(pd.DataFrame(final_f).sum().values,zorder=10,lw=3)
ax.set_ylim([.5,1.5])
ax.axhline(1,color="k",ls="--")
ax.set_ylabel("Sum of strain frequencies",size=20)



## Plotting strain frequencies

Strain frequencies can be plotted using a main key (e.g. Day) and a secondary key (e.g. Capsule), yielding a two-level identification of each sample

In [None]:
## More ordering utilities
mnum = list(set(Fs.T.index.get_level_values("mouse_number")))
msite = list(set(Fs.T.index.get_level_values("region")))
mcage = list(set(Fs.T.index.get_level_values("cage")))

mnum_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"mouse_number").index.get_level_values("mouse_number") == m).ravel() for m in list(set(mnum))}

msite_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"region").index.get_level_values("region") == m).ravel() for m in list(set(msite))}

mcage_sample_dic = {m:np.argwhere(reorder_sort(Fs.T,"cage").index.get_level_values("cage") == m).ravel() for m in list(set(mcage))}

all_sample_dics = {"region":msite_sample_dic,"mouse_number":mnum_sample_dic,"cage":mcage_sample_dic}


In [None]:
key_to_sort = "cage"
secondary_key = "region"

subset = True
subset_value = 5000

if subset:
    high_coverage_snv_idxs = Dss.median(axis = 1).sort_values(ascending = False).index


cmap_cage = []
# colors_library = [(1, 1, 0.8, 0.25), #creamy yellow
#                   (0.529, 0.808, 0.922, 0.25), #sky blue
#                   (0.529, 0.808, 0.922, 0.25)]
#                   #(0.529, 0.808, 0.722, 0.25), #green blue
#                   #(0.294, 0.427, 0.804, 0.15)] #purple blue
colors_library = [(0.9058823529411765, 0.8823529411764706, 0.9372549019607843, 2/3), #light purple
                  (0.6666666666666666, 0.596078431372549, 0.6627450980392157, 2/3), #dark purple
                  (0.9058823529411765, 0.8823529411764706, 0.9372549019607843, 2/3)]
                  #(0.529, 0.808, 0.722, 0.25), #green blue
                  #(0.294, 0.427, 0.804, 0.15)] #purple blue
if "Cage 1" in mcage_sample_dic:
#     cmap_cage.append(colors_library(0))
    cmap_cage.append(colors_library[0])
if "Cage 2" in mcage_sample_dic:
    # cmap_cage.append(colors_library[1])
    cmap_cage.append((1,1,1,2/3))
    
if "Cage 3" in mcage_sample_dic:
    cmap_cage.append(colors_library[2])
    

cmap_clus = get_cmap(5,name="Set3")
# manual_cmap_clus = [(0.396078431372549,0.2627450980392157,0.12941176470588237,1),
#                     (0.8235294117647058,0.7058823529411765,0.5490196078431373,1)]

manual_cmap_clus = [(0.16470588235294117,0.4,0.4549019607843137,1),
                    (0.5019607843137255,0.5019607843137255,0.5019607843137255,1)]

fig,ax = plt.subplots(figsize=(16,8))
fig.suptitle(fu.get_pretty_species_name(species),size=30)

for i,f in enumerate(final_f):
    if subset:
        high_coverage_snv_idxs_strain = high_coverage_snv_idxs.intersection(final_clusters[i].index)[:subset_value]
    ff = reorder_sort(f,key_to_sort).sort_index(key=lambda x: x.map(order_dict))
    ax.plot(ff.values,zorder=100,lw=6,color=manual_cmap_clus[i],label=f"Strain {i+1}");
    ax.plot(ff.values,zorder=80,lw=7,color="k");
    if len(final_clusters) != 0:
        if subset:
            ff_c = reorder_sort(final_clusters[i].loc[high_coverage_snv_idxs_strain].T,key_to_sort).T.sort_index(key=lambda x: x.map(order_dict),axis=1)
        else:
            ff_c = reorder_sort(final_clusters[i].T,key_to_sort).T.sort_index(key=lambda x: x.map(order_dict),axis=1)
        ax.plot(ff_c.sample(min(ff_c.shape[0],10000)).T.values,color=manual_cmap_clus[i],alpha=.01)
    else:
        raise ValueError("No clusters to plot.")


major_x = []
minor_x = []
labels = []

second_xlabels = list(ff.index.get_level_values(secondary_key))


#Making the vertical lines and labels
i = 0
for key, item in all_sample_dics[key_to_sort].items():
    
    xmin = item.min() 
    xmax = item.max()
    
    for e in item:
        ax.axvline(e,color="k",zorder=0,alpha=.5)   
        ax.text(e, -0.1, abbreviate_gut_site(second_xlabels[e], blank_inoc = True), ha='center',va='top', clip_on=False,size=15, rotation=0)
        
    if xmin != xmax:
        major_x.extend([xmin,(xmax + xmin)/2,xmax])
        minor_x.append((xmax + xmin)/2)
        labels.extend(["",key,""])
    else:
        major_x.append(xmin)
        minor_x.append(xmax)
        labels.extend([key])   
        
    i+=1
    
    if (max(all_sample_dics[key_to_sort]) == key):
        ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())
        ax.vlines(item[-1] + 0.5, 0, -0.15, color='black', lw=2, clip_on=False, transform=ax.get_xaxis_transform())
    else:
        ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())
        ax.vlines(item[-1] + 0.5, 0, -0.15, color='black', lw=0.8, clip_on=False, transform=ax.get_xaxis_transform())

#Making the vertical colors
key_to_sort = "cage"
i = 0    
for key, item in sorted(all_sample_dics[key_to_sort].items()):
        
    xmin = item.min() 
    xmax = item.max()
        
    if (key == max(all_sample_dics[key_to_sort])):
        ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=2, clip_on=False, transform=ax.get_xaxis_transform())
        ax.axvspan(xmin - .25,xmax+.25,alpha=.2,color='black') 

        i+=1
    else:
        ax.axvspan(xmin - .1,xmax+.1,color=cmap_cage[i]) #alpha=.2,
        ax.vlines(item[0] - 0.5, 0, -0.15, color='black', lw=2, clip_on=False, transform=ax.get_xaxis_transform())
        ax.vlines(item[-1] + 0.5, 0, -0.15, color='black', lw=2, clip_on=False, transform=ax.get_xaxis_transform())

        i+=1


ax.set_xticks(major_x)
ax.set_xticks(minor_x, minor = True)
ax.set_xticklabels([label if label != "Inoculum" else "" for label in labels])

ax.axhline(0,color="grey")
ax.axhline(1,color="grey")

ax.tick_params(axis = 'x', which = 'major', length=0,labelsize = 20,pad=45)
ax.tick_params(axis = 'x', which = 'minor', length = 10,labelsize = 0)
    
ax.set_ylabel("Strain frequency",size=30)
ax.set_ylim([-0.05,1.05])
if "Inoculum" in labels:
    ax.set_xlim([-1,max(major_x)+0.25])
    

# ax.set_xlabel("Sample", size = 30)


fig.legend(prop={"size":20})
# fig.legend(prop={"size":20}, bbox_to_anchor=(0.85, 0.5));

#plt.tight_layout()


In [None]:
reorder_sort(f,"cage").sort_index(key=lambda x: x.map(order_dict))

## Bootstrapping support and saving file for R plotting

In [None]:
pd.DataFrame(final_f[0]).reset_index().sort_values(by = ["cage", "mouse_number"]).rename(columns = {0: "freq"})

In [None]:
all_strains = True
bootstrap_ci = True
boostrap_k = 100
bootstrap_N = 1000
np.random.seed(6)
columns_to_keep = ['sample',
                   'species', 
                   'strain', 
                   'cage', 
                   'mouse_number', 
                   'region', 
                   'freq']
if len(final_clusters) > 0:
    for i in np.arange(len(final_f)): 
        if i == 0:
            final_f_all_strains = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
            final_f_all_strains['species'] = [species]*len(final_f_all_strains)
            final_f_all_strains['strain'] = [i + 1]*len(final_f_all_strains)
            final_f_all_strains = final_f_all_strains[columns_to_keep]
            iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().loc[:,['sample', 0.75]]
            iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().loc[:,['sample', 0.25]]
            final_f_all_strains = final_f_all_strains.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
            if bootstrap_ci:
                upper_ci_vec = []
                lower_ci_vec = []
                
                for sample_i in range(final_clusters[i].shape[1]):
                    
                    sample_name = final_clusters[i].iloc[:,sample_i].name[3]
                    snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                    mean_freq_vec = []
                    for n in range(bootstrap_N):
                        sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                        mean_freq = np.mean(sampled_snv_freqs)
                        mean_freq_vec.append(mean_freq)
                    upper_ci = np.quantile(mean_freq_vec, 0.975)
                    lower_ci = np.quantile(mean_freq_vec, 0.025)
                    upper_ci_vec.append(upper_ci)
                    lower_ci_vec.append(lower_ci)
                    
                

        else:
            final_f_all_strains_temp = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
            final_f_all_strains_temp['species'] = [species]*len(final_f_all_strains_temp)
            final_f_all_strains_temp['strain'] = [i + 1]*len(final_f_all_strains_temp)
            final_f_all_strains_temp = final_f_all_strains_temp[columns_to_keep]
            iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().reset_index().loc[:,['sample', 0.75]]
            iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().reset_index().loc[:,['sample', 0.25]]
            final_f_all_strains_temp = final_f_all_strains_temp.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
            final_f_all_strains = pd.concat([final_f_all_strains, final_f_all_strains_temp], ignore_index=True)
            
            if bootstrap_ci:
                
                for sample_i in range(final_clusters[i].shape[1]): #if we did not infer multiple strains, but there may have been multiple strains in the inoculum, just calculate CI for inoculum
             
                    snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                    mean_freq_vec = []

                    for n in range(bootstrap_N):
                        sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                        mean_freq = np.mean(sampled_snv_freqs)
                        mean_freq_vec.append(mean_freq)
                    upper_ci = np.quantile(mean_freq_vec, 0.975)
                    lower_ci = np.quantile(mean_freq_vec, 0.025)
                    upper_ci_vec.append(upper_ci)
                    lower_ci_vec.append(lower_ci)
            
            
        final_f_all_strains['upper_ci'] = upper_ci_vec
        final_f_all_strains['lower_ci'] = lower_ci_vec
    # Renaming
    final_f_all_strains = final_f_all_strains.rename(columns = {0.25: "quantile_25", 0.75: "quantile_75"})
    # sorting
    final_f_all_strains = final_f_all_strains.sort_values(by = ["strain", "cage", "mouse_number"]).reset_index(drop = True)

    #Saving 
    
    output_file = "%sstrain_phasing/strain_clusters/%s/%s_strain_frequency.csv" % (config.project_directory, species, species)
    final_f_all_strains.to_csv(output_file, sep = "\t", index = False)
else:
    sys.stderr.write("Only one strain detected\n")
    

In [None]:
"%s%s" % (config.metadata_directory, "species_snps.txt")

# For loop through species list

In [33]:
# variables
## Meta-parameters: experiment with these—no hard and fast rules!

## minimum number of SNVs which need to be clustered together in order to qualify as a "strain"
min_cluster_size = 1000

## minimum fraction of sites which pass our coverage threshold which must be in a cluster in order for it to qualify 
## as a strain
min_cluster_fraction = 1/10

## For computational efficiency, we can downsample the SNVs we actually perform strain phasing on
max_num_snvs = 20000

## distance threshold to be considered linked—lower means trajectories have to be more   
max_d = 3.5

## minimum coverage to consider allele frequency at a site for purposes of clustering
min_coverage = 10 

## minimum average sample coverage at polymorphic sites (e.g. sites in the A/D matrices)
min_sample_coverage = 5


## polymorphic & covered fraction: what percentage of samples does a site need 
## with coverage > min_coverage and polymorphic to be included in downstream analyses? 
poly_cov_frac = 1/5 #

## Number of clusters to calculate
n_clusters = 100

#Minimum number of snvs per sample
min_num_snvs_per_sample = 100


# Load species list
species_path = "%s%s" % (config.metadata_directory, "species_snps.txt")
with open(species_path, 'r') as file:
    species_list = [line.strip() for line in file]

# Load subject list
sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()
subjects = subject_sample_map.keys()

# for species_i,species in enumerate(species_list):
for species_i,species in enumerate(["229547"]):
    sys.stderr.write("Processing %s (%s / %s)\n" % (species, species_i + 1, len(species_list)))
    # Defining directories

    strainfinder_dir = "%sinput" % (config.strain_phasing_directory)

    # Load Fs
    snp_alignment_path = "%s/%s/%s.strainfinder.p" %  (strainfinder_dir ,species, species)
    if os.path.exists(snp_alignment_path):
        Fs,Ass,Dss = return_FAD(species, 
                                min_coverage=min_coverage, 
                                min_sample_coverage=min_sample_coverage, 
                                poly_cov_frac = poly_cov_frac, 
                                calculate_poly_cov_frac=False, 
                                read_support = True) 
    else:
        sys.stderr.write("No SNP alignment file for species. Skipping to next species.\n")
        continue
    
    # Filter out samplers without adequate read coverage
    sample_with_adequate_snv_count = ~((~np.isnan(Fs)).sum() < min_num_snvs_per_sample)

    Fs = Fs.loc[:,sample_with_adequate_snv_count]
    Ass = Ass.loc[:,sample_with_adequate_snv_count]
    Dss = Dss.loc[:,sample_with_adequate_snv_count]


    # Making directories if need be
    raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
    species_raw_cluster_dir = "%s%s/" % (raw_cluster_path, species)



    # Read in raw cluster file
    species_raw_cluster_path = "%s%s%s" % (species_raw_cluster_dir, species, "_RawCluster.pckl")

    if not os.path.exists(species_raw_cluster_path):
        sys.stderr.write(f"\tFile not found: {species_raw_cluster_path}, making raw cluster.\n")
        
        ## CONSTRUCTION DISTANCE MATRIX
        fss = Ass.values/(Dss.values + (Dss.values == 0)) #This is so it doesn't produce a na (division by 0)

        cluster_As = Ass.values
        cluster_Ds = Dss.values
        cluster_fs = cluster_As/(cluster_Ds + (cluster_Ds == 0))

        ## for compatibility in case of threshold number of SNVs
        num = min(max_num_snvs,Fs.shape[0])

        i_list = Dss.T.mean().sort_values(ascending=False).index[:num]

        sys.stderr.write("Processing %s SNVs" % num)

        ## simply shuffles indices if no threshold is specified
        #i_list = sample(range(Fs.shape[0]),num)
        i_list_idx = Fs.loc[i_list].index

        Ass_sub = Ass.loc[i_list_idx]
        Dss_sub = Dss.loc[i_list_idx]
        Fs_sub = Fs.loc[i_list_idx]

        fss_sub = Ass_sub.values/(Dss_sub.values + (Dss_sub.values == 0))

        cluster_As_sub = Ass_sub.values
        cluster_Ds_sub = Dss_sub.values
        cluster_fs_sub = cluster_As_sub/(cluster_Ds_sub + (cluster_Ds_sub == 0))

        D_mat = np.zeros([num,num])
        D_mat_1 = D_mat_fun1(num,fss_sub,cluster_Ds_sub,D_mat)
        D_mat = np.zeros([num,num]) 
        D_mat_2 = D_mat_fun2(num,fss_sub,cluster_Ds_sub,D_mat)

        D_mat = np.fmin(D_mat_1,D_mat_2) #I believe this is filling in the minimum of the two polarizations
        D_mat = symmetrize(D_mat)

        D_mat_1 = pd.DataFrame(D_mat_1,index=Fs_sub.index,columns=Fs_sub.index)
        D_mat_2 = pd.DataFrame(D_mat_2,index=Fs_sub.index,columns=Fs_sub.index)

        D_mat_close = pd.DataFrame(D_mat < max_d) 

        D_mat_close.index = Fs_sub.index
        D_mat_close.columns = Fs_sub.index
        
        ## extracts up to 100 clusters
        ## in practice all SNVs should fall into one of a fairly small number of clusters
        ## really should re-write this with a while loop but this works for now
        ## the idea is that we exhaust all clusters—there should only be a small number of them ultimately

        ###Idea with while loop:
        ##### While there are still variants out there, have it try to be clusterings

        all_clus_pol = []
        all_clus_idx = []
        all_clus_A = []
        all_clus_D = []

        all_clus_F = []

        for i in range(n_clusters):
            
            try:

                clus,clus_idxs = return_clus(D_mat_close,Fs_sub)
        #         clus,clus_idxs = return_clus(D_mat_close,Fs_sub, co_cluster_pct=0.5) #Finding points that cluster with 25% other points. That's a cluster.
                                                                    #We would modify this function to get smaller clusters...
                clus_pol = polarize_clus(clus,clus_idxs,D_mat_1,D_mat_2)
                clus_pol.index = clus_idxs
                D_mat_close = drop_clus_idxs(D_mat_close,clus_idxs)

                if clus_pol.shape[0] > min_cluster_size and clus_pol.shape[0] > Fs.shape[0]*min_cluster_fraction:

                    all_clus_D.append(Dss.loc[clus.index].mean().values)
                    all_clus_pol.append(clus_pol)
                    all_clus_A.append(clus_pol.mean()*all_clus_D[-1])
                    all_clus_F.append(clus_pol.mean())

                    print(clus_pol.shape[0])

            except:
                pass
        
        if len(all_clus_D) == 0:
            sys.stderr.write("No cluster found. Skipping to next species.\n")
        ## now, choosing a representative SNV from each cluster, and finding all other sites (not just limited to the 20k)
        ## which are consistent w/ being linked to it

        final_clusters = []

        all_aligned_sites = []

        for i in range(len(all_clus_D)):
            
            sys.stderr.write(f'\n\nCluster {i+1}\n')
            ancD = all_clus_D[i]
            ancF = all_clus_F[i]

            dss = Dss.values
            fss = Fs.values
            
            disAnc_forward = []
            disAnc_backward = []

            for j in range(Dss.shape[0]):
                disAnc_forward.append(calc_dis(ancD,dss[j],ancF,fss[j]))
                disAnc_backward.append(calc_dis(ancD,dss[j],ancF,1-fss[j]))
                if j % 1000 == 0:
                    sys.stderr.write(f"\n\t{np.around(100*j/Dss.shape[0],3)}% finished")
            
            disAnc = [min(els) for els in zip(disAnc_forward, disAnc_backward)]
            disAnc = np.array(disAnc)
            aligned_sites = Fs.loc[disAnc < max_d].index
            f_dist =  pd.DataFrame(np.array([disAnc_forward,disAnc_backward]).T,index=Fs.index)
            pols = f_dist.T.idxmin() > 0
            
            aligned_sites = [a for a in aligned_sites if a not in all_aligned_sites]
            
            pols = pols.loc[aligned_sites]
            re_polarize = pols.loc[pols].index
            
            all_aligned_sites.extend(aligned_sites)
            
            Fs_cluster = Fs.loc[aligned_sites]
            
            Fs_cluster.loc[re_polarize] = 1 - Fs_cluster.loc[re_polarize]
                
            final_clusters.append(Fs_cluster)

        if len(final_clusters) == 0:
            sys.stderr.write("No clusters detected. Skipping to next species.\n")
        else:
            if not os.path.exists(species_raw_cluster_dir):
                os.makedirs(species_raw_cluster_dir)
                print("Cluster directory created successfully!")
            else:
                print("Cluster directory already exists.")
            # #SAVING RAW FILE
            pickle_object = open(species_raw_cluster_path, "wb")
            pickle.dump(final_clusters, pickle_object)
            pickle_object.close()

    else:
        final_clusters = pd.read_pickle(species_raw_cluster_path)

    # Polarizing clusters
    if len(final_clusters) == 0:
        sys.stderr.write("\tNo strains detected, skipping species.\n")
        continue
    
    ## If only a single cluster is detected, add a second "cluster" which is simply 1 minus the allele frequencies
    ## in the first cluster
    ## aids in visualization for people not familiar with this kind of clustering
    if len(final_clusters) == 1:
        final_clusters.append(1-final_clusters[0])

    ## add cluster centroids
    final_f = []
    for cluster in final_clusters:
        final_f.append(cluster.mean())
    df_final_f = pd.DataFrame(final_f)

    ## now, polarize clusters so that the sum of squareds of the centroids to 1 is minimized
    ## the idea here is that accurate strain frequencies should sum to 1
    polarize = True

    pol_d2 = {}

    for i in range(df_final_f.shape[0]):
        df_final_f_temp = df_final_f.copy() #Makes a copy of the centroids
        df_final_f_temp.iloc[i] = 1 - df_final_f_temp.iloc[i] #gets the polarized version of ONE of the centroids.
        pol_d2[i] =  ((1 - df_final_f_temp.sum())**2).sum()   #Get the across centroids for all samples (should be close to 1), 
                                                                    #subtract this from 1, and square. Sum all those values
                                                                    #Ideally, this value is really close to 0. 
                                                                    #Add this value to the dictionary.

        pol_d2 = pd.Series(pol_d2)                                #Make the dictionary a series 

        if pol_d2.min() < ((1 - df_final_f.sum())**2).sum(): #If any of the above repolarizations actually made the overall sum of centroids closer to 1, repolarize.
            clus_to_re_pol = pol_d2.idxmin()
            final_f[clus_to_re_pol] = 1 - final_f[clus_to_re_pol]
            final_clusters[clus_to_re_pol] = 1 - final_clusters[clus_to_re_pol]
            df_final_f = pd.DataFrame(final_f)  

    # Filtering out samples without adequate snvs after polarization step
    good_indices = []

    for i,cluster in enumerate(final_clusters):
        if i == 0:
            good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
        else:
            new_good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
            good_samples = good_samples & new_good_samples

    if sum(good_samples == False) == len(good_samples): # if no sample is "good" across all clusters, then skip to the next species
        sys.stderr.write("\tNo good samples after filter, skipping species.\n")
        continue

    for i,cluster in enumerate(final_clusters): 
        final_clusters[i] = final_clusters[i].T.loc[good_samples].T
        final_f[i] = final_f[i].T.loc[good_samples]

    Fs = Fs.T.loc[good_samples].T
    Ass = Ass.T.loc[good_samples].T
    Dss = Dss.T.loc[good_samples].T

    # Filtering out columsn that have only na (redundant)
    #Filter all all na columns if there are any - THis is redundant
    if len(final_clusters) > 0:
        for i,cluster in enumerate(final_clusters):
            if i == 0:
                mask = ~np.isnan(cluster).all(axis = 0)
            final_clusters[i] = cluster.loc[:,mask]
            final_f[i] = final_f[i][mask]
        # Fs = Fs.loc[:,mask]
        Fs = Fs.loc[:, Fs.columns[mask]]

    
    

    # Bootstrapping CI and outputting dataframe
    all_strains = True
    bootstrap_ci = True
    boostrap_k = 100
    bootstrap_N = 1000
    np.random.seed(6)
    columns_to_keep = ['sample',
                    'species', 
                    'strain', 
                    'cage', 
                    'mouse_number', 
                    'region', 
                    'freq']
    if len(final_clusters) > 0:
        for i in np.arange(len(final_f)): 
            if i == 0:
                final_f_all_strains = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
                final_f_all_strains['species'] = [species]*len(final_f_all_strains)
                final_f_all_strains['strain'] = [i + 1]*len(final_f_all_strains)
                final_f_all_strains = final_f_all_strains[columns_to_keep]
                iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().loc[:,['sample', 0.75]]
                iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().loc[:,['sample', 0.25]]
                final_f_all_strains = final_f_all_strains.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
                if bootstrap_ci:
                    upper_ci_vec = []
                    lower_ci_vec = []
                    
                    for sample_i in range(final_clusters[i].shape[1]):
                        
                        sample_name = final_clusters[i].iloc[:,sample_i].name[3]
                        snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                        mean_freq_vec = []
                        for n in range(bootstrap_N):
                            sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                            mean_freq = np.mean(sampled_snv_freqs)
                            mean_freq_vec.append(mean_freq)
                        upper_ci = np.quantile(mean_freq_vec, 0.975)
                        lower_ci = np.quantile(mean_freq_vec, 0.025)
                        upper_ci_vec.append(upper_ci)
                        lower_ci_vec.append(lower_ci)
                        
                    

            else:
                final_f_all_strains_temp = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
                final_f_all_strains_temp['species'] = [species]*len(final_f_all_strains_temp)
                final_f_all_strains_temp['strain'] = [i + 1]*len(final_f_all_strains_temp)
                final_f_all_strains_temp = final_f_all_strains_temp[columns_to_keep]
                iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().reset_index().loc[:,['sample', 0.75]]
                iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().reset_index().loc[:,['sample', 0.25]]
                final_f_all_strains_temp = final_f_all_strains_temp.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
                final_f_all_strains = pd.concat([final_f_all_strains, final_f_all_strains_temp], ignore_index=True)
                
                if bootstrap_ci:
                    
                    for sample_i in range(final_clusters[i].shape[1]): #if we did not infer multiple strains, but there may have been multiple strains in the inoculum, just calculate CI for inoculum
                
                        snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                        mean_freq_vec = []

                        for n in range(bootstrap_N):
                            sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                            mean_freq = np.mean(sampled_snv_freqs)
                            mean_freq_vec.append(mean_freq)
                        upper_ci = np.quantile(mean_freq_vec, 0.975)
                        lower_ci = np.quantile(mean_freq_vec, 0.025)
                        upper_ci_vec.append(upper_ci)
                        lower_ci_vec.append(lower_ci)
                
                
            final_f_all_strains['upper_ci'] = upper_ci_vec
            final_f_all_strains['lower_ci'] = lower_ci_vec
        # Renaming
        final_f_all_strains = final_f_all_strains.rename(columns = {0.25: "quantile_25", 0.75: "quantile_75"})
        # sorting
        final_f_all_strains = final_f_all_strains.sort_values(by = ["strain", "cage", "mouse_number"]).reset_index(drop = True)

        #Saving 
        
        output_file = "%sstrain_phasing/strain_clusters/%s/%s_strain_frequency.csv" % (config.project_directory, species, species)
        final_f_all_strains.to_csv(output_file, sep = "\t", index = False)
    
    else:
        sys.stderr.write("\tOnly one strain detected\n\n")
        



    

    

    
    

Processing 229547 (1 / 117)
Using a read support of 4 for each polymorphism.
