In [2]:
# Setting working directory
import sys
sys.path.insert(0, "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/scripts/strain_phasing/")

# Normal Libraries
import pandas as pd
import numpy as np
import pickle
import scipy.stats
import random as rand
from random import randint,sample,choices
from math import log
import os 
from datetime import datetime

# Plotting libraries
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from matplotlib.lines import Line2D
import figure_utils as fu
from matplotlib import colormaps
plasma_cmap = colormaps.get_cmap('plasma')

# config
import config

# predefined functions
from strain_phasing_functions import *
from microbiome_evolution_functions import *

In [3]:
species_path = "%s%s" % (config.analysis_directory, "metadata/species_snps.txt")
with open(species_path, 'r') as file:
    species_list = [line.strip() for line in file]

In [4]:
species_list

['Acidaminococcus_intestini_54097',
 'Actinomyces_graevenitzii_58300',
 'Actinomyces_sp_57735',
 'Actinomyces_sp_62581',
 'Actinomyces_viscosus_57672',
 'Adlercreutzia_equolifaciens_60310',
 'Aggregatibacter_aphrophilus_58143',
 'Akkermansia_muciniphila_55290',
 'Alistipes_finegoldii_56071',
 'Alistipes_indistinctus_62207',
 'Alistipes_onderdonkii_55464',
 'Alistipes_putredinis_61533',
 'Alistipes_senegalensis_58364',
 'Alistipes_shahii_62199',
 'Anaerostipes_hadrus_55206',
 'Atopobium_parvulum_59960',
 'Atopobium_sp_59401',
 'Bacteroidales_bacterium_58650',
 'Bacteroides_caccae_53434',
 'Bacteroides_cellulosilyticus_58046',
 'Bacteroides_clarus_62282',
 'Bacteroides_coprocola_61586',
 'Bacteroides_eggerthii_54457',
 'Bacteroides_finegoldii_57739',
 'Bacteroides_fragilis_54507',
 'Bacteroides_intestinalis_61596',
 'Bacteroides_massiliensis_44749',
 'Bacteroides_ovatus_58035',
 'Bacteroides_pectinophilus_61619',
 'Bacteroides_plebeius_61623',
 'Bacteroides_salyersiae_54873',
 'Bacteroid

In [29]:
species = "Desulfovibrio_piger_61475"

## Directories

In [30]:
strainfinder_dir = "%sinput" % (config.strain_phasing_directory)

strainfinder_dir = "%sinput" % (config.strain_phasing_directory)

#Raw cluster
# raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
species_raw_cluster_dir = "%s%s/" % (raw_cluster_path, species)
if not os.path.exists(species_raw_cluster_dir):
    os.makedirs(species_raw_cluster_dir)
    print("Cluster directory created successfully!")
else:
    print("Cluster directory already exists.")

strain_phasing_figures_dir = "%s%s" % (config.figure_directory, "strain_phasing/")
if not os.path.exists(strain_phasing_figures_dir):
    os.makedirs(strain_phasing_figures_dir)
    print("Figure directory created successfully!")
else:
    print("Figure directory already exists.")


Cluster directory already exists.
Figure directory already exists.


## Selecting host

In [31]:
sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()

In [32]:
# Extracting median coverage
## Loading coverage df
coverage_df_path = "%s%s" % (config.data_directory, "species/coverage.txt.bz2")
coverage_df = pd.read_csv(coverage_df_path, sep = "\t")




In [33]:
cvg_data = []
for subject_id in subject_sample_map.keys():
    samples_of_interest = [sample for sample in list(subject_sample_map[subject_id].keys())if sample in coverage_df.columns]
    coverage_df_subject = coverage_df.loc[coverage_df.species_id == species, ["species_id"] + samples_of_interest]
    median_coverage = coverage_df_subject[samples_of_interest].median(axis = 1).values[0]
    cvg_data.append([subject_id, median_coverage])
cvg_data = pd.DataFrame(data = cvg_data, columns = ['subject', 'median_coverage']).sort_values(by = 'median_coverage', ascending = False).reset_index(drop=True)


In [34]:
cvg_data

Unnamed: 0,subject,median_coverage
0,9,2.601905
1,12,0.352792
2,14,0.277947
3,2,0.221344
4,1,0.0
5,10,0.0
6,3,0.0
7,11,0.0
8,8,0.0
9,7,0.0


In [35]:
subject_id = "9"
# samples_of_interest = [sample for sample in list(subject_sample_map[subject_id].keys())if sample in coverage_df.columns]
samples_of_interest = list(subject_sample_map[subject_id].keys())



## Analysis

In [36]:
## Meta-parameters: experiment with these—no hard and fast rules!

## minimum number of SNVs which need to be clustered together in order to qualify as a "strain"
min_cluster_size = 1000

## minimum fraction of sites which pass our coverage threshold which must be in a cluster in order for it to qualify 
## as a strain
min_cluster_fraction = 1/10

## For computational efficiency, we can downsample the SNVs we actually perform strain phasing on
max_num_snvs = 20000

## distance threshold to be considered linked—lower means trajectories have to be more   
max_d = 3.5

## minimum coverage to consider allele frequency at a site for purposes of clustering
min_coverage = 10 

## minimum average sample coverage at polymorphic sites (e.g. sites in the A/D matrices)
min_sample_coverage = 5


## polymorphic & covered fraction: what percentage of samples does a site need 
## with coverage > min_coverage and polymorphic to be included in downstream analyses? 
poly_cov_frac = 1/5 #

## Number of clusters to calculate
n_clusters = 100

#Minimum number of snvs per sample
min_num_snvs_per_sample = 100

In [37]:
Fs,Ass,Dss = return_FAD(species, min_coverage=min_coverage, 
                        min_sample_coverage=min_sample_coverage, 
                        poly_cov_frac = poly_cov_frac, 
                        calculate_poly_cov_frac=False, 
                        read_support = False, # read_support = False,
                        read_support_no = 2,
                        subject_id=subject_id) 


Not using read support.


In [38]:
#filter out samples without an adequate number of SNVs when all is said and done
sample_with_adequate_snv_count = ~((~np.isnan(Fs)).sum() < min_num_snvs_per_sample)

Fs = Fs.loc[:,sample_with_adequate_snv_count]
Ass = Ass.loc[:,sample_with_adequate_snv_count]
Dss = Dss.loc[:,sample_with_adequate_snv_count]

In [39]:
%%time


fss = Ass.values/(Dss.values + (Dss.values == 0)) #This is so it doesn't produce a na (division by 0)

cluster_As = Ass.values
cluster_Ds = Dss.values
cluster_fs = cluster_As/(cluster_Ds + (cluster_Ds == 0))

## for compatibility in case of threshold number of SNVs
num = min(max_num_snvs,Fs.shape[0])

i_list = Dss.T.mean().sort_values(ascending=False).index[:num]

sys.stderr.write("Processing %s SNVs" % num)

## simply shuffles indices if no threshold is specified
#i_list = sample(range(Fs.shape[0]),num)
i_list_idx = Fs.loc[i_list].index

Ass_sub = Ass.loc[i_list_idx]
Dss_sub = Dss.loc[i_list_idx]
Fs_sub = Fs.loc[i_list_idx]

fss_sub = Ass_sub.values/(Dss_sub.values + (Dss_sub.values == 0))

cluster_As_sub = Ass_sub.values
cluster_Ds_sub = Dss_sub.values
cluster_fs_sub = cluster_As_sub/(cluster_Ds_sub + (cluster_Ds_sub == 0))

D_mat = np.zeros([num,num])
D_mat_1 = D_mat_fun1(num,fss_sub,cluster_Ds_sub,D_mat)
D_mat = np.zeros([num,num]) 
D_mat_2 = D_mat_fun2(num,fss_sub,cluster_Ds_sub,D_mat)

D_mat = np.fmin(D_mat_1,D_mat_2) #I believe this is filling in the minimum of the two polarizations
D_mat = symmetrize(D_mat)

D_mat_1 = pd.DataFrame(D_mat_1,index=Fs_sub.index,columns=Fs_sub.index)
D_mat_2 = pd.DataFrame(D_mat_2,index=Fs_sub.index,columns=Fs_sub.index)

D_mat_close = pd.DataFrame(D_mat < max_d) 

D_mat_close.index = Fs_sub.index
D_mat_close.columns = Fs_sub.index


CPU times: user 16.1 ms, sys: 5.86 ms, total: 21.9 ms
Wall time: 15.9 ms


Processing 250 SNVs

In [40]:
## extracts up to 100 clusters
## in practice all SNVs should fall into one of a fairly small number of clusters
## really should re-write this with a while loop but this works for now
## the idea is that we exhaust all clusters—there should only be a small number of them ultimately

###Idea with while loop:
##### While there are still variants out there, have it try to be clusterings

all_clus_pol = []
all_clus_idx = []
all_clus_A = []
all_clus_D = []

all_clus_F = []

for i in range(n_clusters):
    
    try:

        clus,clus_idxs = return_clus(D_mat_close,Fs_sub)
#         clus,clus_idxs = return_clus(D_mat_close,Fs_sub, co_cluster_pct=0.5) #Finding points that cluster with 25% other points. That's a cluster.
                                                            #We would modify this function to get smaller clusters...
        clus_pol = polarize_clus(clus,clus_idxs,D_mat_1,D_mat_2)
        clus_pol.index = clus_idxs
        D_mat_close = drop_clus_idxs(D_mat_close,clus_idxs)

        if clus_pol.shape[0] > min_cluster_size and clus_pol.shape[0] > Fs.shape[0]*min_cluster_fraction:

            all_clus_D.append(Dss.loc[clus.index].mean().values)
            all_clus_pol.append(clus_pol)
            all_clus_A.append(clus_pol.mean()*all_clus_D[-1])
            all_clus_F.append(clus_pol.mean())

            print(clus_pol.shape[0])

    except:
        pass


In [41]:
## now, choosing a representative SNV from each cluster, and finding all other sites (not just limited to the 20k)
## which are consistent w/ being linked to it

final_clusters = []

all_aligned_sites = []

for i in range(len(all_clus_D)):
    
    sys.stderr.write(f'\n\nCluster {i+1}\n')
    ancD = all_clus_D[i]
    ancF = all_clus_F[i]

    dss = Dss.values
    fss = Fs.values
    
    disAnc_forward = []
    disAnc_backward = []

    for j in range(Dss.shape[0]):
        disAnc_forward.append(calc_dis(ancD,dss[j],ancF,fss[j]))
        disAnc_backward.append(calc_dis(ancD,dss[j],ancF,1-fss[j]))
        if j % 1000 == 0:
            sys.stderr.write(f"\n\t{np.around(100*j/Dss.shape[0],3)}% finished")
    
    disAnc = [min(els) for els in zip(disAnc_forward, disAnc_backward)]
    disAnc = np.array(disAnc)
    aligned_sites = Fs.loc[disAnc < max_d].index
    f_dist =  pd.DataFrame(np.array([disAnc_forward,disAnc_backward]).T,index=Fs.index)
    pols = f_dist.T.idxmin() > 0
    
    aligned_sites = [a for a in aligned_sites if a not in all_aligned_sites]
    
    pols = pols.loc[aligned_sites]
    re_polarize = pols.loc[pols].index
    
    all_aligned_sites.extend(aligned_sites)
    
    Fs_cluster = Fs.loc[aligned_sites]
    
    Fs_cluster.loc[re_polarize] = 1 - Fs_cluster.loc[re_polarize]
        
    final_clusters.append(Fs_cluster)
    

In [None]:
# #SAVING RAW FILE
# species_raw_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_RawCluster.pckl")

# pickle_object = open(species_raw_cluster_path, "wb")
# pickle.dump(final_clusters, pickle_object)
# pickle_object.close()

In [None]:
#LOADING THE RAW FILE
species_raw_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_RawCluster.pckl")

final_clusters = pd.read_pickle(species_raw_cluster_path)

## Creating polarized clusters

Once clusters have been identified and internally polarized, they need to be polarized relative to one another. In the best case, the sum of strain frequencies will be 1. 

In [42]:
no_cluster = False
polarize = True

if len(final_clusters) > 1:
    sys.stderr.write("More than 1 strains detected.\n")
## If only a single cluster is detected, add a second "cluster" which is simply 1 minus the allele frequencies
## in the first cluster
## aids in visualization for people not familiar with this kind of clustering
    
if len(final_clusters) == 0:
    sys.stderr.write("No clusters detected.\n")
    no_cluster = True

    Fs,Ass,Dss = return_FAD(species, min_coverage=0, 
                            min_sample_coverage=min_sample_coverage, 
                            poly_cov_frac = 0, 
                            calculate_poly_cov_frac=False, 
                            read_support = False, 
                            subject_id=subject_id) 
    
    if Fs.mean().mean() < 0.5:
        df_final_f = 1 - Fs.mean().T
        df_final_f.loc[:,:] = 1
        final_f = []
        final_f.append(df_final_f)
        final_clusters = []
        final_clusters.append(1 - Fs)
    else:
        df_final_f = Fs.mean().T
        df_final_f.loc[:,:] = 1
        final_f = []
        final_f.append(df_final_f)
        final_clusters = []
        final_clusters.append(Fs)
        
else:
    sys.stderr.write("At least one strain is present. Polarizing\n")
    if len(final_clusters) == 1:
        final_clusters.append(1-final_clusters[0])

    ## add cluster centroids
    final_f = []
    for cluster in final_clusters:
        final_f.append(cluster.mean())
    df_final_f = pd.DataFrame(final_f)

    ## now, polarize clusters so that the sum of squareds of the centroids to 1 is minimized
    ## the idea here is that accurate strain frequencies should sum to 1

    pol_d2 = {}

    for i in range(df_final_f.shape[0]):
        df_final_f_temp = df_final_f.copy() #Makes a copy of the centroids
        df_final_f_temp.iloc[i] = 1 - df_final_f_temp.iloc[i] #gets the polarized version of ONE of the centroids.
        pol_d2[i] =  ((1 - df_final_f_temp.sum())**2).sum()   #Get the across centroids for all samples (should be close to 1), 
                                                                    #subtract this from 1, and square. Sum all those values
                                                                    #Ideally, this value is really close to 0. 
                                                                    #Add this value to the dictionary.

        pol_d2 = pd.Series(pol_d2)                                #Make the dictionary a series 

        if pol_d2.min() < ((1 - df_final_f.sum())**2).sum(): #If any of the above repolarizations actually made the overall sum of centroids closer to 1, repolarize.
            clus_to_re_pol = pol_d2.idxmin()
            final_f[clus_to_re_pol] = 1 - final_f[clus_to_re_pol]
            final_clusters[clus_to_re_pol] = 1 - final_clusters[clus_to_re_pol]
            df_final_f = pd.DataFrame(final_f)  


    

No clusters detected.
Not using read support.


### Filtering out samples in which each cluster does not have adequate snvs

In [43]:
good_indices = []

for i,cluster in enumerate(final_clusters):
    if i == 0:
        good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
    else:
        new_good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
        good_samples = good_samples & new_good_samples

for i,cluster in enumerate(final_clusters): 
    final_clusters[i] = final_clusters[i].T.loc[good_samples].T
    final_f[i] = final_f[i].T.loc[good_samples]

Fs = Fs.T.loc[good_samples].T
Ass = Ass.T.loc[good_samples].T
Dss = Dss.T.loc[good_samples].T


In [44]:
#Filter all all na columns if there are any - THis is redundant
if len(final_clusters) > 0:
    for i,cluster in enumerate(final_clusters):
        if i == 0:
            mask = ~np.isnan(cluster).all(axis = 0)
        final_clusters[i] = cluster.loc[:,mask]
        final_f[i] = final_f[i][mask]
    Fs = Fs.loc[:,mask]

In [None]:

# #SAVING RAW FILE
# species_centroid_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_ClusterCentroid.pckl")
# species_polarized_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_PolarizedCluster.pckl")
# final_Fs_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_final_Fs.pckl")

# pickle_object = open(species_centroid_cluster_path, "wb")
# pickle.dump(final_f, pickle_object)
# pickle_object.close()

# pickle_object = open(species_polarized_cluster_path, "wb")
# pickle.dump(final_clusters, pickle_object)
# pickle_object.close()

# pickle_object = open(final_Fs_path, "wb")
# pickle.dump(Fs, pickle_object)
# pickle_object.close()

In [None]:
## plot the chosen polarization of strains
## sum of strain frequencies should be ~1 at all timepoints

fig,ax = plt.subplots(figsize=(12,8))
ax.plot(pd.DataFrame(final_f).sum().values,zorder=10,lw=3)
ax.set_ylim([.5,1.5])
ax.axhline(1,color="k",ls="--")
ax.set_ylabel("Sum of strain frequencies",size=20)



## Plotting strain frequencies

Strain frequencies can be plotted using a main key (e.g. Day) and a secondary key (e.g. Capsule), yielding a two-level identification of each sample

In [None]:
subset = True
subset_value = 1000

if subset:
    high_coverage_snv_idxs = Dss.median(axis = 1).sort_values(ascending = False).index

# Create the line plot
fig, ax = plt.subplots(figsize=(20, 8))

for i,f in enumerate(final_f):
    if subset:
        high_coverage_snv_idxs_strain = high_coverage_snv_idxs.intersection(final_clusters[i].index)[:subset_value]
        sns.lineplot(data = final_clusters[i].loc[high_coverage_snv_idxs_strain].T.values, ax = ax, palette=[["red", "blue", "green"][i]]*subset_value, alpha = 0.075, dashes = False, legend = False)
    else:
        sns.lineplot(data = final_clusters[i].T.values, ax = ax, palette=[["red", "blue", "green"][i]]*final_clusters[i].shape[0], alpha = 0.075, dashes = False, legend = False)
    # sns.lineplot(data = f.values, ax = ax, color = ["red", "blue", "green"][i], linewidth = 4)
    sns.lineplot(data = pd.DataFrame(f), x = "sample", y = 0,  ax = ax, color = "black", linewidth = 9)
    sns.lineplot(data = pd.DataFrame(f), x = "sample", y = 0,  ax = ax, color = ["red", "blue", "green"][i], linewidth = 7)

    

# Creating x axis minor ticks and extracting locations for vspan; creating major tick labels
major_ticks = []
major_tick_labels = []
minor_ticks = []
minor_tick_labels = []
time_point = ""
x_ticks_loc = ax.get_xticks()
vspan_counter = 0
vspan_vec = []
for i, column in enumerate(final_clusters[0].columns):
    # Major ticks: timepoint
    new_time_point = "\n\n" +column[1] + ",\n" + column[2]
    if (time_point != new_time_point) & (i != len(final_clusters[0].columns) - 1):
        time_point = new_time_point
        # major_ticks.append(x_ticks_loc[i])
        major_tick_labels.append(time_point)

        # add vspan
        vspan_counter += 1

        if vspan_counter == 1:
            xmin = 0
        else: 
            xmax = x_ticks_loc[i] - 0.5
            vspan_vec.append([xmin, xmax])
            xmin = xmax
    elif (time_point != new_time_point) & (i == len(final_clusters[0].columns) - 1):  
        time_point = new_time_point    
        major_tick_labels.append(time_point)
        
        xmax = x_ticks_loc[i] - 0.5
        vspan_vec.append([xmin, xmax])
        xmin = xmax
        xmax = x_ticks_loc[i]
        vspan_vec.append([xmin, xmax])
    elif (i == len(final_clusters[0].columns) - 1):
        xmax = x_ticks_loc[i]
        vspan_vec.append([xmin, xmax])

    # Minor ticks: sample type
    sample_type = column[0]
    minor_ticks.append(x_ticks_loc[i])
    minor_tick_labels.append(sample_type)

# adding vspan and creating minor ticks
for i,v in enumerate(vspan_vec):
    # adding vspan
    if i % 2 == 1:
        ax.axvspan(v[0],v[1],alpha=.2,color='grey') 
    
    if (i == 0) & (v[1] == 0.5):
        major_ticks.append(0)
    elif (i == len(vspan_vec) - 1):
        major_ticks.append(np.mean([v[0] + 0.5, v[1]]))
    else:
        major_ticks.append(np.mean(v))


ax.set_xticks(major_ticks)
ax.set_xticklabels(major_tick_labels)
ax.set_xticks(minor_ticks, minor=True)
ax.set_xticklabels(minor_tick_labels, minor=True)

ax.xaxis.remove_overlapping_locs = False

plt.tick_params(axis='x',which='major',bottom=False,left=False,top=False) 

# Titles

ax.set_title("%s%s%s%s" % ("Strains of ", species, " in subject ", subject_id), size = 20)
ax.set_xlabel("Sample", size = 20)
ax.set_ylabel("Frequency", size = 20)

        
# Legend

legend_elements = []

for i in np.arange(len(final_f)):
    legend_elements.append(Line2D([0], [0], color=["red", "blue", "green"][i], lw=8, label='Strain %s' % (i+1)))

ax.legend(handles=legend_elements, fontsize = 16)


# Saving
plt.tight_layout()


# out_file = "%s%s%s%s%s" % (strain_phasing_figures_dir, species, "_subject_", subject_id, "_StrainFreq.png")
# fig.savefig(out_file, dpi = 300)



In [None]:
out_file = "%s%s%s%s%s" % (strain_phasing_figures_dir, species, "_subject_", subject_id, "_StrainFreq_ReadSupport251104.png")
fig.savefig(out_file, dpi = 300)

## Reclustering around a new 

## Bootstrapping support and saving file for R plotting

In [45]:
all_strains = True
bootstrap_ci = True
boostrap_k = 100
bootstrap_N = 1000
if len(final_clusters) > 0:
    for i in np.arange(len(final_f)): 
        if i == 0:
            final_f_all_strains = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
            final_f_all_strains['species'] = [species]*len(final_f_all_strains)
            final_f_all_strains['strain'] = [i + 1]*len(final_f_all_strains)
            final_f_all_strains['subject'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][0])
            final_f_all_strains['sample_type'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][9])
            final_f_all_strains['tissue'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][10])
            final_f_all_strains = final_f_all_strains[['sample_type', 
                                                       'date', 
                                                       'time', 
                                                       'sample', 
                                                       'freq', 
                                                       'species', 
                                                       'strain',
                                                       'subject', 
                                                       'tissue']]
            iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().loc[:,['sample', 0.75]]
            iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().loc[:,['sample', 0.25]]
            final_f_all_strains = final_f_all_strains.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
            if bootstrap_ci:
                upper_ci_vec = []
                lower_ci_vec = []
                
                for sample_i in range(final_clusters[i].shape[1]):
                    
                    sample_name = final_clusters[i].iloc[:,sample_i].name[3]
                    snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                    mean_freq_vec = []
                    for n in range(bootstrap_N):
                        sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                        mean_freq = np.mean(sampled_snv_freqs)
                        mean_freq_vec.append(mean_freq)
                    upper_ci = np.quantile(mean_freq_vec, 0.975)
                    lower_ci = np.quantile(mean_freq_vec, 0.025)
                    upper_ci_vec.append(upper_ci)
                    lower_ci_vec.append(lower_ci)
                    
                

        else:
            final_f_all_strains_temp = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
            final_f_all_strains_temp['species'] = [species]*len(final_f_all_strains_temp)
            final_f_all_strains_temp['strain'] = [i + 1]*len(final_f_all_strains_temp)
            final_f_all_strains_temp['subject'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][0])
            final_f_all_strains_temp['sample_type'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][9])
            final_f_all_strains_temp['tissue'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][10])
            final_f_all_strains_temp = final_f_all_strains_temp[['sample_type', 
                                                                 'date', 
                                                                 'time', 
                                                                 'sample', 
                                                                 'freq', 
                                                                 'species', 
                                                                 'strain',
                                                                 'subject', 
                                                                 'tissue']]
            iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().reset_index().loc[:,['sample', 0.75]]
            iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().reset_index().loc[:,['sample', 0.25]]
            final_f_all_strains_temp = final_f_all_strains_temp.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
            final_f_all_strains = pd.concat([final_f_all_strains, final_f_all_strains_temp], ignore_index=True)
            
            if bootstrap_ci:
                
                for sample_i in range(final_clusters[i].shape[1]): #if we did not infer multiple strains, but there may have been multiple strains in the inoculum, just calculate CI for inoculum
             
                    snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                    mean_freq_vec = []

                    for n in range(bootstrap_N):
                        sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                        mean_freq = np.mean(sampled_snv_freqs)
                        mean_freq_vec.append(mean_freq)
                    upper_ci = np.quantile(mean_freq_vec, 0.975)
                    lower_ci = np.quantile(mean_freq_vec, 0.025)
                    upper_ci_vec.append(upper_ci)
                    lower_ci_vec.append(lower_ci)
            
            
        final_f_all_strains['upper_ci'] = upper_ci_vec
        final_f_all_strains['lower_ci'] = lower_ci_vec
    #Renaming
    final_f_all_strains = final_f_all_strains.rename(columns = {0.25: "quantile_25", 0.75: "quantile_75"})
    
    #Saving 
    
    output_file = "%sstrain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (config.project_directory, species, species, subject_id)
    final_f_all_strains.to_csv(output_file, sep = "\t", index = False)
else:
    sys.stderr.write("Only one strain detected\n")
    

# For loop through species list

In [None]:
config.strain_phasing_directory

In [None]:
# variables
## Meta-parameters: experiment with these—no hard and fast rules!

## minimum number of SNVs which need to be clustered together in order to qualify as a "strain"
min_cluster_size = 1000

## minimum fraction of sites which pass our coverage threshold which must be in a cluster in order for it to qualify 
## as a strain
min_cluster_fraction = 1/10

## For computational efficiency, we can downsample the SNVs we actually perform strain phasing on
max_num_snvs = 20000

## distance threshold to be considered linked—lower means trajectories have to be more   
max_d = 3.5

## minimum coverage to consider allele frequency at a site for purposes of clustering
min_coverage = 10 

## minimum average sample coverage at polymorphic sites (e.g. sites in the A/D matrices)
min_sample_coverage = 5


## polymorphic & covered fraction: what percentage of samples does a site need 
## with coverage > min_coverage and polymorphic to be included in downstream analyses? 
poly_cov_frac = 1/5 #

## Number of clusters to calculate
n_clusters = 100

#Minimum number of snvs per sample
min_num_snvs_per_sample = 100


# Load species list
species_path = "%s%s" % (config.analysis_directory, "metadata/species_snps.txt")
with open(species_path, 'r') as file:
    species_list = [line.strip() for line in file]

# Load subject list
sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()
subjects = subject_sample_map.keys()

debug_counter = 0

subjects = ["9"]
species_list = ["Blautia_wexlerae_56130"]

for subject_i,subject_id in enumerate(list(subjects)):
    sys.stderr.write("Processing subject %s (%s / %s)\n" % (subject_id, subject_i + 1, len(subjects)))
    samples_of_interest = list(subject_sample_map[subject_id].keys())
    for species_i,species in enumerate(species_list):
        sys.stderr.write("Processing %s (%s / %s)\n" % (species, species_i + 1, len(species_list)))
        # Defining directories

        strainfinder_dir = "%sinput" % (config.strain_phasing_directory)

        #Raw cluster
        # raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
        raw_cluster_path = "%s%s" % (config.strain_phasing_directory, "strain_clusters/")
        species_raw_cluster_dir = "%s%s/" % (raw_cluster_path, species)

        # Load Fs
        Fs,Ass,Dss = return_FAD(species, min_coverage=min_coverage, 
                                min_sample_coverage=min_sample_coverage, 
                                poly_cov_frac = poly_cov_frac, 
                                calculate_poly_cov_frac=False, 
                                read_support = False, 
                                subject_id=subject_id) 
        
        # Filter out samplers without adequate read coverage
        sample_with_adequate_snv_count = ~((~np.isnan(Fs)).sum() < min_num_snvs_per_sample)

        Fs = Fs.loc[:,sample_with_adequate_snv_count]
        Ass = Ass.loc[:,sample_with_adequate_snv_count]
        Dss = Dss.loc[:,sample_with_adequate_snv_count]

        # Read in raw cluster file
        species_raw_cluster_path = "%s%s%s%s%s" % (species_raw_cluster_dir, species, "_subject_", subject_id, "_RawCluster.pckl")

        if not os.path.exists(species_raw_cluster_path):
            sys.stderr.write(f"\tFile not found: {species_raw_cluster_path}, skipping species.\n")
            continue  # Skip to the next iteration
        else:
            final_clusters = pd.read_pickle(species_raw_cluster_path)

        # Polarizing clusters
        if len(final_clusters) == 0:
            sys.stderr.write("\tNo strains detected, skipping species.\n")
            continue
        
        ## If only a single cluster is detected, add a second "cluster" which is simply 1 minus the allele frequencies
        ## in the first cluster
        ## aids in visualization for people not familiar with this kind of clustering
        if len(final_clusters) == 1:
            final_clusters.append(1-final_clusters[0])

        ## add cluster centroids
        final_f = []
        for cluster in final_clusters:
            final_f.append(cluster.mean())
        df_final_f = pd.DataFrame(final_f)

        ## now, polarize clusters so that the sum of squareds of the centroids to 1 is minimized
        ## the idea here is that accurate strain frequencies should sum to 1
        polarize = True

        pol_d2 = {}

        for i in range(df_final_f.shape[0]):
            df_final_f_temp = df_final_f.copy() #Makes a copy of the centroids
            df_final_f_temp.iloc[i] = 1 - df_final_f_temp.iloc[i] #gets the polarized version of ONE of the centroids.
            pol_d2[i] =  ((1 - df_final_f_temp.sum())**2).sum()   #Get the across centroids for all samples (should be close to 1), 
                                                                        #subtract this from 1, and square. Sum all those values
                                                                        #Ideally, this value is really close to 0. 
                                                                        #Add this value to the dictionary.

            pol_d2 = pd.Series(pol_d2)                                #Make the dictionary a series 

            if pol_d2.min() < ((1 - df_final_f.sum())**2).sum(): #If any of the above repolarizations actually made the overall sum of centroids closer to 1, repolarize.
                clus_to_re_pol = pol_d2.idxmin()
                final_f[clus_to_re_pol] = 1 - final_f[clus_to_re_pol]
                final_clusters[clus_to_re_pol] = 1 - final_clusters[clus_to_re_pol]
                df_final_f = pd.DataFrame(final_f)  

        # Filtering out samples without adequate snvs after polarization step
        good_indices = []

        for i,cluster in enumerate(final_clusters):
            if i == 0:
                good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
            else:
                new_good_samples = (len(final_clusters[i]) - np.isnan(final_clusters[i]).sum(axis = 0) > min_num_snvs_per_sample).values
                good_samples = good_samples & new_good_samples

        if sum(good_samples == False) == len(good_samples): # if no sample is "good" across all clusters, then skip to the next species
            sys.stderr.write("\tNo good samples after filter, skipping species.\n")
            continue

        for i,cluster in enumerate(final_clusters): 
            final_clusters[i] = final_clusters[i].T.loc[good_samples].T
            final_f[i] = final_f[i].T.loc[good_samples]

        Fs = Fs.T.loc[good_samples].T
        Ass = Ass.T.loc[good_samples].T
        Dss = Dss.T.loc[good_samples].T

        # Filtering out columsn that have only na (redundant)
        #Filter all all na columns if there are any - THis is redundant
        if len(final_clusters) > 0:
            for i,cluster in enumerate(final_clusters):
                if i == 0:
                    mask = ~np.isnan(cluster).all(axis = 0)
                final_clusters[i] = cluster.loc[:,mask]
                final_f[i] = final_f[i][mask]
            # Fs = Fs.loc[:,mask]
            Fs = Fs.loc[:, Fs.columns[mask]]

        
        

        # Bootstrapping CI and outputting dataframe
        all_strains = True
        bootstrap_ci = True
        boostrap_k = 100
        bootstrap_N = 1000
        if len(final_clusters) > 0:
            for i in np.arange(len(final_f)): 
                if i == 0:
                    final_f_all_strains = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
                    final_f_all_strains['species'] = [species]*len(final_f_all_strains)
                    final_f_all_strains['strain'] = [i + 1]*len(final_f_all_strains)
                    final_f_all_strains['subject'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][0])
                    final_f_all_strains['sample_type'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][9])
                    final_f_all_strains['tissue'] = final_f_all_strains['sample'].apply(lambda x: sample_metadata_map[x][10])
                    final_f_all_strains = final_f_all_strains[['sample_type', 
                                                            'date', 
                                                            'time', 
                                                            'sample', 
                                                            'freq', 
                                                            'species', 
                                                            'strain',
                                                            'subject', 
                                                            'tissue']]
                    iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().loc[:,['sample', 0.75]]
                    iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().loc[:,['sample', 0.25]]
                    final_f_all_strains = final_f_all_strains.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
                    if bootstrap_ci:
                        upper_ci_vec = []
                        lower_ci_vec = []
                        
                        for sample_i in range(final_clusters[i].shape[1]):
                            
                            sample_name = final_clusters[i].iloc[:,sample_i].name[3]
                            snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                            mean_freq_vec = []
                            for n in range(bootstrap_N):
                                sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                                mean_freq = np.mean(sampled_snv_freqs)
                                mean_freq_vec.append(mean_freq)
                            upper_ci = np.quantile(mean_freq_vec, 0.975)
                            lower_ci = np.quantile(mean_freq_vec, 0.025)
                            upper_ci_vec.append(upper_ci)
                            lower_ci_vec.append(lower_ci)
                            
                        

                else:
                    final_f_all_strains_temp = pd.DataFrame(final_f[i]).reset_index().rename(columns = {0: "freq"})
                    final_f_all_strains_temp['species'] = [species]*len(final_f_all_strains_temp)
                    final_f_all_strains_temp['strain'] = [i + 1]*len(final_f_all_strains_temp)
                    final_f_all_strains_temp['subject'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][0])
                    final_f_all_strains_temp['sample_type'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][9])
                    final_f_all_strains_temp['tissue'] = final_f_all_strains_temp['sample'].apply(lambda x: sample_metadata_map[x][10])
                    final_f_all_strains_temp = final_f_all_strains_temp[['sample_type', 
                                                                        'date', 
                                                                        'time', 
                                                                        'sample', 
                                                                        'freq', 
                                                                        'species', 
                                                                        'strain',
                                                                        'subject', 
                                                                        'tissue']]
                    iqr_75 = pd.DataFrame(final_clusters[i].quantile(0.75)).reset_index().reset_index().loc[:,['sample', 0.75]]
                    iqr_25 = pd.DataFrame(final_clusters[i].quantile(0.25)).reset_index().reset_index().loc[:,['sample', 0.25]]
                    final_f_all_strains_temp = final_f_all_strains_temp.merge(iqr_25, on = ["sample"]).merge(iqr_75, on = ["sample"])
                    final_f_all_strains = pd.concat([final_f_all_strains, final_f_all_strains_temp], ignore_index=True)
                    
                    if bootstrap_ci:
                        
                        for sample_i in range(final_clusters[i].shape[1]): #if we did not infer multiple strains, but there may have been multiple strains in the inoculum, just calculate CI for inoculum
                    
                            snv_freqs = final_clusters[i].iloc[:,sample_i].dropna().to_list()
                            mean_freq_vec = []

                            for n in range(bootstrap_N):
                                sampled_snv_freqs = choices(snv_freqs, k = boostrap_k)
                                mean_freq = np.mean(sampled_snv_freqs)
                                mean_freq_vec.append(mean_freq)
                            upper_ci = np.quantile(mean_freq_vec, 0.975)
                            lower_ci = np.quantile(mean_freq_vec, 0.025)
                            upper_ci_vec.append(upper_ci)
                            lower_ci_vec.append(lower_ci)
                    
                    
                final_f_all_strains['upper_ci'] = upper_ci_vec
                final_f_all_strains['lower_ci'] = lower_ci_vec
            #Renaming
            final_f_all_strains = final_f_all_strains.rename(columns = {0.25: "quantile_25", 0.75: "quantile_75"})
            
            #Saving 
            
            output_file = "%sstrain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (config.project_directory, species, species, subject_id)
            final_f_all_strains.to_csv(output_file, sep = "\t", index = False)
            sys.stderr.write("\t%s done!\n\n" % (species))
        else:
            sys.stderr.write("\tOnly one strain detected\n\n")
            



    

    

    
    

In [None]:
final_f_all_strains.groupby(['sample'])['freq'].sum()