In [17]:
import numpy as np
import pandas as pd
import os
import statsmodels as sm
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
from statsmodels.formula.api import mixedlm

import string


# plotting functions

import matplotlib as plt
import seaborn as sns

# predefined functions
from strain_phasing_functions import *
from microbiome_evolution_functions import *

In [18]:
# Directories
strain_frequency_dir = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters/"
species_list_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/species_snps.txt"


In [19]:
# Load species list
with open(species_list_path, "r") as file:
    species_list = file.readlines()  # Read all lines from the file
    species_list = [line.strip() for line in species_list]  # Remove any leading/trailing whitespace

In [20]:
species_list

['Acidaminococcus_intestini_54097',
 'Actinomyces_graevenitzii_58300',
 'Actinomyces_sp_57735',
 'Actinomyces_sp_62581',
 'Actinomyces_viscosus_57672',
 'Adlercreutzia_equolifaciens_60310',
 'Aggregatibacter_aphrophilus_58143',
 'Akkermansia_muciniphila_55290',
 'Alistipes_finegoldii_56071',
 'Alistipes_indistinctus_62207',
 'Alistipes_onderdonkii_55464',
 'Alistipes_putredinis_61533',
 'Alistipes_senegalensis_58364',
 'Alistipes_shahii_62199',
 'Anaerostipes_hadrus_55206',
 'Atopobium_parvulum_59960',
 'Atopobium_sp_59401',
 'Bacteroidales_bacterium_58650',
 'Bacteroides_caccae_53434',
 'Bacteroides_cellulosilyticus_58046',
 'Bacteroides_clarus_62282',
 'Bacteroides_coprocola_61586',
 'Bacteroides_eggerthii_54457',
 'Bacteroides_finegoldii_57739',
 'Bacteroides_fragilis_54507',
 'Bacteroides_intestinalis_61596',
 'Bacteroides_massiliensis_44749',
 'Bacteroides_ovatus_58035',
 'Bacteroides_pectinophilus_61619',
 'Bacteroides_plebeius_61623',
 'Bacteroides_salyersiae_54873',
 'Bacteroid

In [21]:
# Load maps
sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()
## subject list
subjects = [str(subject) for subject in sorted([int(subject) for subject in list(subject_sample_map.keys())])]

# STRAIN ANOVA  

## 2-strain

In [22]:
# Load strain files iteratively
strain_freq_df_list = []

for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df['sample_type_number'] = strain_freq_df['sample'].apply(lambda x: sample_metadata_map[x][2])

In [23]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_freqs.csv"
strain_freq_df.to_csv(out_path, sep = ",",index=False)

In [None]:
## 1. have only 2 strains
strain_freq_df = strain_freq_df.loc[strain_freq_df.groupby(["species", "subject"])['strain'].transform("nunique") == 2,:]

##### Filter to be only strain 1
strain_freq_df_2 = strain_freq_df[strain_freq_df.strain == 2]
strain_freq_df = strain_freq_df[strain_freq_df.strain == 1]


##### Get stool groups
stool_samples_df = strain_freq_df[strain_freq_df.sample_type == "Stool"].reset_index(drop=True)

##### Filter for only capsule
strain_freq_df = strain_freq_df[strain_freq_df.sample_type == "Capsule"].reset_index(drop=True)
strain_freq_df_2 = strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"].reset_index(drop=True)

## 2. Present in at least two sites (i.e., capsules in at least 1 host)
### Identifying groups with spatial data (i.e., with unique sample_type_number)
spatial_groups = strain_freq_df[strain_freq_df.sample_type == "Capsule"][
    strain_freq_df[strain_freq_df.sample_type == "Capsule"].groupby(["species", "subject", "date", "time"])["sample_type_number"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True)
spatial_groups_2 = strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"][
    strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"].groupby(["species", "subject", "date", "time"])["sample_type_number"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True)
### slicing stool based on spatial groups
strain_freq_df = strain_freq_df.merge(spatial_groups, on=["species", "subject"], how="inner")
strain_freq_df_2 = strain_freq_df_2.merge(spatial_groups_2, on=["species", "subject"], how="inner")


## 3. Present in at least two timepoints 
###  adding timepoints
strain_freq_df['timepoint'] = (
    strain_freq_df.sort_values(['date', 'time'])
    .groupby(['species', 'subject'])
    .apply(lambda group: group.groupby(['date', 'time']).ngroup() + 1)
    .reset_index(level=[0, 1], drop=True)
)
strain_freq_df_2['timepoint'] = (
    strain_freq_df_2.sort_values(['date', 'time'])
    .groupby(['species', 'subject'])
    .apply(lambda group: group.groupby(['date', 'time']).ngroup() + 1)
    .reset_index(level=[0, 1], drop=True)
)
### creating temporal groups
temporal_groups = strain_freq_df[
    strain_freq_df.groupby(["species", "subject", "sample_type_number"])["timepoint"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True) # consider only include capsules that are present in multiple timepoints
temporal_groups_2 = strain_freq_df_2[
    strain_freq_df_2.groupby(["species", "subject", "sample_type_number"])["timepoint"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True) # consider only include capsules that are present in multiple timepoints
### slicing stool based on spatial groups
strain_freq_df = strain_freq_df.merge(temporal_groups, on=["species", "subject"], how="inner")
strain_freq_df_2 = strain_freq_df_2.merge(temporal_groups_2, on=["species", "subject"], how="inner")


## 4. subset stool sampels to only include spatial and tempral samples
stool_samples_df = stool_samples_df.merge(spatial_groups[['species', 'subject']].drop_duplicates().reset_index(drop = True), on=["species", "subject"], how="inner")
stool_samples_df = stool_samples_df.merge(temporal_groups[['species', 'subject']].drop_duplicates().reset_index(drop = True), on=["species", "subject"], how="inner")

##### Reordering columns
strain_freq_df = strain_freq_df[['species', 'strain','subject', 'sample_type','sample_type_number','tissue', 'date', 'time','timepoint', 'sample','freq','quantile_25', 'quantile_75', 'upper_ci','lower_ci']]
strain_freq_df = strain_freq_df.reset_index(drop = True)
strain_freq_df_2 = strain_freq_df_2[['species', 'strain','subject', 'sample_type','sample_type_number','tissue', 'date', 'time','timepoint', 'sample','freq','quantile_25', 'quantile_75', 'upper_ci','lower_ci']]
strain_freq_df_2 = strain_freq_df_2.reset_index(drop = True)
stool_samples_df = stool_samples_df[['species', 'strain','subject', 'sample_type','sample_type_number','tissue', 'date', 'time', 'sample','freq','quantile_25', 'quantile_75', 'upper_ci','lower_ci']]
stool_samples_df = stool_samples_df.reset_index(drop = True)

In [None]:
##### Getting strain freq to equal 1
strain_freq_df_bothstrains = pd.concat([strain_freq_df, strain_freq_df_2], axis = 0)
grouped_sum = strain_freq_df_bothstrains.groupby(
    ['sample','species']
)['freq'].sum().round(6)
all_1_groups = grouped_sum[grouped_sum == 1.0].reset_index()
strain_freq_df_bothstrains = strain_freq_df_bothstrains.merge(
    all_1_groups[['sample','species']],
    on=['sample','species'],
    how='inner'
)


In [None]:
##### minimum strain frequency
msf = 0.05
filtered_df = strain_freq_df_bothstrains.groupby(['species', 'subject', 'strain']).filter(
    lambda group: ((group['freq'] > msf) & (group['freq'] < (1-msf))).any()
)


In [None]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_freqs_filtered.csv"
filtered_df.to_csv(out_path, sep = ",",index=False)

In [None]:
strain_freq_df_bothstrains.loc[(strain_freq_df_bothstrains.species == "Anaerostipes_hadrus_55206") & (strain_freq_df_bothstrains.subject == 9),:]

In [None]:
strain_freq_df_bothstrains[strain_freq_df_bothstrains.species == "Ruminococcus_obeum_61472"]

In [None]:
strain_freq_df.strain.unique()

In [None]:
strain_freq_df.species.unique()

In [None]:
# species_subject_list =  strain_freq_df[['species','subject']].drop_duplicates().reset_index(drop = True)
# out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/strain_species_subject_df.txt"
# species_subject_list.to_csv(out_path, sep = "\t", index=False)

In [None]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_freq_GoodSpeciesSubjectPairs.csv"
strain_freq_df.to_csv(out_path, sep = ",",index=False)

In [None]:
### creating a subject 1 only dataset
strain_freq_df_S1 = strain_freq_df[strain_freq_df.subject == 1] # need to filter for species with multiple capsules for timepoint
### creating a dataset with stool
strain_freq_df_stool = pd.concat([strain_freq_df,stool_samples_df]).sort_values(["species", "subject", "date", "time", "sample_type", "sample_type_number"])
strain_freq_df_stool = strain_freq_df_stool.drop(columns=['timepoint'])

In [None]:
print("Number of species: %s\n" % (len(strain_freq_df.species.unique())))
print("Number of subjects: %s\n" % (len(strain_freq_df.subject.unique())))

In [None]:
# subsetting data to an example species
subset_data  = strain_freq_df.loc[(strain_freq_df.species == "Bacteroides_vulgatus_57955") & (strain_freq_df.subject == 11),:]

In [None]:
model = ols("freq ~ C(sample_type_number) + C(timepoint)", data=subset_data).fit()
anova_table = anova_lm(model, typ=2)
anova_table

In [None]:
#explained variance:
explained_variance_space = anova_table.loc["C(sample_type_number)", "sum_sq"] / anova_table["sum_sq"].sum()
explained_variance_space

In [None]:
#explained variance:
explained_variance_time = anova_table.loc["C(timepoint)", "sum_sq"] / anova_table["sum_sq"].sum()
explained_variance_time

## 3-strain

In [None]:
# Load strain files iteratively
strain_freq_df_list = []

for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df['sample_type_number'] = strain_freq_df['sample'].apply(lambda x: sample_metadata_map[x][2])

In [None]:
## 1. have > 2 strains
strain_freq_df = strain_freq_df.loc[strain_freq_df.groupby(["species", "subject"])['strain'].transform("nunique") > 2,:]
##### Filter for only capsule
# strain_freq_df = strain_freq_df[strain_freq_df.sample_type == "Capsule"].reset_index(drop=True)
# strain_freq_df_2 = strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"].reset_index(drop=True)


In [None]:
# strain_freq_df[['species','subject']].drop_duplicates().to_csv("/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/tables/species_subjects_with_three_strains_FULL.tsv")

In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/tables/three_strain_frequencies.tsv"
strain_freq_df.to_csv(out_path, sep = "\t", index = False)

## 2 strains that might be 3

In [None]:
# Load strain files iteratively
strain_freq_df_list = []

for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df['sample_type_number'] = strain_freq_df['sample'].apply(lambda x: sample_metadata_map[x][2])

In [None]:
## 1. have only 2 strains
strain_freq_df = strain_freq_df.loc[strain_freq_df.groupby(["species", "subject"])['strain'].transform("nunique") == 2,:]

##### Filter to be only strain 1
strain_freq_df_2 = strain_freq_df[strain_freq_df.strain == 2]
strain_freq_df = strain_freq_df[strain_freq_df.strain == 1]


##### Filter for only capsule
strain_freq_df = strain_freq_df[strain_freq_df.sample_type == "Capsule"].reset_index(drop=True)
strain_freq_df_2 = strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"].reset_index(drop=True)


## 2. Present in at least two sites (i.e., capsules in at least 1 host)
### Identifying groups with spatial data (i.e., with unique sample_type_number)
spatial_groups = strain_freq_df[strain_freq_df.sample_type == "Capsule"][
    strain_freq_df[strain_freq_df.sample_type == "Capsule"].groupby(["species", "subject", "date", "time"])["sample_type_number"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True)
spatial_groups_2 = strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"][
    strain_freq_df_2[strain_freq_df_2.sample_type == "Capsule"].groupby(["species", "subject", "date", "time"])["sample_type_number"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True)

### slicing stool based on spatial groups
strain_freq_df = strain_freq_df.merge(spatial_groups, on=["species", "subject"], how="inner")
strain_freq_df_2 = strain_freq_df_2.merge(spatial_groups_2, on=["species", "subject"], how="inner")


## 3. Present in at least two timepoints 
###  adding timepoints
strain_freq_df['timepoint'] = (
    strain_freq_df.sort_values(['date', 'time'])
    .groupby(['species', 'subject'])
    .apply(lambda group: group.groupby(['date', 'time']).ngroup() + 1)
    .reset_index(level=[0, 1], drop=True)
)
strain_freq_df_2['timepoint'] = (
    strain_freq_df_2.sort_values(['date', 'time'])
    .groupby(['species', 'subject'])
    .apply(lambda group: group.groupby(['date', 'time']).ngroup() + 1)
    .reset_index(level=[0, 1], drop=True)
)

### creating temporal groups
temporal_groups = strain_freq_df[
    strain_freq_df.groupby(["species", "subject", "sample_type_number"])["timepoint"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True) # consider only include capsules that are present in multiple timepoints
temporal_groups_2 = strain_freq_df_2[
    strain_freq_df_2.groupby(["species", "subject", "sample_type_number"])["timepoint"].transform("nunique") >= 2
][["species", "subject"]].drop_duplicates().reset_index(drop = True) # consider only include capsules that are present in multiple timepoints

### slicing stool based on spatial groups
strain_freq_df = strain_freq_df.merge(temporal_groups, on=["species", "subject"], how="inner")
strain_freq_df_2 = strain_freq_df_2.merge(temporal_groups_2, on=["species", "subject"], how="inner")



##### Reordering columns
strain_freq_df = strain_freq_df[['species', 'strain','subject', 'sample_type','sample_type_number','tissue', 'date', 'time','timepoint', 'sample','freq','quantile_25', 'quantile_75', 'upper_ci','lower_ci']]
strain_freq_df = strain_freq_df.reset_index(drop = True)
strain_freq_df_2 = strain_freq_df_2[['species', 'strain','subject', 'sample_type','sample_type_number','tissue', 'date', 'time','timepoint', 'sample','freq','quantile_25', 'quantile_75', 'upper_ci','lower_ci']]
strain_freq_df_2 = strain_freq_df_2.reset_index(drop = True)


In [None]:
##### Getting strain freq to equal 1
strain_freq_df_misphased = pd.concat([strain_freq_df, strain_freq_df_2], axis = 0)
grouped_sum = strain_freq_df_misphased.groupby(
    ['sample','species']
)['freq'].sum()
misphased_groups = grouped_sum[grouped_sum != 1.0].reset_index()
# Need to basically 
strain_freq_df_misphased = strain_freq_df_misphased.merge(
    misphased_groups[['sample','species']],
    on=['sample','species'],
    how='inner'
)

In [None]:
misphased_groups[misphased_groups.species == "Anaerostipes_hadrus_55206"]['freq'][34]

In [None]:
filtered_species_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_freqs_filtered.csv"
filtered_species_df = pd.read_csv(filtered_species_path)
filtered_species_df = filtered_species_df[['species','subject']].drop_duplicates()

In [None]:
pairs_to_remove = set(map(tuple, filtered_species_df[['species','subject']].values))

strain_freq_df_misphased_filtered = strain_freq_df_misphased[
    ~strain_freq_df_misphased.set_index(['species','subject']).index.isin(pairs_to_remove)
]

In [None]:
strain_freq_df_misphased_filtered[['species','subject']].drop_duplicates()

## All subjects

In [None]:
# For loop for calculating freq ~ C(sample_type_number) + C(timepoint) for species-host pairs

ols_results = []

for i,species in enumerate(strain_freq_df.species.unique()):
    
    # Extracting hosts
    host_list = strain_freq_df[strain_freq_df['species'] == species].subject.unique()

    for j,host in enumerate(host_list):
        
        # subsetting the data
        subset_data  = strain_freq_df.loc[(strain_freq_df.species == species) & (strain_freq_df.subject == host),:].copy()
        print(f"Iteration: Species={species}, Host={host}, Data Size={subset_data.shape}")
        subset_data['sample_type_number'] = subset_data['sample_type_number'].astype('category')
        subset_data['timepoint'] = subset_data['timepoint'].astype('category')
        subset_data['freq_clt'] = np.log(subset_data['freq'] + 1)

        # skip if necessary
        if len(subset_data['sample_type_number'].unique()) <= 1 or len(subset_data['timepoint'].unique()) <= 1 or subset_data.shape[0] <= 2:
            continue
        
        # building the model
        # model = ols("freq ~ C(sample_type_number) + C(timepoint)", data=subset_data).fit()
        try:
            # model = ols("freq_clt ~ C(sample_type_number) + C(timepoint)", data=subset_data).fit()
            model = ols("freq_clt ~ C(timepoint) + C(sample_type_number)", data=subset_data).fit()
            anova_table = anova_lm(model, typ=2)
            
            # explained variance
            explained_variance_space = anova_table.loc["C(sample_type_number)", "sum_sq"] / anova_table["sum_sq"].sum()
            explained_variance_time = anova_table.loc["C(timepoint)", "sum_sq"] / anova_table["sum_sq"].sum()
            residual_variance = anova_table.loc["Residual", "sum_sq"] / anova_table["sum_sq"].sum()
            
            # adding to the array
            ols_results.append([species, host, explained_variance_space, explained_variance_time, residual_variance])
            
        except Exception as e:
            print(f"Error for species: {species}, host: {host} - {e}")

# building the pandas df
ols_results = np.array(ols_results)
ols_results_df = pd.DataFrame(ols_results, columns=["species", 
                                                    "subject", 
                                                    "explained_variance_space", 
                                                    "explained_variance_time", 
                                                    "residual_variance"])

ols_results_df["explained_variance_space"] = ols_results_df["explained_variance_space"].astype(float)
ols_results_df["explained_variance_time"] = ols_results_df["explained_variance_time"].astype(float)
ols_results_df["residual_variance"] = ols_results_df["residual_variance"].astype(float)







In [None]:
print("Number of species: " + str(len(ols_results_df.species.unique())))
print("Number of hosts: " + str(len(ols_results_df.species.unique())))

In [None]:
# Convert ols_results_df to long-form
ols_results_long = pd.melt(
    ols_results_df,
    id_vars=["species", "subject"],  # Columns to keep as identifiers
    value_vars=["explained_variance_space", "explained_variance_time", "residual_variance"],  # Columns to unpivot
    var_name="variance_type",  # Name for the new column containing labels
    value_name="value"  # Name for the new column containing values
)

ols_results_long['variance_type'] = ols_results_long['variance_type'].apply(
    lambda x: "spatial effect" if x == "explained_variance_space" 
    else "temporal effect" if x == "explained_variance_time" 
    else "residual variance" if x == "residual_variance" 
    else x
)

ols_results_long['value'] = ols_results_long['value']*100


In [None]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_anova.csv"
ols_results_long.to_csv(out_path, sep = ",",index=False)

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))  # Define figure and axes with a specific size

sns.boxplot(data = ols_results_long, x = "variance_type", y = "value", 
            showfliers = False,
            boxprops={'facecolor': 'none', 'alpha': 0.5}, 
            ax = ax)
sns.stripplot(data = ols_results_long, x = "variance_type", y = "value", 
              hue = "species",
              palette= "colorblind",
              jitter = False, 
              size = 7,
              ax = ax)

for (species, subject), group in ols_results_long.groupby(['species', 'subject']):
    x_positions = [list(ols_results_long['variance_type'].unique()).index(v) for v in group['variance_type']]
    y_values = group['value']
    ax.plot(x_positions, y_values, marker='', linestyle='-', color='gray', alpha=0.7)

# Labels
ax.set_xlabel("") 
ax.set_ylabel("Percent variance explained", fontsize=18)    
# tick parameters
ax.tick_params(axis='x', labelsize=14)  
ax.tick_params(axis='y', labelsize=14) 
# Place the legend outside the right edge of the plot
ax.legend(title="Species", fontsize=14, title_fontsize=16, loc='center left', bbox_to_anchor=(1, 0.5))



In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/revisions/strain_ANOVA_species.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight", facecolor = "white")

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))  # Define figure and axes with a specific size

# Define a custom palette for the variance_type column
custom_palette = {"spatial effect": "purple", 
                  "temporal effect": "red", 
                  "residual variance": "gray"}  

sns.boxplot(data = ols_results_long, x = "variance_type", y = "value", 
            showfliers = False,
            boxprops={'facecolor': 'none', 'alpha': 0.5}, 
            ax = ax)
sns.stripplot(data = ols_results_long, x = "variance_type", y = "value", 
              palette=custom_palette, 
              jitter = False, 
              size = 7,
              ax = ax)

for (species, subject), group in ols_results_long.groupby(['species', 'subject']):
    x_positions = [list(ols_results_long['variance_type'].unique()).index(v) for v in group['variance_type']]
    y_values = group['value']
    ax.plot(x_positions, y_values, marker='', linestyle='-', color='gray', alpha=0.7)

# Labels
ax.set_xlabel("") 
ax.set_ylabel("Percent variance explained", fontsize=18)    
# tick parameters
ax.tick_params(axis='x', labelsize=14)  
ax.tick_params(axis='y', labelsize=14) 




In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/revisions/strain_ANOVA.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight", facecolor = "white")

## Subject 1 (including technical variation)

In [None]:
subset_data = strain_freq_df_S1[strain_freq_df_S1.species == "Bacteroides_ovatus_58035"]

In [None]:
# For loop for calculating freq ~ C(sample_type_number) + C(timepoint) for species-host pairs

ols_results = []

for i,species in enumerate(strain_freq_df_S1.species.unique()):
    
    # Extracting hosts
    host = 1

    print(f"Iteration: Species={species}, Host={host}, Data Size={subset_data.shape}")
    
    # subsetting the data
    subset_data  = strain_freq_df_S1.loc[(strain_freq_df_S1.species == species) & (strain_freq_df_S1.timepoint >= 5),:].copy()
    subset_data['sample_type_number'] = subset_data['sample_type_number'].astype('category')
    subset_data['timepoint'] = subset_data['timepoint'].astype('category')
    subset_data['freq_clt'] = np.log(subset_data['freq'] + 1)

    # skip if necessary
    if len(subset_data['sample_type_number'].unique()) <= 1 or len(subset_data['timepoint'].unique()) <= 1 or subset_data.shape[0] <= 2:
        continue
    
    # building the model
    # model = ols("freq ~ C(sample_type_number) + C(timepoint)", data=subset_data).fit()
    try:
        # model = ols("freq_clt ~ C(sample_type_number) + C(timepoint)", data=subset_data).fit()
        model = ols("freq_clt ~ C(timepoint) + C(sample_type_number)", data=subset_data).fit()
        anova_table = anova_lm(model, typ=2)
        
        # explained variance
        explained_variance_space = anova_table.loc["C(sample_type_number)", "sum_sq"] / anova_table["sum_sq"].sum()
        explained_variance_time = anova_table.loc["C(timepoint)", "sum_sq"] / anova_table["sum_sq"].sum()
        residual_variance = anova_table.loc["Residual", "sum_sq"] / anova_table["sum_sq"].sum()
        
        # adding to the array
        ols_results.append([species, host, explained_variance_space, explained_variance_time, residual_variance])
        
    except Exception as e:
        print(f"Error for species: {species}, host: {host} - {e}")

# building the pandas df
ols_results = np.array(ols_results)
ols_results_df = pd.DataFrame(ols_results, columns=["species", 
                                                    "subject", 
                                                    "explained_variance_space", 
                                                    "explained_variance_time", 
                                                    "residual_variance"])

ols_results_df["explained_variance_space"] = ols_results_df["explained_variance_space"].astype(float)
ols_results_df["explained_variance_time"] = ols_results_df["explained_variance_time"].astype(float)
ols_results_df["residual_variance"] = ols_results_df["residual_variance"].astype(float)







In [None]:
# Convert ols_results_df to long-form
ols_results_long = pd.melt(
    ols_results_df,
    id_vars=["species", "subject"],  # Columns to keep as identifiers
    value_vars=["explained_variance_space", "explained_variance_time", "residual_variance"],  # Columns to unpivot
    var_name="variance_type",  # Name for the new column containing labels
    value_name="value"  # Name for the new column containing values
)

ols_results_long['variance_type'] = ols_results_long['variance_type'].apply(
    lambda x: "spatial effect" if x == "explained_variance_space" 
    else "temporal effect" if x == "explained_variance_time" 
    else "residual variance" if x == "residual_variance" 
    else x
)

ols_results_long['value'] = ols_results_long['value']*100


In [None]:
fig, ax = plt.subplots(figsize=(10, 8))  # Define figure and axes with a specific size

# Define a custom palette for the variance_type column
custom_palette = {"spatial effect": "purple", 
                  "temporal effect": "red", 
                  "residual variance": "gray"}  

sns.boxplot(data = ols_results_long, x = "variance_type", y = "value", 
            showfliers = False,
            boxprops={'facecolor': 'none', 'alpha': 0.5}, 
            ax = ax)
sns.stripplot(data = ols_results_long, x = "variance_type", y = "value", 
              palette=custom_palette, 
              jitter = False, 
              size = 7,
              ax = ax)

for (species, subject), group in ols_results_long.groupby(['species', 'subject']):
    x_positions = [list(ols_results_long['variance_type'].unique()).index(v) for v in group['variance_type']]
    y_values = group['value']
    ax.plot(x_positions, y_values, marker='', linestyle='-', color='gray', alpha=0.7)

# Labels
ax.set_xlabel("") 
ax.set_ylabel("Percent variance explained", fontsize=18)    
# tick parameters
ax.tick_params(axis='x', labelsize=14)  
ax.tick_params(axis='y', labelsize=14) 




In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/revisions/strain_ANOVA_technical.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight", facecolor = "white")

# How does coverage impact the number of samples that make it through my filters?

In [None]:
# Load strain files iteratively
print("Processing normal strain clusters")

strain_freq_df_list = []
for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df['sample_type_number'] = strain_freq_df['sample'].apply(lambda x: sample_metadata_map[x][2])
strain_freq_df['coverage_filters'] = "average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)"


print("Processing strain clusters with average coverage of 4 and min per site coverage of 10")

strain_freq_df_list = []
for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters_avg4_min10/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df_min4_cov10 = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df_min4_cov10['sample_type_number'] = strain_freq_df_min4_cov10['sample'].apply(lambda x: sample_metadata_map[x][2])
strain_freq_df_min4_cov10['coverage_filters'] = "average 4 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample"


print("Processing strain clusters with average coverage of 5 and min per site coverage of 5")

strain_freq_df_list = []
for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters_avg5_min5/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df_min5_cov5 = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df_min5_cov5['sample_type_number'] = strain_freq_df_min5_cov5['sample'].apply(lambda x: sample_metadata_map[x][2])
strain_freq_df_min5_cov5['coverage_filters'] = "average 5 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample"


print("Processing strain clusters with average coverage of 4 and min per site coverage of 5")

strain_freq_df_list = []
for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters_avg4_min5/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df_min4_cov5 = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df_min4_cov5['sample_type_number'] = strain_freq_df_min4_cov5['sample'].apply(lambda x: sample_metadata_map[x][2])
strain_freq_df_min4_cov5['coverage_filters'] = "average 4 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample"

print("Processing strain clusters with average coverage of 5 and min per site coverage of 5, and 100 SNVs per cluster")

strain_freq_df_list = []
for species_i,species in enumerate(species_list):
    for subject_i,subject_id in enumerate(subjects):
        file_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/strain_phasing/strain_clusters_avg5_min5_50snvs/%s/%s_subject_%s_strain_frequency.csv" % (species, species, subject_id)
        if not os.path.exists(file_path):
            continue 
    
        strain_freq_df_list.append(pd.read_csv(file_path, sep = "\t"))
    
strain_freq_df_min5_cov5_50snvs = pd.concat(strain_freq_df_list, ignore_index=True)
strain_freq_df_min5_cov5_50snvs['sample_type_number'] = strain_freq_df_min5_cov5_50snvs['sample'].apply(lambda x: sample_metadata_map[x][2])
strain_freq_df_min5_cov5_50snvs['coverage_filters'] = "average 5 reads,\nper SNV minimum 5 reads,\n50 SNVs per sample"



In [None]:
# Concatenate all strain frequency DataFrames into a single DataFrame
all_strain_freq_df = pd.concat(
    [
        strain_freq_df,
        strain_freq_df_min4_cov10,
        strain_freq_df_min5_cov5,
        strain_freq_df_min4_cov5,
        strain_freq_df_min5_cov5_50snvs
    ],
    ignore_index=True
)



In [None]:
sample_no_dataframe = pd.DataFrame(all_strain_freq_df.groupby(['species', 'subject', 'coverage_filters']).size()).reset_index().rename(columns={0:"number_of_samples"})

In [None]:
# Convert to wide format
sample_no_dataframe_wide = sample_no_dataframe.pivot(index=['species', 'subject'], columns='coverage_filters', values='number_of_samples')
sample_no_dataframe_wide.columns = sample_no_dataframe_wide.columns.astype(str)
sample_no_dataframe_wide = sample_no_dataframe_wide.reset_index()


In [None]:
sample_no_dataframe_wide

In [None]:
condition_1 = sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)'] < sample_no_dataframe_wide['average 4 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample']
condition_2 = sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)'] < sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample']
condition_3 = sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)'] < sample_no_dataframe_wide['average 4 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample']
condition_4 = sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)'] < sample_no_dataframe_wide['average 5 reads,\nper SNV minimum 5 reads,\n50 SNVs per sample']

conditions_with_an_increase = sample_no_dataframe_wide[condition_1 | condition_2 | condition_3 | condition_4]





In [None]:
conditions_with_an_increase

In [None]:
condition_1 = sample_no_dataframe_wide.species == "Adlercreutzia_equolifaciens_60310"
condition_2 = sample_no_dataframe_wide.subject == 12
sample_no_dataframe_wide[condition_1 & condition_2]


In [None]:
fig, ax = plt.subplots(figsize=(10, 8))  # Define figure and axes with a specific size

desired_order = [
    'average 5 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample\n(default)',
    'average 4 reads,\nper SNV minimum 10 reads,\n100 SNVs per sample',
    'average 5 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample',
    'average 4 reads,\nper SNV minimum 5 reads,\n100 SNVs per sample',
    'average 5 reads,\nper SNV minimum 5 reads,\n50 SNVs per sample'
]

sample_no_dataframe['coverage_filters'] = pd.Categorical(
    sample_no_dataframe['coverage_filters'],
    categories=desired_order,
    ordered=True
)

sns.boxplot(data=sample_no_dataframe, x="coverage_filters", y="number_of_samples", 
            showfliers=False,
            palette = "colorblind",
            boxprops={'alpha': 0.5},  # Correct transparency
            ax=ax)
sns.stripplot(data=sample_no_dataframe, x="coverage_filters", y="number_of_samples", 
              hue="coverage_filters",
              palette="colorblind",
              jitter=False, 
              size=7,
              ax=ax)

for (species, subject), group in sample_no_dataframe.groupby(['species', 'subject']):
    x_positions = [list(sample_no_dataframe['coverage_filters'].unique()).index(v) for v in group['coverage_filters']]
    y_values = group['number_of_samples']
    ax.plot(x_positions, y_values, marker='', linestyle='-', color='gray', alpha=0.7)

# Labels
ax.set_xlabel("") 
ax.set_ylabel("number of samples", fontsize=18)    
# Tick parameters
ax.tick_params(axis='x', labelsize=14, rotation=45)  # Rotate x-axis labels by 45 degrees
ax.tick_params(axis='y', labelsize=14) 

# Align x-axis labels to the right
ax.set_xticklabels(ax.get_xticklabels(), ha='right')

# Remove the legend
ax.legend_.remove()

plt.show()

In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/revisions/effect_of_coverage_on_the_number_of_samples.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight", facecolor = "white")