In [1]:
import numpy as np
import pandas as pd
import os
import statsmodels.api as sm
from statsmodels.formula.api import ols
import string
import itertools
from datetime import datetime
import random


# plotting functions

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns


# predefined functions
os.sys.path.append('/u/project/ngarud/michaelw/microbiome_evolution/py3.8/microbiome_evolution_SHALON/postprocessing_scripts/')
from parse_midas_data import *
from diversity_utils import *
from calculate_intersample_changes import *
from core_gene_utils import *
from parse_patric import *
import parse_patric

In [2]:
# load species list
species_list_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/species_snps.txt"

with open(species_list_path, 'r') as file:
    species_list = file.readlines()
species_list = [species.strip() for species in species_list]



In [3]:
# Load maps
sample_metadata_map = parse_sample_metadata_map()
sample_list = list(sample_metadata_map.keys())
subject_sample_map = parse_subject_sample_map()

# Calcuating opportunities, rates of change etc.

## SNV change rates

In [None]:
# load species list
species_list_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/species_snps.txt"

with open(species_list_path, 'r') as file:
    species_list = file.readlines()
species_list = [species.strip() for species in species_list]


In [None]:
# Load SNV changes dataframe
in_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/snp_changes.txt.bz2"
snv_changes_df = pd.read_csv(in_path, sep = ",")

In [None]:
snv_summary_df = pd.DataFrame(snv_changes_df.groupby(['species', 'sample1', 'sample2']).size()).reset_index().rename(columns = {0:"no_of_changes"})

In [None]:
snv_summary_df

## QP pairs (even with 0 changes)

In [4]:
# path to output
haploid_summary_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/QP_status.csv"
generate = False
save = True
if  os.path.exists(haploid_summary_path) and (not generate):
    print("Loading preexisting QP summary file")
    samples_df = pd.read_csv(haploid_summary_path, sep = ",")

    haploid_samples = samples_df[samples_df['haploid']]

else:

    print("Creating QP summary file from scratch.")
    # load species list
    species_list_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/species_snps.txt"

    with open(species_list_path, 'r') as file:
        species_list = file.readlines()
    species_list = [species.strip() for species in species_list]

    metadata_map = parse_midas_data.parse_sample_metadata_map()
    samples_df = pd.DataFrame(columns = ['accession_id', 'species', 'haploid', 'subject_id'])

    accession_id_vec = []
    species_vec = []
    haploid_vec = []
    subject_id_vec = []

    print("\nCalculating lineage structure.")

    counter = 0
    no_of_species = len(species_list)
    for species in species_list:
        counter += 1
        print("Processing " + species + " ({}/{})".format(counter, no_of_species))
        high_coverage_samples = diversity_utils.calculate_highcoverage_samples(species)
        haploid_samples = diversity_utils.calculate_haploid_samples(species, use_HMP_freqs = True)
        
        haploid_boolean_vec = [True if sample in haploid_samples else False for sample in high_coverage_samples]
        subject_id_vec_temp = [metadata_map[sample][0] for sample in high_coverage_samples]
        
        species_vec = species_vec + [species]*len(high_coverage_samples)
        accession_id_vec = accession_id_vec + list(high_coverage_samples)
        haploid_vec = haploid_vec + haploid_boolean_vec
        subject_id_vec = subject_id_vec + subject_id_vec_temp

    samples_df['accession_id'] = accession_id_vec
    samples_df['species'] = species_vec
    samples_df['haploid'] = haploid_vec
    samples_df['subject_id'] = subject_id_vec

    if save:
        samples_df.to_csv(haploid_summary_path, sep = ",", index = False)



Loading preexisting QP summary file


In [5]:
# path to output
change_summary_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/change_summary_Full.csv"
generate = False
save = True

if os.path.exists(change_summary_path) and (not generate):
    print("Loading preexisting change summary file")
    change_summary_df = pd.read_csv(change_summary_path, sep = ",")
    
else:

    haploid_samples = samples_df[samples_df['haploid']]
        
    change_summary_array = []
    haploid_species = haploid_samples.species.unique()

    print("\nCalculating change rate.")

    for i,species in enumerate(haploid_species):
        print("Processing %s (%d/%d)" % (species, i+1, len(haploid_species)))
        intersample_change_map = load_intersample_change_map(species)
        samples = haploid_samples[haploid_samples['species'] == species]['accession_id'].unique()
        sample_pairs = list(itertools.combinations(samples, 2))
        for sample_pair in sample_pairs:
            if (sample_pair[1], sample_pair[0]) in intersample_change_map:
                sample_pair = (sample_pair[1], sample_pair[0])

            if sample_pair not in intersample_change_map:
                print("%s not in intersample change map for %s" % (str(sample_pair), species))
                continue

            opportunities = intersample_change_map[sample_pair]['snps'][0]
            gene_opportunities = intersample_change_map[sample_pair]['genes'][0]
            snv_changes = len(intersample_change_map[sample_pair]['snps'][2])
            gene_changes = len(intersample_change_map[sample_pair]['genes'][2])
            rate_of_change = snv_changes/opportunities
            gene_rate_of_change = gene_changes/gene_opportunities

            change_summary_array.append([species, sample_pair[0], sample_pair[1], snv_changes, opportunities, rate_of_change, gene_changes, gene_opportunities, gene_rate_of_change])
            

    change_summary_df = pd.DataFrame(change_summary_array, columns=['species', 'sample_1', 'sample_2', 'snv_changes', 'opportunities', 'rate_of_change', 'gene_changes', 'gene_opportunities', 'gene_rate_of_change'])

    # annotate
    print("\n\nAnnotating.")
    change_summary_df['subject_1'] = change_summary_df['sample_1'].apply(lambda x: sample_metadata_map[x][0])
    change_summary_df['subject_2'] = change_summary_df['sample_2'].apply(lambda x: sample_metadata_map[x][0])
    change_summary_df['device_type_1'] = change_summary_df['sample_1'].apply(lambda x: sample_metadata_map[x][2])
    change_summary_df['device_type_2'] = change_summary_df['sample_2'].apply(lambda x: sample_metadata_map[x][2])
    timestamp_format = '%Y-%m-%dT%H:%M:%SZ'
    change_summary_df['day_1'] = change_summary_df[['sample_1','device_type_1']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample_1']][5], timestamp_format).strftime('%Y-%m-%d') if row['device_type_1'] == "Stool" or row['device_type_1'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample_1']][3], timestamp_format).strftime('%Y-%m-%d'), axis = 1)
    change_summary_df['day_2'] = change_summary_df[['sample_2','device_type_2']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample_2']][5], timestamp_format).strftime('%Y-%m-%d') if row['device_type_2'] == "Stool" or row['device_type_2'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample_2']][3], timestamp_format).strftime('%Y-%m-%d'), axis = 1)
    change_summary_df['time_1'] = change_summary_df[['sample_1','device_type_1']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample_1']][5], timestamp_format).strftime('%H:%M:%S') if row['device_type_1'] == "Stool" or row['device_type_1'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample_1']][3], timestamp_format).strftime('%H:%M:%S'), axis = 1)
    change_summary_df['time_2'] = change_summary_df[['sample_2','device_type_2']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample_2']][5], timestamp_format).strftime('%H:%M:%S') if row['device_type_2'] == "Stool" or row['device_type_2'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample_2']][3], timestamp_format).strftime('%H:%M:%S'), axis = 1)

    # Label as within timepoint or between time point
    change_summary_df['timepoint_orientation'] = change_summary_df.apply(lambda row: "Within timepoint" if row['day_1'] == row['day_2'] and row['time_1'] == row['time_2'] else "Between timepoint", axis = 1)
    change_summary_df['device_orientation'] = change_summary_df.apply(lambda row: "Same device" if row['device_type_1'] == row['device_type_2'] else "Different device", axis = 1)
    change_summary_df['datetime_1'] = pd.to_datetime(change_summary_df['day_1'] + ' ' + change_summary_df['time_1'])
    change_summary_df['datetime_2'] = pd.to_datetime(change_summary_df['day_2'] + ' ' + change_summary_df['time_2'])
    change_summary_df['time_difference_hours'] = (change_summary_df['datetime_2'] - change_summary_df['datetime_1']).dt.total_seconds() / 3600

    if save:
        change_summary_df.to_csv(change_summary_path, sep = ",", index=False)
        

        

Loading preexisting change summary file


In [None]:
change_summary_df

In [6]:
# Filter for <= 20 SNV changes
change_summary_df_full = change_summary_df.copy()
change_summary_df = change_summary_df[change_summary_df['snv_changes'] <= 20]
# Filter for within host
change_summary_df = change_summary_df[change_summary_df['subject_1'] == change_summary_df['subject_2']]
# Filter for between capsule
change_summary_df = change_summary_df[(change_summary_df['device_type_1'] != "Stool") &
                                      (change_summary_df['device_type_1'] != "Saliva") &
                                      (change_summary_df['device_type_2'] != "Stool") &
                                      (change_summary_df['device_type_2'] != "Saliva")]
# Filter for not the same device in the same timepoint (which only happens in subject 1)
change_summary_df = change_summary_df[~((change_summary_df['timepoint_orientation'] == "Within timepoint") &
                                      (change_summary_df['device_orientation'] == "Same device"))]
# Reset index
change_summary_df = change_summary_df.reset_index(drop=True)


In [None]:
# # Saving change summary
# change_summary_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/change_summary_CapsulesOnly.csv"

# change_summary_df.to_csv(change_summary_path, sep = ",")


## Estimating text numbers

### Estimate 1: How many species-subjects show along the gut differentiation comparisons have changes (i.e., > 0 changes).  


In [7]:
within_timepoint_change_pairs = change_summary_df[change_summary_df.timepoint_orientation == "Within timepoint"][['species','subject_1']].drop_duplicates().rename(columns = {"subject_1":"subject"}).reset_index(drop = True)

In [8]:
len(within_timepoint_change_pairs)

56

In [9]:
within_timepoint_change_pairs_with_changes = change_summary_df[(change_summary_df.timepoint_orientation == "Within timepoint") & (change_summary_df.snv_changes > 0)]

In [10]:
within_timepoint_change_pairs_with_changes

Unnamed: 0,species,sample_1,sample_2,snv_changes,opportunities,rate_of_change,gene_changes,gene_opportunities,gene_rate_of_change,subject_1,...,device_type_2,day_1,day_2,time_1,time_2,timepoint_orientation,device_orientation,datetime_1,datetime_2,time_difference_hours
231,Anaerostipes_hadrus_55206,SRR18585057,SRR18585058,8,875267.0,9.140068e-06,3,2360.0,0.001271,8,...,Capsule 3,2020-08-18,2020-08-18,05:00:00,05:00:00,Within timepoint,Different device,2020-08-18 05:00:00,2020-08-18 05:00:00,0.0
396,Bacteroides_massiliensis_44749,SRR18585022,SRR18585025,2,2419000.0,8.267879e-07,0,2963.0,0.0,2,...,Capsule 1,2020-06-27,2020-06-27,06:55:00,06:55:00,Within timepoint,Different device,2020-06-27 06:55:00,2020-06-27 06:55:00,0.0
889,Bacteroides_vulgatus_57955,SRR18585193,SRR18585194,5,2738090.0,1.82609e-06,0,8473.0,0.0,11,...,Capsule 3,2020-09-16,2020-09-16,04:15:00,04:15:00,Within timepoint,Different device,2020-09-16 04:15:00,2020-09-16 04:15:00,0.0
909,Bacteroides_vulgatus_57955,SRR18585218,SRR18585220,5,2820210.0,1.772918e-06,0,8473.0,0.0,8,...,Capsule 1,2020-08-19,2020-08-19,05:15:00,05:15:00,Within timepoint,Different device,2020-08-19 05:15:00,2020-08-19 05:15:00,0.0
1069,Blautia_wexlerae_56130,SRR18584991,SRR18584993,3,1300190.0,2.307355e-06,1,3628.0,0.000276,10,...,Capsule 1,2020-08-30,2020-08-30,22:30:00,22:30:00,Within timepoint,Different device,2020-08-30 22:30:00,2020-08-30 22:30:00,0.0
1111,Desulfovibrio_piger_61475,SRR18585200,SRR18585202,6,660160.0,9.088706e-06,1,2476.0,0.000404,9,...,Capsule 1,2020-08-25,2020-08-25,23:45:00,23:45:00,Within timepoint,Different device,2020-08-25 23:45:00,2020-08-25 23:45:00,0.0
1472,Ruminococcus_gnavus_57638,SRR18585019,SRR18585020,5,1453270.0,3.440517e-06,1,4633.0,0.000216,2,...,Capsule 3,2020-06-30,2020-06-30,21:30:00,21:30:00,Within timepoint,Different device,2020-06-30 21:30:00,2020-06-30 21:30:00,0.0


In [11]:
within_timepoint_change_pairs_with_changes

Unnamed: 0,species,sample_1,sample_2,snv_changes,opportunities,rate_of_change,gene_changes,gene_opportunities,gene_rate_of_change,subject_1,...,device_type_2,day_1,day_2,time_1,time_2,timepoint_orientation,device_orientation,datetime_1,datetime_2,time_difference_hours
231,Anaerostipes_hadrus_55206,SRR18585057,SRR18585058,8,875267.0,9.140068e-06,3,2360.0,0.001271,8,...,Capsule 3,2020-08-18,2020-08-18,05:00:00,05:00:00,Within timepoint,Different device,2020-08-18 05:00:00,2020-08-18 05:00:00,0.0
396,Bacteroides_massiliensis_44749,SRR18585022,SRR18585025,2,2419000.0,8.267879e-07,0,2963.0,0.0,2,...,Capsule 1,2020-06-27,2020-06-27,06:55:00,06:55:00,Within timepoint,Different device,2020-06-27 06:55:00,2020-06-27 06:55:00,0.0
889,Bacteroides_vulgatus_57955,SRR18585193,SRR18585194,5,2738090.0,1.82609e-06,0,8473.0,0.0,11,...,Capsule 3,2020-09-16,2020-09-16,04:15:00,04:15:00,Within timepoint,Different device,2020-09-16 04:15:00,2020-09-16 04:15:00,0.0
909,Bacteroides_vulgatus_57955,SRR18585218,SRR18585220,5,2820210.0,1.772918e-06,0,8473.0,0.0,8,...,Capsule 1,2020-08-19,2020-08-19,05:15:00,05:15:00,Within timepoint,Different device,2020-08-19 05:15:00,2020-08-19 05:15:00,0.0
1069,Blautia_wexlerae_56130,SRR18584991,SRR18584993,3,1300190.0,2.307355e-06,1,3628.0,0.000276,10,...,Capsule 1,2020-08-30,2020-08-30,22:30:00,22:30:00,Within timepoint,Different device,2020-08-30 22:30:00,2020-08-30 22:30:00,0.0
1111,Desulfovibrio_piger_61475,SRR18585200,SRR18585202,6,660160.0,9.088706e-06,1,2476.0,0.000404,9,...,Capsule 1,2020-08-25,2020-08-25,23:45:00,23:45:00,Within timepoint,Different device,2020-08-25 23:45:00,2020-08-25 23:45:00,0.0
1472,Ruminococcus_gnavus_57638,SRR18585019,SRR18585020,5,1453270.0,3.440517e-06,1,4633.0,0.000216,2,...,Capsule 3,2020-06-30,2020-06-30,21:30:00,21:30:00,Within timepoint,Different device,2020-06-30 21:30:00,2020-06-30 21:30:00,0.0


In [12]:
len(within_timepoint_change_pairs_with_changes)

7

In [13]:
between_timepoint_change_pairs = change_summary_df[change_summary_df.timepoint_orientation == "Between timepoint"][['species','subject_1']].drop_duplicates().rename(columns = {"subject_1":"subject"}).reset_index(drop = True)

In [14]:
len(between_timepoint_change_pairs)

92

In [15]:
len(between_timepoint_change_pairs.subject.unique())

12

In [16]:
between_timepoint_change_pairs_with_changes = change_summary_df[(change_summary_df.timepoint_orientation == "Between timepoint") & (change_summary_df.snv_changes > 0)][['species','subject_1']].drop_duplicates().reset_index(drop=True)

In [17]:
len(between_timepoint_change_pairs_with_changes)

30

In [18]:
30/92

0.32608695652173914

### Estimate 2: How many species-subjects meet the following requirements:  
- at least two timepoints with at least two capsule devices

In [19]:
# Load maps
sample_metadata_map = parse_sample_metadata_map()

In [20]:
# annotate haploid_samples
## device type
haploid_samples.loc[:, 'device_type'] = haploid_samples['accession_id'].apply(lambda x: sample_metadata_map[x][2])
timestamp_format = '%Y-%m-%dT%H:%M:%SZ'
haploid_samples.loc[:, 'day'] = haploid_samples[['accession_id', 'device_type']].apply(
    lambda row: datetime.strptime(sample_metadata_map[row['accession_id']][5], timestamp_format).strftime('%Y-%m-%d')
    if row['device_type'] == "Stool" or row['device_type'] == "Saliva"
    else datetime.strptime(sample_metadata_map[row['accession_id']][3], timestamp_format).strftime('%Y-%m-%d'),
    axis=1
)
haploid_samples.loc[:, 'time'] = haploid_samples[['accession_id', 'device_type']].apply(
    lambda row: datetime.strptime(sample_metadata_map[row['accession_id']][5], timestamp_format).strftime('%H:%M:%S')
    if row['device_type'] == "Stool" or row['device_type'] == "Saliva"
    else datetime.strptime(sample_metadata_map[row['accession_id']][3], timestamp_format).strftime('%H:%M:%S'),
    axis=1
)
haploid_samples['datetime'] = pd.to_datetime(haploid_samples['day'] + ' ' + haploid_samples['time'])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  haploid_samples.loc[:, 'device_type'] = haploid_samples['accession_id'].apply(lambda x: sample_metadata_map[x][2])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  haploid_samples.loc[:, 'day'] = haploid_samples[['accession_id', 'device_type']].apply(
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ha

In [21]:
# number of unique device types per haploid samples 
device_types_per_tmpt = haploid_samples.groupby(['species', 'subject_id', 'datetime'])['device_type'].nunique().reset_index()

# Filter timepoints with more than 2 unique device types
timepoints_with_gt2 = device_types_per_tmpt[device_types_per_tmpt['device_type'] >= 2]

# Count species-subject pairs with at least two such timepoints
species_subject_pairs = timepoints_with_gt2.groupby(['species', 'subject_id']).filter(lambda x: len(x) >= 2).reset_index(drop = True)

# Optionally, re-group the filtered data for further analysis
grouped_species_subject_pairs = species_subject_pairs[['species', 'subject_id']].drop_duplicates().reset_index(drop = True)


In [22]:
grouped_species_subject_pairs

Unnamed: 0,species,subject_id
0,Alistipes_onderdonkii_55464,9
1,Alistipes_onderdonkii_55464,11
2,Alistipes_putredinis_61533,9
3,Alistipes_putredinis_61533,11
4,Bacteroides_massiliensis_44749,2
5,Bacteroides_vulgatus_57955,2
6,Bacteroides_vulgatus_57955,8
7,Bacteroides_vulgatus_57955,11
8,Bifidobacterium_adolescentis_56815,12
9,Bilophila_wadsworthia_57364,11


### Estimate 3: How many of these have (1) temporal changes (2) within host changes

In [23]:
within_vec = []
between_vec = []
for i,row in grouped_species_subject_pairs.iterrows():
    species = row["species"]
    subject_id = row["subject_id"]

    orientations_with_changes = change_summary_df[(change_summary_df['species'] == species) & (change_summary_df['subject_1'] == subject_id) & (change_summary_df['snv_changes'] > 0)].timepoint_orientation.unique().tolist()

    if "Within timepoint" in orientations_with_changes:
        within_vec.append(True)
    else:
        within_vec.append(False)
    
    if "Between timepoint" in orientations_with_changes:
        between_vec.append(True)
    else:
        between_vec.append(False)


change_type_summary_df = grouped_species_subject_pairs.copy()

change_type_summary_df["within_timepoint_changes"] = within_vec
change_type_summary_df["between_timepoint_changes"] = between_vec

    


    

In [24]:
change_type_summary_df[change_type_summary_df.between_timepoint_changes].reset_index(drop = True
                                                                                    )

Unnamed: 0,species,subject_id,within_timepoint_changes,between_timepoint_changes
0,Alistipes_onderdonkii_55464,11,False,True
1,Alistipes_putredinis_61533,11,False,True
2,Bacteroides_massiliensis_44749,2,True,True
3,Bacteroides_vulgatus_57955,8,True,True
4,Bacteroides_vulgatus_57955,11,True,True
5,Bilophila_wadsworthia_57364,11,False,True
6,Desulfovibrio_piger_61475,9,True,True
7,Escherichia_coli_58110,3,False,True
8,Parabacteroides_distasonis_56985,1,False,True


In [25]:
change_type_summary_df

Unnamed: 0,species,subject_id,within_timepoint_changes,between_timepoint_changes
0,Alistipes_onderdonkii_55464,9,False,False
1,Alistipes_onderdonkii_55464,11,False,True
2,Alistipes_putredinis_61533,9,False,False
3,Alistipes_putredinis_61533,11,False,True
4,Bacteroides_massiliensis_44749,2,True,True
5,Bacteroides_vulgatus_57955,2,False,False
6,Bacteroides_vulgatus_57955,8,True,True
7,Bacteroides_vulgatus_57955,11,True,True
8,Bifidobacterium_adolescentis_56815,12,False,False
9,Bilophila_wadsworthia_57364,11,False,True


### Estimate 4: getting gene descriptsion of certain species

In [26]:
import parse_patric
species = "Bacteroides_vulgatus_57955"
genome_ids = parse_midas_data.get_ref_genome_ids(species)
non_shared_genes = core_gene_utils.parse_non_shared_reference_genes(species)
gene_descriptions = parse_patric.load_patric_gene_descriptions(genome_ids, non_shared_genes)
centroid_gene_map = parse_midas_data.load_centroid_gene_map(species)

In [27]:
gene_descriptions['435590.9.peg.777']

'TPR-domain-containing protein'

## SNV_freq changes for within-host changes

In [None]:
site_id_gene_dict = dict({})

only_haploid = False
only_high_coverage = False
min_depth = 1
final_line_number = 1e10
timestamp_format = '%Y-%m-%dT%H:%M:%SZ'

sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()


# for species_subject_i,row in change_type_summary_df[change_type_summary_df.within_timepoint_changes].reset_index(drop = False).iterrows():
for species_subject_i,row in within_timepoint_change_pairs_with_changes.rename(columns={"subject_1":"subject_id"}).reset_index(drop = False).iterrows():
    species = row['species']
    subject_id = row['subject_id']

    # sys.stderr.write("Processing species %s in subject %d (%d / %d species-subject pairs)\n" % (species, subject_id, species_subject_i + 1, len(change_type_summary_df[change_type_summary_df.within_timepoint_changes])))
    sys.stderr.write("Processing species %s in subject %d (%d / %d species-subject pairs)\n" % (species, subject_id, species_subject_i + 1, len(within_timepoint_change_pairs_with_changes)))

    # Identify sample pair
    species_subject_withinhost_sample_pair = change_summary_df[(change_summary_df['species'] == species) & (change_summary_df['subject_1'] == subject_id) & (change_summary_df['timepoint_orientation'] == "Within timepoint") & (change_summary_df['snv_changes'] > 0)][['sample_1', 'sample_2', 'snv_changes']].sort_values(by = ['snv_changes'], ascending = False)
    sample_pair = (species_subject_withinhost_sample_pair.iloc[0,0], species_subject_withinhost_sample_pair.iloc[0,1])

    # Extract SNVs
    intersample_change_map = load_intersample_change_map(species)

    if sample_pair in intersample_change_map:
        snvs = intersample_change_map[sample_pair]['snps'][2]
    elif (sample_pair[1], sample_pair[0]) in intersample_change_map:
        snvs = intersample_change_map[(sample_pair[1], sample_pair[0])]['snps'][2]
    else:
        sys.stderr.write("%s not in intersample change map for %s\n" % (str(sample_pair), species))
        continue

    snvs = [(snv[1], snv[2]) for snv in snvs]
    
    # Extract SNV frequencies
    ## Define paths
    snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
    snps_depth_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_depth.txt.bz2")
    snps_info_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_info.txt.bz2")

    ## define desired samples

    if only_haploid:
        sys.stderr.write("Only considering haploid samples\n")
        haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
        desired_samples = [desired_sample for desired_sample in desired_samples if (desired_sample in haploid_samples) & (sample_metadata_map[desired_sample][0] == str(subject_id))]
    elif only_high_coverage:
        sys.stderr.write("Only considering high coverage samples\n")
        high_coverage_samples = calculate_highcoverage_samples(species)
        desired_samples = [desired_sample for desired_sample in desired_samples if (desired_sample in high_coverage_samples) & (sample_metadata_map[desired_sample][0] == str(subject_id))]
    else:
        sys.stderr.write("Considering all samples\n")
        desired_samples = list(subject_sample_map[str(subject_id)].keys())

    
    # Looping through the files
    line_counter = 0
    loci_extracted = 0
    site_id_vec = []
    allele_frequencies = []

    with bz2.open(snps_freq_path, 'rt') as file,  bz2.open(snps_depth_path, 'rt') as depth_file, bz2.open(snps_info_path, 'rt') as info_file:
        for line, depth_line, info_line in zip(file, depth_file, info_file):
            if line_counter == 0:
                items = line.split()[1:]    
                samples_in_file = parse_merged_sample_names(items)
                desired_sample_idxs = []
                for sample in desired_samples:
                    if sample not in samples_in_file:
                        continue
                    desired_sample_idxs.append( np.nonzero(samples_in_file==sample)[0][0] )
                desired_sample_idxs = np.array(desired_sample_idxs)    
                desired_samples = samples_in_file[desired_sample_idxs]

                line_counter += 1
            else:
                if line_counter == final_line_number:
                    sys.stderr.write("Breaking at line " + str(final_line_number) + ". " + str(loci_extracted)+ " loci extracted.\n")
                    break

                items = line.split()
                depths =  depth_line.split()
                infos = info_line.split()
                if len(infos) >= 7:
                    gene_id = infos[6]
                else: 
                    gene_id = ""
                site_id_original = items[0]
                site_id = items[0].split("|")
                site_id = (site_id[0], int(site_id[1]))
                items = items[1:]
                depths = [int(depth) for depth in depths[1:]]     

                if site_id in snvs:
                    sys.stderr.write("EXTRACTING: " + site_id[0] + "|" + str(site_id[1]) + "\n")
                    
                    if species not in site_id_gene_dict:
                        site_id_gene_dict[species] = dict({})
                    
                    if site_id not in site_id_gene_dict[species]:
                        site_id_gene_dict[species][site_id] = gene_id
                    
                    site_id_vec.append(site_id_original)
                    loci_extracted += 1
                    af = []
                    dp = [] 
                    for idx in desired_sample_idxs:    
                        item = items[idx]
                        depth = depths[idx]
                        if depth < min_depth:
                            subitems = [np.nan]
                        else:
                            subitems = item.split(",")
                        af.append(subitems[0])
                        dp.append(depth)

                    if loci_extracted == 1:
                        allele_frequencies = np.array(af)
                        site_depths = np.array(dp)
                        if len(snvs) == 1:
                            allele_frequencies = allele_frequencies[np.newaxis, :]
                            site_depths = site_depths[np.newaxis, :]
                    else:
                        allele_frequencies = np.vstack((allele_frequencies, np.array(af)))
                        site_depths = np.vstack((site_depths, np.array(dp)))
                        
                    
                if (loci_extracted == len(snvs)):
                    sys.stderr.write("Extracted all loci. Annotating.\n")
                    break
    ## ANNOTATION
    ### FREQS
    snv_freqs = pd.DataFrame(columns = desired_samples, index=site_id_vec, data = allele_frequencies)
    snv_freqs = snv_freqs.reset_index().rename(columns = {"index":"site_id"})
    snv_freqs[['contig', 'site_pos', 'nucleotide']] = snv_freqs['site_id'].str.split("|", expand=True)
    ### DEPTHS
    snv_depths = pd.DataFrame(columns=desired_samples,index=site_id_vec,data=site_depths)
    snv_depths = snv_depths.reset_index().rename(columns={"index": "site_id"})
    snv_freqs[['contig', 'site_pos', 'nucleotide']] = snv_freqs['site_id'].str.split("|", expand=True)
    snv_depths[['contig', 'site_pos', 'nucleotide']] = snv_depths['site_id'].str.split("|", expand=True)
    # convert allele frequency to numeric
    snv_freqs[list(desired_samples)] = snv_freqs[list(desired_samples)].apply(pd.to_numeric, errors='coerce')
    snv_depths[list(desired_samples)] = snv_depths[list(desired_samples)].apply(pd.to_numeric, errors='coerce')
    # convert to longform
    snv_freqs =pd.melt(snv_freqs,
                    id_vars=['site_id', 'contig', 'site_pos', 'nucleotide'],  # columns to keep as identifier variables
                    var_name='sample',  # the name for the sample column
                    value_name='allele_frequency'  # the name for the values from sample columns
                    )
    snv_depths = pd.melt(
        snv_depths,
        id_vars=['site_id'],
        var_name='sample',
        value_name='depth'
    )
    # Merge depth and AFs
    snv_freqs = snv_freqs.merge(snv_depths[['site_id', 'sample', 'depth']],
                                on=['site_id', 'sample'],
                                how='left')

    # species and subject
    snv_freqs['species'] = species
    snv_freqs['subject_id'] = subject_id
    # annotate with sample type (e.g., "Stool", "Saliva", "Capsule 2", etc.)
    snv_freqs['sample_type'] = snv_freqs['sample'].apply(lambda x: sample_metadata_map[x][2])
    # annotate with date
    snv_freqs['date'] = snv_freqs[['sample','sample_type']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample']][5], timestamp_format).strftime('%Y-%m-%d') if row['sample_type'] == "Stool" or row['sample_type'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample']][3], timestamp_format).strftime('%Y-%m-%d'), axis = 1)
    # annotate with time
    snv_freqs['time'] = snv_freqs[['sample','sample_type']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample']][5], timestamp_format).strftime('%H:%M:%S') if row['sample_type'] == "Stool" or row['sample_type'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample']][3], timestamp_format).strftime('%H:%M:%S'), axis = 1)
    # annotate with timepoint
    snv_freqs['timepoint'] = pd.to_datetime(snv_freqs['date'] + ' ' + snv_freqs['time'])
    if species_subject_i == 0:
        snv_freqs_combined = snv_freqs.copy()
    else:
        snv_freqs_combined = pd.concat([snv_freqs_combined, snv_freqs], ignore_index=True)




snv_freqs_combined = snv_freqs_combined[['species', 
                                         'subject_id', 
                                         'sample_type', 
                                         'date',
                                         'time', 
                                         'timepoint',
                                         'site_id', 
                                         'contig', 
                                         'site_pos', 
                                         'nucleotide', 
                                         'sample',
                                         'allele_frequency',
                                         'depth']]

sys.stderr.write("\n\nFinished. Combined output in snv_freqs_combined\n")

                
                        
                        
                        
                        
                        
                    
                    
                    
                    
            
                
            
        
    
        
        

    

In [None]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_WithinTimepoint.tsv"
# out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_WithinTimepoint_atleast1timepoint.tsv"
snv_freqs_combined.to_csv(out_path, sep = "\t", index=False)

In [None]:
change_type_summary_df[change_type_summary_df.between_timepoint_changes]

In [None]:
only_haploid = False
only_high_coverage = False
min_depth = 1
final_line_number = 1e10
timestamp_format = '%Y-%m-%dT%H:%M:%SZ'

sample_metadata_map = parse_sample_metadata_map()
subject_sample_map = parse_subject_sample_map()


for species_subject_i,row in change_type_summary_df[change_type_summary_df.between_timepoint_changes].reset_index(drop = False).iterrows():
    species = row['species']
    subject_id = row['subject_id']

    sys.stderr.write("Processing species %s in subject %d (%d / %d species-subject pairs)\n" % (species, subject_id, species_subject_i + 1, len(change_type_summary_df[change_type_summary_df.between_timepoint_changes])))

    # Identify sample pair
    species_subject_between_sample_pair = change_summary_df[(change_summary_df['species'] == species) & (change_summary_df['subject_1'] == subject_id) & (change_summary_df['timepoint_orientation'] == "Between timepoint") & (change_summary_df['snv_changes'] > 0)][['sample_1', 'sample_2', 'snv_changes']].sort_values(by = ['snv_changes'], ascending = False)
    sample_pair = (species_subject_between_sample_pair.iloc[0,0], species_subject_between_sample_pair.iloc[0,1])

    # Extract SNVs
    intersample_change_map = load_intersample_change_map(species)

    snvs = []

    # Loop through all sample pairs
    for _, row in species_subject_between_sample_pair.iterrows():
        sample_pair = (row['sample_1'], row['sample_2'])

        # Try both orientations
        if sample_pair in intersample_change_map:
            snvs_temp = intersample_change_map[sample_pair]['snps'][2]
        elif (sample_pair[1], sample_pair[0]) in intersample_change_map:
            snvs_temp = intersample_change_map[(sample_pair[1], sample_pair[0])]['snps'][2]
        else:
            sys.stderr.write(
                f"{sample_pair} not in intersample change map for {species}\n"
            )
            continue

        # Append raw SNV list to final list
        snvs = snvs + snvs_temp

    snvs = list(set([(snv[1], snv[2]) for snv in snvs]))
    
    # Extract SNV frequencies
    ## Define paths
    snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
    snps_depth_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_depth.txt.bz2")
    snps_info_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_info.txt.bz2")

    ## define desired samples

    if only_haploid:
        sys.stderr.write("Only considering haploid samples\n")
        haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
        desired_samples = [desired_sample for desired_sample in desired_samples if (desired_sample in haploid_samples) & (sample_metadata_map[desired_sample][0] == str(subject_id))]
    elif only_high_coverage:
        sys.stderr.write("Only considering high coverage samples\n")
        high_coverage_samples = calculate_highcoverage_samples(species)
        desired_samples = [desired_sample for desired_sample in desired_samples if (desired_sample in high_coverage_samples) & (sample_metadata_map[desired_sample][0] == str(subject_id))]
    else:
        sys.stderr.write("Considering all samples\n")
        desired_samples = list(subject_sample_map[str(subject_id)].keys())

    
    # Looping through the files
    line_counter = 0
    loci_extracted = 0
    site_id_vec = []
    allele_frequencies = []

    with bz2.open(snps_freq_path, 'rt') as file,  bz2.open(snps_depth_path, 'rt') as depth_file, bz2.open(snps_info_path, 'rt') as info_file:
        for line, depth_line, info_line in zip(file, depth_file, info_file):
            if line_counter == 0:
                items = line.split()[1:]    
                samples_in_file = parse_merged_sample_names(items)
                desired_sample_idxs = []
                for sample in desired_samples:
                    if sample not in samples_in_file:
                        continue
                    desired_sample_idxs.append( np.nonzero(samples_in_file==sample)[0][0] )
                desired_sample_idxs = np.array(desired_sample_idxs)    
                desired_samples = samples_in_file[desired_sample_idxs]

                line_counter += 1
            else:
                if line_counter == final_line_number:
                    sys.stderr.write("Breaking at line " + str(final_line_number) + ". " + str(loci_extracted)+ " loci extracted.\n")
                    break

                items = line.split()
                depths =  depth_line.split()
                infos = info_line.split()
                if len(infos) >= 7:
                    gene_id = infos[6]
                else: 
                    gene_id = ""
                site_id_original = items[0]
                site_id = items[0].split("|")
                site_id = (site_id[0], int(site_id[1]))
                items = items[1:]
                depths = [int(depth) for depth in depths[1:]]     

                if site_id in snvs:
                    sys.stderr.write("EXTRACTING: " + site_id[0] + "|" + str(site_id[1]) + "\n")
                    site_id_vec.append(site_id_original)
                    loci_extracted += 1
                    af = []
                    dp = [] 
                    for idx in desired_sample_idxs:    
                        item = items[idx]
                        depth = depths[idx]
                        if depth < min_depth:
                            subitems = [np.nan]
                        else:
                            subitems = item.split(",")
                        af.append(subitems[0])
                        dp.append(depth)

                    if loci_extracted == 1:
                        allele_frequencies = np.array(af)
                        site_depths = np.array(dp)
                        if len(snvs) == 1:
                            allele_frequencies = allele_frequencies[np.newaxis, :]
                            site_depths = site_depths[np.newaxis, :]
                    else:
                        allele_frequencies = np.vstack((allele_frequencies, np.array(af)))
                        site_depths = np.vstack((site_depths, np.array(dp)))
                        
                    
                if (loci_extracted == len(snvs)):
                    sys.stderr.write("Extracted all loci. Annotating.\n")
                    break
    ## ANNOTATION
    ### FREQS
    snv_freqs = pd.DataFrame(columns = desired_samples, index=site_id_vec, data = allele_frequencies)
    snv_freqs = snv_freqs.reset_index().rename(columns = {"index":"site_id"})
    snv_freqs[['contig', 'site_pos', 'nucleotide']] = snv_freqs['site_id'].str.split("|", expand=True)
    ### DEPTHS
    snv_depths = pd.DataFrame(columns=desired_samples,index=site_id_vec,data=site_depths)
    snv_depths = snv_depths.reset_index().rename(columns={"index": "site_id"})
    snv_freqs[['contig', 'site_pos', 'nucleotide']] = snv_freqs['site_id'].str.split("|", expand=True)
    snv_depths[['contig', 'site_pos', 'nucleotide']] = snv_depths['site_id'].str.split("|", expand=True)
    # convert allele frequency to numeric
    snv_freqs[list(desired_samples)] = snv_freqs[list(desired_samples)].apply(pd.to_numeric, errors='coerce')
    snv_depths[list(desired_samples)] = snv_depths[list(desired_samples)].apply(pd.to_numeric, errors='coerce')
    # convert to longform
    snv_freqs =pd.melt(snv_freqs,
                    id_vars=['site_id', 'contig', 'site_pos', 'nucleotide'],  # columns to keep as identifier variables
                    var_name='sample',  # the name for the sample column
                    value_name='allele_frequency'  # the name for the values from sample columns
                    )
    snv_depths = pd.melt(
        snv_depths,
        id_vars=['site_id'],
        var_name='sample',
        value_name='depth'
    )
    # Merge depth and AFs
    snv_freqs = snv_freqs.merge(snv_depths[['site_id', 'sample', 'depth']],
                                on=['site_id', 'sample'],
                                how='left')

    # species and subject
    snv_freqs['species'] = species
    snv_freqs['subject_id'] = subject_id
    # annotate with sample type (e.g., "Stool", "Saliva", "Capsule 2", etc.)
    snv_freqs['sample_type'] = snv_freqs['sample'].apply(lambda x: sample_metadata_map[x][2])
    # annotate with date
    snv_freqs['date'] = snv_freqs[['sample','sample_type']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample']][5], timestamp_format).strftime('%Y-%m-%d') if row['sample_type'] == "Stool" or row['sample_type'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample']][3], timestamp_format).strftime('%Y-%m-%d'), axis = 1)
    # annotate with time
    snv_freqs['time'] = snv_freqs[['sample','sample_type']].apply(lambda row: datetime.strptime(sample_metadata_map[row['sample']][5], timestamp_format).strftime('%H:%M:%S') if row['sample_type'] == "Stool" or row['sample_type'] == "Saliva" else datetime.strptime(sample_metadata_map[row['sample']][3], timestamp_format).strftime('%H:%M:%S'), axis = 1)
    # annotate with timepoint
    snv_freqs['timepoint'] = pd.to_datetime(snv_freqs['date'] + ' ' + snv_freqs['time'])
    if species_subject_i == 0:
        snv_freqs_combined = snv_freqs.copy()
    else:
        snv_freqs_combined = pd.concat([snv_freqs_combined, snv_freqs], ignore_index=True)




snv_freqs_combined = snv_freqs_combined[['species', 
                                         'subject_id', 
                                         'sample_type', 
                                         'date',
                                         'time', 
                                         'timepoint',
                                         'site_id', 
                                         'contig', 
                                         'site_pos', 
                                         'nucleotide', 
                                         'sample',
                                         'allele_frequency',
                                         'depth']]

sys.stderr.write("\n\nFinished. Combined output in snv_freqs_combined\n")

                
                        
                        
                        
                        
                        
                    
                    
                    
                    
            
                
            
        
    
        
        

    

In [None]:
snv_freqs_combined

In [None]:
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_BetweenTimepoint.tsv"
snv_freqs_combined.to_csv(out_path, sep = "\t", index=False)

## SNV annotations

### METHOD 1

In [None]:
species_list = change_summary_df[change_summary_df.snv_changes > 0][['species','sample_1','sample_2']].drop_duplicates().reset_index(drop = True)


In [None]:
species_vec = []
contig_vec = []
site_pos_vec = []
contig_site_pos_vec = []
variant_type_vec = []
gene_id_vec = []
gene_description_vec = []

last_species = ""


for i,row in species_list.iterrows():
    
    species = row['species']
    sample_1 = row['sample_1']
    sample_2 = row['sample_2']
    
    sys.stderr.write("Processing %s in %s - %s sample pair (%d / %d species)\n" % (species, sample_1, sample_2, i+1, len(species_list)))

    if species != last_species:
        last_species = species
        intersample_change_map = load_intersample_change_map(species)
    

    if (sample_1,sample_2) in intersample_change_map:
        sample_pair = (sample_1,sample_2)
        snvs = intersample_change_map[sample_pair]['snps'][2]
    elif (sample_2,sample_1) in intersample_change_map:
        sample_pair = (sample_2,sample_1)
        snvs = intersample_change_map[sample_pair]['snps'][2]
    else:
        sys.stderr.write("%s not in intersample change map for %s\n" % (str(sample_pair), species))
        continue


    for snv in snvs:
        gene_id = snv[0]
        contig = snv[1]
        site_pos = snv[2]
        variant_type = snv[3]
        contig_site_pos = contig + "|" + str(site_pos)
        if contig_site_pos not in contig_site_pos_vec:

            species_vec.append(species)
            contig_vec.append(contig)
            site_pos_vec.append(site_pos)
            contig_site_pos_vec.append(contig_site_pos)
            gene_id_vec.append(gene_id)
            variant_type_vec.append(variant_type)

            genome_ids = parse_midas_data.get_ref_genome_ids(species)
            non_shared_genes = core_gene_utils.parse_non_shared_reference_genes(species)
            gene_descriptions = parse_patric.load_patric_gene_descriptions(genome_ids, non_shared_genes)
            centroid_gene_map = parse_midas_data.load_centroid_gene_map(species)

            if gene_id in gene_descriptions:
                gene_description_vec.append(gene_descriptions[gene_id])
            elif gene_id in centroid_gene_map:
                if centroid_gene_map[gene_id] in gene_descriptions:
                    gene_description_vec.append(gene_descriptions[centroid_gene_map[gene_id]])
            else:
                gene_description_vec.append("")




    

In [None]:
# saving
out_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_gene_descriptions.tsv"
snv_gene_descriptions_df.to_csv(out_path, sep = "\t", index = False)


### METHOD 2

In [None]:
# Loading dataframes
between_timepoint_freqs_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_BetweenTimepoint.tsv"
within_timepoint_freqs_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_WithinTimepoint.tsv"
snv_freqs_1timepoint_path = "/u/project/ngarud/Garud_lab/metagenomic_fastq_files/Shalon_2023/evolutionary_changes/SNV_freqs_WithinTimepoint_atleast1timepoint.tsv"

df_1 = pd.read_csv(between_timepoint_freqs_path, sep = "\t")
df_2 = pd.read_csv(within_timepoint_freqs_path, sep = "\t")
df_3 = pd.read_csv(snv_freqs_1timepoint_path, sep = "\t")

df_all = pd.concat([df_1, df_2, df_3], axis=0, ignore_index=True)



In [None]:
df_all[['species','site_id','contig','site_pos']].drop_duplicates()

## $F_{st}$ pre-processing

In [None]:
# Get all within-host sample qp pairs
haploid_sample_dict = dict({})
for species in species_list:
    haploid_samples = calculate_haploid_samples(species)
    sample_pairs = list(itertools.combinations(haploid_samples, 2))

    sample_pairs = [sample_pair for sample_pair in sample_pairs if sample_metadata_map[sample_pair[0]][0] == sample_metadata_map[sample_pair[1]][0]]

    # Store the sample pairs in the dictionary
    if len(sample_pairs) > 0:
        haploid_sample_dict[species] = sample_pairs

In [None]:
# Make fst dataframe
fst_processing = []
for species in haploid_sample_dict.keys():
    
    for sample_pair in haploid_sample_dict[species]:
        
        sample_1 = sample_pair[0]
        sample_2 = sample_pair[1]

        fst_processing.append([species,sample_1,sample_2])

fst_processing_df = pd.DataFrame(fst_processing, columns = ['species','sample_1', 'sample_2'])

# saving 
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/metadata/Fst_processing_df.txt"
fst_processing_df.to_csv(out_path, sep=',', index=False)


## Analyzing pre-identified changes

[Within-host change](https://diversityalonggutshared.blogspot.com/2024/04/shalon-results-all-within-host-snvs.html)  
[Within-set changes](https://diversityalonggutshared.blogspot.com/2024/04/do-within-host-changes-spread-shalon-et.html)

In [None]:
grouped_species_subject_pairs

In [None]:
species = "Alistipes_onderdonkii_55464"
subject_id = 11


In [None]:
# load intersample changes
intersample_changes = load_intersample_change_map(species)

In [None]:
# Extracting all within-host SNVs

subject_samples = list(subject_sample_map[str(subject_id)].keys())
haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
subject_samples = [sample_ for sample_ in subject_samples if sample_ in haploid_samples]
# sample_pair_list = list(itertools.combinations(subject_samples, 2))
sample_pair_list = list(change_summary_df.loc[(change_summary_df['species'] == species) & (change_summary_df['subject_1'] == str(subject_id)),['sample_1', 'sample_2']].itertuples(index=False, name=None))

# paring intersample changes to only include within-host snvs
intersample_changes = {key: value for key, value in intersample_changes.items() if (key in sample_pair_list) or ((key[1], key[0]) in sample_pair_list)}


### SNV trajectories

In [None]:
snvs = set({})
# for key in intersample_changes.keys():
for key in sample_pair_list:
    snv_list = set(intersample_changes[key]['snps'][2])
    snvs.update(snv_list)

In [None]:
# Filter for core or non-shared genes
core_genes_only = False

if core_genes_only:
    sys.stderr.write("Filtering for core genes\n")
else:
    sys.stderr.write("Filtering for non-shared genes\n")
    

if core_genes_only:
    core_genes = parse_core_genes(species)
    centroid_gene_map = load_centroid_gene_map(species)
    desired_gene_ids = {value if key in core_genes else key for key, value in centroid_gene_map.items() if key in core_genes or value in core_genes}.union(core_genes)
    snvs = [snv for snv in snvs if snv[0] in desired_gene_ids]
else:
    non_shared_genes = parse_non_shared_reference_genes(species)
    centroid_gene_map = load_centroid_gene_map(species)
    desired_gene_ids = {value if key in non_shared_genes else key for key, value in centroid_gene_map.items() if key in non_shared_genes or value in non_shared_genes}.union(non_shared_genes)
    snvs = [snv for snv in snvs if snv[0] in desired_gene_ids]

# snvs = [(snv[1], snv[2], snv[0]) for snv in snvs]
snvs = list(set([(snv[1], snv[2]) for snv in sorted(snvs, key=lambda x: (x[1], x[2]))]))

In [None]:
# separate snvs into clusters

clust_dist = 1e4
contigs = {snv[0] for snv in snvs}
snv_clusters = []
start = 0
for i,contig in enumerate(contigs):
    snvs_contig = [snv for snv in snvs if snv[0] == contig]
    snvs_contig = [snv for snv in sorted(snvs_contig, key=lambda x: (x[0], x[1]))]
    loci = np.array([snv[1] for snv in snvs_contig])
    gaps = np.diff(loci)
    breaks = np.where(gaps > clust_dist)[0]
    start = 0
    for end in breaks:
        end = end+1
        snv_clusters.append(list(snvs_contig[start:end]))
        start = end
    snv_clusters.append(list(snvs_contig[start:]))
    

In [None]:
# reloading subject samples
subject_samples = list(subject_sample_map[str(subject_id)].keys())
snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
# Open the compressed file in binary mode
with bz2.open(snps_freq_path, 'rb') as file:
    first_line = file.readline().strip() 
snps_samples = first_line.decode('utf-8').split("\t")[1:]

subject_samples = [sample_ for sample_ in subject_samples if sample_ in snps_samples]


# Creating sample indices and sorting
sample_type = [sample_metadata_map[f][2] for f in subject_samples]
stool_idx = [i for i, element in enumerate(sample_type) if element == "Stool"]
saliva_idx = [i for i, element in enumerate(sample_type) if element == "Saliva"]
timestamp_format = '%Y-%m-%dT%H:%M:%SZ'

# Create a list of tuples containing sorting keys
sorting_keys = [
(
    datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%Y-%m-%d') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%Y-%m-%d'),
    datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%H:%M:%S') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%H:%M:%S'),
    sample_metadata_map[f][2]  # sample_type
)
for i, f in enumerate(subject_samples)
]

# Sort subject_samples based on the sorting keys
desired_samples = [x for _, x in sorted(zip(sorting_keys, subject_samples))]

# Creating sample indices with sorted
sample_type = [sample_metadata_map[f][2] for f in desired_samples]
stool_idx = [i for i, element in enumerate(sample_type) if element == "Stool"]
saliva_idx = [i for i, element in enumerate(sample_type) if element == "Saliva"]

# creating a dictionary
sample_sorting_dict = {
    f: (
        datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%Y-%m-%d') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%Y-%m-%d'),
        datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%H:%M:%S') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%H:%M:%S'),
        sample_metadata_map[f][2],  # sample_type
        i
    )
    for i, f in enumerate(desired_samples)
}


In [None]:
# Extracting SNV frequencies

only_haploid = True
only_high_coverage = False
min_depth = 5

x_axis_dictionary = {}

snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
snps_depth_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_depth.txt.bz2")
snps_info_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_info.txt.bz2")


final_line_number = 1e10

if only_haploid:
    haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
    desired_samples = [desired_sample for desired_sample in desired_samples if desired_sample in haploid_samples]
elif only_high_coverage:
    high_coverage_samples = calculate_highcoverage_samples(species)
    desired_samples = [desired_sample for desired_sample in desired_samples if desired_sample in high_coverage_samples]

# if (len(baseline_sample) == 0) | (baseline_sample not in desired_samples):
#     if len(recipient_samples) > 0:
#         baseline_sample = recipient_samples[0]
#     else:
#         baseline_sample = desired_samples[0]

line_counter = 0
current_cluster = 0
clusters_finished = False
loci_extracted = [0]*len(snv_clusters)
site_id_vec = [[] for _ in snv_clusters]
allele_frequencies = [[] for _ in snv_clusters]

with bz2.open(snps_freq_path, 'rt') as file,  bz2.open(snps_depth_path, 'rt') as depth_file, bz2.open(snps_info_path, 'rt') as info_file:
    for line, depth_line, info_line in zip(file, depth_file, info_file):
        if line_counter == 0:
            items = line.split()[1:]    
            samples_in_file = parse_merged_sample_names(items)
            desired_sample_idxs = []
            for sample in desired_samples:
                if sample not in samples_in_file:
                    continue
                desired_sample_idxs.append( np.nonzero(samples_in_file==sample)[0][0] )
                # if plot_strain_snvs:
                #     strain_tracking_sample_idx = np.nonzero(samples_in_file==strain_tracking_sample)[0][0]
                #     strain_tracking_sample_foil_idx = np.nonzero(samples_in_file==strain_tracking_sample_foil)[0][0]

            desired_sample_idxs = np.array(desired_sample_idxs)    
            desired_samples = samples_in_file[desired_sample_idxs]
            sample_sorting_dict = {key:val for key,val in sample_sorting_dict.items() if key in desired_samples}
            sample_sorting_dict = {k: sample_sorting_dict[k][:3] + (i,) for i, k in enumerate(sample_sorting_dict.keys())}

            line_counter += 1
        else:
            if line_counter == final_line_number:
                sys.stderr.write("Breaking at line " + str(final_line_number) + "\n")
                break
            if (sum(loci_extracted) == len(snvs)):
                sys.stderr.write("All loci extracted\n")
                break
            
            items = line.split()
            depths =  depth_line.split()
            infos = info_line.split()
            if len(infos) >= 7:
                gene_id = infos[6]
            else: 
                gene_id = ""
            site_id_original = items[0]
            site_id = items[0].split("|")
            site_id = (site_id[0], int(site_id[1]))
            items = items[1:]
            depths = [int(depth) for depth in depths[1:]]     

            for cluster_i,cluster in enumerate(snv_clusters):
                if site_id in cluster:
                    sys.stderr.write("EXTRACTING: " + site_id[0] + "|" + str(site_id[1]) + " in cluster " + str(cluster_i+1) + "\n")
                    site_id_vec[cluster_i].append(site_id_original)
                    loci_extracted[cluster_i] += 1
                    af = []
                    for idx in desired_sample_idxs:    
                        item = items[idx]
                        depth = depths[idx]
                        if depth < min_depth:
                            subitems = [np.nan]
                        else:
                            subitems = item.split(",")
                        af.append(subitems[0])
                    if loci_extracted[cluster_i] == 1:
                        current_cluster += 1
                        allele_frequencies[cluster_i] = np.array(af)
                        if len(snv_clusters[cluster_i]) == 1:
                            allele_frequencies[cluster_i] = allele_frequencies[cluster_i][np.newaxis, :]
                    else:
                        allele_frequencies[cluster_i] = np.vstack((allele_frequencies[cluster_i], np.array(af)))
                    break
            if (loci_extracted[cluster_i] == len(cluster)) & (not clusters_finished):
                if (sum(loci_extracted) == len(snvs)):
                    clusters_finished = True
                sys.stderr.write("Extracted " + str(loci_extracted[cluster_i]) + " SNVs in cluster " + str(cluster_i+1) + "\n")
            
            line_counter += 1
            # if line_counter % 1000 == 0:
            #     print(site_id)  



In [None]:
## ANNOTATION
snv_freqs = []
sorting_dict = {key: idx for idx, key in enumerate(sample_sorting_dict.keys())}
polarize_sample = [sample_ for sample_ in sample_sorting_dict.keys() if (sample_sorting_dict[sample_][2] != "Stool") & (sample_sorting_dict[sample_][2] != "Saliva")][0]
for cluster_i, cluster in enumerate(snv_clusters):
    # creating dataframe
    snv_freqs.append(pd.DataFrame(columns = desired_samples, index=site_id_vec[cluster_i], data = allele_frequencies[cluster_i]))
    snv_freqs[cluster_i] = snv_freqs[cluster_i].reset_index().rename(columns = {"index":"site_id"})
    snv_freqs[cluster_i][['contig', 'site_pos', 'nucleotide']] = snv_freqs[cluster_i]['site_id'].str.split("|", expand=True)
    # convert allele frequency to numeric
    snv_freqs[cluster_i][list(desired_samples)] = snv_freqs[cluster_i][list(desired_samples)].apply(pd.to_numeric, errors='coerce')
    # polarize
    snv_freqs[cluster_i][list(desired_samples)] = snv_freqs[cluster_i].apply(lambda row: 1 - row[list(desired_samples)] if row[polarize_sample] > 0.5 else row[list(desired_samples)], axis = 1)
    # convert to longform
    snv_freqs[cluster_i] =pd.melt(snv_freqs[cluster_i],
                                  id_vars=['site_id', 'contig', 'site_pos', 'nucleotide'],  # columns to keep as identifier variables
                                  var_name='sample',  # the name for the sample column
                                  value_name='allele_frequency'  # the name for the values from sample columns
                                  )

    # annotate with sample type (e.g., "Stool", "Saliva", "Capsule 2", etc.)
    snv_freqs[cluster_i]['sample_type'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][2])
    # annotate with date
    snv_freqs[cluster_i]['date'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][0])
    # annotate with time
    snv_freqs[cluster_i]['time'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][1])
    # annotate with timepoint
    snv_freqs[cluster_i]['timepoint'] = snv_freqs[cluster_i]['sample'].apply(lambda x: "%s,\n%s" % (sample_sorting_dict[x][0],sample_sorting_dict[x][1]))
    # sample order
    # snv_freqs[cluster_i]['sample_order'] = snv_freqs[cluster_i]['sample'].map(sorting_dict)
    snv_freqs[cluster_i]['sample_order'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][3])

snv_freqs_combined = pd.concat(
    [df.assign(cluster=i) for i, df in enumerate(snv_freqs)],
    ignore_index=True
)



In [None]:
# plotting
fig, ax = plt.subplots(figsize=(20, 8))


color_palette = sns.color_palette("husl", len(snv_freqs))


for cluster_i,cluster in enumerate(snv_freqs):
    for site_id in cluster['site_id'].unique():
        snv_freq_individual = cluster[cluster['site_id'] == site_id]
        sns.lineplot(data=snv_freq_individual, x = "sample_order", y = "allele_frequency", color = color_palette[cluster_i])

# Add legend manually
cluster_labels = [f"Cluster {i}" for i in range(len(snv_freqs))]
handles = [plt.Line2D([0], [0], color=color_palette[i], lw=2) for i in range(len(snv_freqs))]
ax.legend(
    handles,
    cluster_labels,
    title="Genomic clusters",
    loc="center left",
    bbox_to_anchor=(1, 0.5),  # Position the legend to the right of the plot
    ncol=1  # Arrange legend items in a single column
)



ax.set_xticks(range(len(sample_sorting_dict.keys())))

major_ticks = []
major_tick_labels = []
minor_ticks = []
minor_tick_labels = []
time_point = ""
x_ticks_loc = ax.get_xticks()
vspan_counter = 0
vspan_vec = []
for i, sample_ in enumerate(sample_sorting_dict.keys()):
    # Major ticks: timepoint
    new_time_point = "%s,\n%s" % (sample_sorting_dict[sample_][0],sample_sorting_dict[sample_][1])
    if (time_point != new_time_point) & (i != len(sample_sorting_dict) - 1):
        time_point = new_time_point
        # major_ticks.append(x_ticks_loc[i])
        major_tick_labels.append(time_point)

        # add vspan
        vspan_counter += 1

        if vspan_counter == 1:
            xmin = 0
        else: 
            xmax = x_ticks_loc[i] - 0.5
            vspan_vec.append([xmin, xmax])
            xmin = xmax
    elif (time_point != new_time_point) & (i == len(sample_sorting_dict) - 1):  
        time_point = new_time_point    
        major_tick_labels.append(time_point)
        
        xmax = x_ticks_loc[i] - 0.5
        vspan_vec.append([xmin, xmax])
        xmin = xmax
        xmax = x_ticks_loc[i]
        vspan_vec.append([xmin, xmax])
    elif (i == len(sample_sorting_dict) - 1):
        xmax = x_ticks_loc[i]
        vspan_vec.append([xmin, xmax])

    # Minor ticks: sample type
    sample_type = sample_sorting_dict[sample_][2]
    minor_ticks.append(x_ticks_loc[i])
    minor_tick_labels.append(sample_type)

# adding vspan and creating minor ticks
for i,v in enumerate(vspan_vec):
    # adding vspan
    if i % 2 == 1:
        ax.axvspan(v[0],v[1],alpha=.2,color='grey') 
    
    if (i == 0) & (v[1] == 0.5):
        major_ticks.append(0)
    elif (i == len(vspan_vec) - 1):
        major_ticks.append(np.mean([v[0] + 0.5, v[1]]))
    else:
        major_ticks.append(np.mean(v))


ax.set_xticks(major_ticks)
ax.set_xticklabels(major_tick_labels)
ax.set_xticks(minor_ticks, minor=True)
ax.set_xticklabels(minor_tick_labels, minor=True)

ax.xaxis.remove_overlapping_locs = False

ax.set_ylim(0,1)

# Tick parameters
plt.tick_params(axis='x',which='major',bottom=False,left=False,top=False) 
ax.tick_params(axis='x', which='major', pad=15)  # Move major ticks down by increasing the padding
ax.tick_params(axis='x', which='minor', pad=2) 

# Titles

ax.set_title("%s%s%s%s" % ("SNV frequencies in ", species, " in subject ", str(subject_id)), size = 20)
ax.set_xlabel("Sample", size = 20)
ax.set_ylabel("Allele frequency", size = 20)


plt.tight_layout()


In [None]:
snv_freqs[0]

In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/evolutionary_changes/SNV_freqs_%s_%s.png" % (species,str(subject_id))
fig.savefig(out_path, dpi = 300, facecolor = "white")


In [None]:
out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/revisions/SNV_freqs_%s_%s.png" % (species,str(subject_id))
fig.savefig(out_path, dpi = 300, facecolor = "white")

In [None]:
# saving 
fst_processing_df.to_csv(out_path, sep=',', index=False)

## Analyzing pre-identified changes (for loop)

In [None]:
only_haploid = True
only_high_coverage = False
core_genes_only = False
min_depth = 10


number_of_snvs_vec = []
snvs_set_vec = []


for species_subject_i,species_subject_row in grouped_species_subject_pairs.iterrows():
    # Extracting species-subject pair
    species = species_subject_row['species']
    subject_id = int(species_subject_row['subject_id'])

    sys.stderr.write("Processig %s in subject %s (%d / %d)\n" % (species, str(subject_id), species_subject_i + 1, len(grouped_species_subject_pairs)))

    # load intersample changes
    intersample_changes = load_intersample_change_map(species)

    # Extracting all within-host SNVs

    subject_samples = list(subject_sample_map[str(subject_id)].keys())
    haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
    # subject_samples = [sample_ for sample_ in subject_samples if sample_ in haploid_samples]
    # sample_pair_list = list(itertools.combinations(subject_samples, 2))
    sample_pair_list = list(change_summary_df.loc[(change_summary_df['species'] == species) & (change_summary_df['subject_1'] == str(subject_id)),['sample_1', 'sample_2']].itertuples(index=False, name=None))
    
    ## paring intersample changes to only include within-host snvs
    intersample_changes = {key: value for key, value in intersample_changes.items() if (key in sample_pair_list) or ((key[1], key[0]) in sample_pair_list)}

    ## building snv list 
    snvs = set({})
    # for key in intersample_changes.keys():
    for key in sample_pair_list:
        snv_list = set(intersample_changes[key]['snps'][2])
        snvs.update(snv_list)

    # Filter for core or non-shared genes
    
    if core_genes_only:
        sys.stderr.write("Filtering for core genes\n")
    else:
        sys.stderr.write("Filtering for non-shared genes\n")
        
    if core_genes_only:
        core_genes = parse_core_genes(species)
        centroid_gene_map = load_centroid_gene_map(species)
        desired_gene_ids = {value if key in core_genes else key for key, value in centroid_gene_map.items() if key in core_genes or value in core_genes}.union(core_genes)
        snvs = [snv for snv in snvs if snv[0] in desired_gene_ids]
    else:
        non_shared_genes = parse_non_shared_reference_genes(species)
        centroid_gene_map = load_centroid_gene_map(species)
        desired_gene_ids = {value if key in non_shared_genes else key for key, value in centroid_gene_map.items() if key in non_shared_genes or value in non_shared_genes}.union(non_shared_genes)
        snvs = [snv for snv in snvs if snv[0] in desired_gene_ids]

    # snvs = [(snv[1], snv[2], snv[0]) for snv in snvs]
    snvs = list(set([(snv[1], snv[2]) for snv in sorted(snvs, key=lambda x: (x[1], x[2]))]))

    if len(snvs) == 0:
        sys.stderr.write("No within-host SNVs in %s in %s \n" % (species, str(subject_id)))
        number_of_snvs_vec.append(0)
        snvs_set_vec.append(set())
        continue
    else:
        number_of_snvs_vec.append(len(snvs))
        snvs_set_vec.append(set(snvs))


    # separate snvs into clusters

    clust_dist = 1e4
    contigs = {snv[0] for snv in snvs}
    snv_clusters = []
    start = 0
    for i,contig in enumerate(contigs):
        snvs_contig = [snv for snv in snvs if snv[0] == contig]
        snvs_contig = [snv for snv in sorted(snvs_contig, key=lambda x: (x[0], x[1]))]
        loci = np.array([snv[1] for snv in snvs_contig])
        gaps = np.diff(loci)
        breaks = np.where(gaps > clust_dist)[0]
        start = 0
        for end in breaks:
            end = end+1
            snv_clusters.append(list(snvs_contig[start:end]))
            start = end
        snv_clusters.append(list(snvs_contig[start:]))
        

    # reloading subject samples
    subject_samples = list(subject_sample_map[str(subject_id)].keys())
    snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
    # Open the compressed file in binary mode
    with bz2.open(snps_freq_path, 'rb') as file:
        first_line = file.readline().strip() 
    snps_samples = first_line.decode('utf-8').split("\t")[1:]

    subject_samples = [sample_ for sample_ in subject_samples if sample_ in snps_samples]


    # Creating sample indices and sorting
    sample_type = [sample_metadata_map[f][2] for f in subject_samples]
    stool_idx = [i for i, element in enumerate(sample_type) if element == "Stool"]
    saliva_idx = [i for i, element in enumerate(sample_type) if element == "Saliva"]
    timestamp_format = '%Y-%m-%dT%H:%M:%SZ'

    # Create a list of tuples containing sorting keys
    sorting_keys = [
    (
        datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%Y-%m-%d') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%Y-%m-%d'),
        datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%H:%M:%S') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%H:%M:%S'),
        sample_metadata_map[f][2]  # sample_type
    )
    for i, f in enumerate(subject_samples)
    ]

    # Sort subject_samples based on the sorting keys
    desired_samples = [x for _, x in sorted(zip(sorting_keys, subject_samples))]

    # Creating sample indices with sorted
    sample_type = [sample_metadata_map[f][2] for f in desired_samples]
    stool_idx = [i for i, element in enumerate(sample_type) if element == "Stool"]
    saliva_idx = [i for i, element in enumerate(sample_type) if element == "Saliva"]

    # creating a dictionary
    sample_sorting_dict = {
        f: (
            datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%Y-%m-%d') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%Y-%m-%d'),
            datetime.strptime(sample_metadata_map[f][5], timestamp_format).strftime('%H:%M:%S') if (i in stool_idx) | (i in saliva_idx) else datetime.strptime(sample_metadata_map[f][3], timestamp_format).strftime('%H:%M:%S'),
            sample_metadata_map[f][2],  # sample_type
            i
        )
        for i, f in enumerate(desired_samples)
    }

        
    # Extracting SNV frequencies
    x_axis_dictionary = {}

    snps_freq_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_ref_freq.txt.bz2")
    snps_depth_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_depth.txt.bz2")
    snps_info_path = "%s%s%s%s" % (config.data_directory,"snps/", species, "/snps_info.txt.bz2")


    final_line_number = 1e10

    if only_haploid:
        haploid_samples = calculate_haploid_samples(species, use_HMP_freqs=True)
        desired_samples = [desired_sample for desired_sample in desired_samples if desired_sample in haploid_samples]
    elif only_high_coverage:
        high_coverage_samples = calculate_highcoverage_samples(species)
        desired_samples = [desired_sample for desired_sample in desired_samples if desired_sample in high_coverage_samples]

    # if (len(baseline_sample) == 0) | (baseline_sample not in desired_samples):
    #     if len(recipient_samples) > 0:
    #         baseline_sample = recipient_samples[0]
    #     else:
    #         baseline_sample = desired_samples[0]

    line_counter = 0
    current_cluster = 0
    clusters_finished = False
    loci_extracted = [0]*len(snv_clusters)
    site_id_vec = [[] for _ in snv_clusters]
    allele_frequencies = [[] for _ in snv_clusters]

    with bz2.open(snps_freq_path, 'rt') as file,  bz2.open(snps_depth_path, 'rt') as depth_file, bz2.open(snps_info_path, 'rt') as info_file:
        for line, depth_line, info_line in zip(file, depth_file, info_file):
            if line_counter == 0:
                items = line.split()[1:]    
                samples_in_file = parse_merged_sample_names(items)
                desired_sample_idxs = []
                for sample in desired_samples:
                    if sample not in samples_in_file:
                        continue
                    desired_sample_idxs.append( np.nonzero(samples_in_file==sample)[0][0] )
                    # if plot_strain_snvs:
                    #     strain_tracking_sample_idx = np.nonzero(samples_in_file==strain_tracking_sample)[0][0]
                    #     strain_tracking_sample_foil_idx = np.nonzero(samples_in_file==strain_tracking_sample_foil)[0][0]

                desired_sample_idxs = np.array(desired_sample_idxs)    
                desired_samples = samples_in_file[desired_sample_idxs]
                sample_sorting_dict = {key:val for key,val in sample_sorting_dict.items() if key in desired_samples}
                sample_sorting_dict = {k: sample_sorting_dict[k][:3] + (i,) for i, k in enumerate(sample_sorting_dict.keys())}

                line_counter += 1
            else:
                if line_counter == final_line_number:
                    sys.stderr.write("Breaking at line " + str(final_line_number) + "\n")
                    break
                if (sum(loci_extracted) == len(snvs)):
                    sys.stderr.write("All loci extracted\n")
                    break
                
                items = line.split()
                depths =  depth_line.split()
                infos = info_line.split()
                if len(infos) >= 7:
                    gene_id = infos[6]
                else: 
                    gene_id = ""
                site_id_original = items[0]
                site_id = items[0].split("|")
                site_id = (site_id[0], int(site_id[1]))
                items = items[1:]
                depths = [int(depth) for depth in depths[1:]]     

                for cluster_i,cluster in enumerate(snv_clusters):
                    if site_id in cluster:
                        sys.stderr.write("EXTRACTING: " + site_id[0] + "|" + str(site_id[1]) + " in cluster " + str(cluster_i+1) + "\n")
                        site_id_vec[cluster_i].append(site_id_original)
                        loci_extracted[cluster_i] += 1
                        af = []
                        for idx in desired_sample_idxs:    
                            item = items[idx]
                            depth = depths[idx]
                            if depth < min_depth:
                                subitems = [np.nan]
                            else:
                                subitems = item.split(",")
                            af.append(subitems[0])
                        if loci_extracted[cluster_i] == 1:
                            current_cluster += 1
                            allele_frequencies[cluster_i] = np.array(af)
                            if len(snv_clusters[cluster_i]) == 1:
                                allele_frequencies[cluster_i] = allele_frequencies[cluster_i][np.newaxis, :]
                        else:
                            allele_frequencies[cluster_i] = np.vstack((allele_frequencies[cluster_i], np.array(af)))
                        break
                if (loci_extracted[cluster_i] == len(cluster)) & (not clusters_finished):
                    if (sum(loci_extracted) == len(snvs)):
                        clusters_finished = True
                    sys.stderr.write("Extracted " + str(loci_extracted[cluster_i]) + " SNVs in cluster " + str(cluster_i+1) + "\n")
                
                line_counter += 1
                # if line_counter % 1000 == 0:
                #     print(site_id)  



    ## ANNOTATION
    snv_freqs = []
    sorting_dict = {key: idx for idx, key in enumerate(sample_sorting_dict.keys())}
    polarize_sample = [sample_ for sample_ in sample_sorting_dict.keys() if (sample_sorting_dict[sample_][2] != "Stool") & (sample_sorting_dict[sample_][2] != "Saliva")][0]
    for cluster_i, cluster in enumerate(snv_clusters):
        # creating dataframe
        snv_freqs.append(pd.DataFrame(columns = desired_samples, index=site_id_vec[cluster_i], data = allele_frequencies[cluster_i]))
        snv_freqs[cluster_i] = snv_freqs[cluster_i].reset_index().rename(columns = {"index":"site_id"})
        snv_freqs[cluster_i][['contig', 'site_pos', 'nucleotide']] = snv_freqs[cluster_i]['site_id'].str.split("|", expand=True)
        # convert allele frequency to numeric
        snv_freqs[cluster_i][list(desired_samples)] = snv_freqs[cluster_i][list(desired_samples)].apply(pd.to_numeric, errors='coerce')
        # polarize
        snv_freqs[cluster_i][list(desired_samples)] = snv_freqs[cluster_i].apply(lambda row: 1 - row[list(desired_samples)] if row[polarize_sample] > 0.5 else row[list(desired_samples)], axis = 1)
        # convert to longform
        snv_freqs[cluster_i] =pd.melt(snv_freqs[cluster_i],
                                    id_vars=['site_id', 'contig', 'site_pos', 'nucleotide'],  # columns to keep as identifier variables
                                    var_name='sample',  # the name for the sample column
                                    value_name='allele_frequency'  # the name for the values from sample columns
                                    )

        # annotate with sample type (e.g., "Stool", "Saliva", "Capsule 2", etc.)
        snv_freqs[cluster_i]['sample_type'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][2])
        # annotate with date
        snv_freqs[cluster_i]['date'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][0])
        # annotate with time
        snv_freqs[cluster_i]['time'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][1])
        # annotate with timepoint
        snv_freqs[cluster_i]['timepoint'] = snv_freqs[cluster_i]['sample'].apply(lambda x: "%s,\n%s" % (sample_sorting_dict[x][0],sample_sorting_dict[x][1]))
        # sample order
        # snv_freqs[cluster_i]['sample_order'] = snv_freqs[cluster_i]['sample'].map(sorting_dict)
        snv_freqs[cluster_i]['sample_order'] = snv_freqs[cluster_i]['sample'].apply(lambda x: sample_sorting_dict[x][3])

    snv_freqs_combined = pd.concat(
        [df.assign(cluster=i) for i, df in enumerate(snv_freqs)],
        ignore_index=True
    )


    # plotting
    fig, ax = plt.subplots(figsize=(20, 8))


    color_palette = sns.color_palette("husl", len(snv_freqs))


    for cluster_i,cluster in enumerate(snv_freqs):
        for site_id in cluster['site_id'].unique():
            snv_freq_individual = cluster[cluster['site_id'] == site_id]
            sns.lineplot(data=snv_freq_individual, x = "sample_order", y = "allele_frequency", color = color_palette[cluster_i])

    # Add legend manually
    cluster_labels = [f"Cluster {i}" for i in range(len(snv_freqs))]
    handles = [plt.Line2D([0], [0], color=color_palette[i], lw=2) for i in range(len(snv_freqs))]
    ax.legend(
        handles,
        cluster_labels,
        title="Genomic clusters",
        loc="center left",
        bbox_to_anchor=(1, 0.5),  # Position the legend to the right of the plot
        ncol=1  # Arrange legend items in a single column
    )



    ax.set_xticks(range(len(sample_sorting_dict.keys())))

    major_ticks = []
    major_tick_labels = []
    minor_ticks = []
    minor_tick_labels = []
    time_point = ""
    x_ticks_loc = ax.get_xticks()
    vspan_counter = 0
    vspan_vec = []
    for i, sample_ in enumerate(sample_sorting_dict.keys()):
        # Major ticks: timepoint
        new_time_point = "%s,\n%s" % (sample_sorting_dict[sample_][0],sample_sorting_dict[sample_][1])
        if (time_point != new_time_point) & (i != len(sample_sorting_dict) - 1):
            time_point = new_time_point
            # major_ticks.append(x_ticks_loc[i])
            major_tick_labels.append(time_point)

            # add vspan
            vspan_counter += 1

            if vspan_counter == 1:
                xmin = 0
            else: 
                xmax = x_ticks_loc[i] - 0.5
                vspan_vec.append([xmin, xmax])
                xmin = xmax
        elif (time_point != new_time_point) & (i == len(sample_sorting_dict) - 1):  
            time_point = new_time_point    
            major_tick_labels.append(time_point)
            
            xmax = x_ticks_loc[i] - 0.5
            vspan_vec.append([xmin, xmax])
            xmin = xmax
            xmax = x_ticks_loc[i]
            vspan_vec.append([xmin, xmax])
        elif (i == len(sample_sorting_dict) - 1):
            xmax = x_ticks_loc[i]
            vspan_vec.append([xmin, xmax])

        # Minor ticks: sample type
        sample_type = sample_sorting_dict[sample_][2]
        minor_ticks.append(x_ticks_loc[i])
        minor_tick_labels.append(sample_type)

    # adding vspan and creating minor ticks
    for i,v in enumerate(vspan_vec):
        # adding vspan
        if i % 2 == 1:
            ax.axvspan(v[0],v[1],alpha=.2,color='grey') 
        
        if (i == 0) & (v[1] == 0.5):
            major_ticks.append(0)
        elif (i == len(vspan_vec) - 1):
            major_ticks.append(np.mean([v[0] + 0.5, v[1]]))
        else:
            major_ticks.append(np.mean(v))


    ax.set_xticks(major_ticks)
    ax.set_xticklabels(major_tick_labels)
    ax.set_xticks(minor_ticks, minor=True)
    ax.set_xticklabels(minor_tick_labels, minor=True)

    ax.xaxis.remove_overlapping_locs = False

    ax.set_ylim(0,1)

    # Tick parameters
    plt.tick_params(axis='x',which='major',bottom=False,left=False,top=False) 
    ax.tick_params(axis='x', which='major', pad=15)  # Move major ticks down by increasing the padding
    ax.tick_params(axis='x', which='minor', pad=2) 

    # Titles

    ax.set_title("%s%s%s%s" % ("SNV frequencies in ", species, " in subject ", str(subject_id)), size = 20)
    ax.set_xlabel("Sample", size = 20)
    ax.set_ylabel("Allele frequency", size = 20)


    plt.tight_layout()

    out_path = "/u/project/ngarud/michaelw/Diversity-Along-Gut/Shalon_2023/figures/evolutionary_changes/SNV_freqs_%s_%s_nov13.png" % (species,str(subject_id))
    fig.savefig(out_path, dpi = 300, facecolor = "white")


In [None]:
change_summary_df[(change_summary_df['species'] == "Alistipes_putredinis_61533") & ((change_summary_df['subject_1'] == "9"))]