In [1]:
from utils import config, parse_midas_data, sample_utils as su, temporal_changes_utils, stats_utils, midas_db_utils, parse_patric
from collections import defaultdict
import math, random, numpy as np
import pickle, sys, bz2
import matplotlib.pyplot as plt

# Cohort list
cohorts = ['backhed', 'ferretti', 'yassour', 'shao', 'olm', 'hmp']

# Plot directory
plot_dir = "%s/" % (config.analysis_directory)

# Species list
good_species_list = parse_midas_data.load_pickled_good_species_list()

# Sample-subject-order maps
sys.stderr.write("Loading sample metadata...\n")
subject_sample_map = su.parse_subject_sample_map()
sample_order_map = su.parse_sample_order_map()
sample_subject_map = su.parse_sample_subject_map()
same_mi_pair_dict = su.get_same_mi_pair_dict(subject_sample_map)
sys.stderr.write("Done!\n")

# Timepoint pair types
tp_pair_names = ['MM', 'MI', 'II', 'AA']

# Cohorts
cohorts = ['backhed', 'ferretti', 'yassour', 'shao', 'hmp']
mi_cohorts = ['backhed', 'ferretti', 'yassour', 'shao']

# Samples for each cohort
samples = {cohort: su.get_sample_names(cohort) for cohort in cohorts}
hmp_samples = su.get_sample_names('hmp')
mother_samples = su.get_sample_names('mother')
infant_samples = su.get_sample_names('infant')
olm_samples = su.get_sample_names('olm')
infant_samples_no_olm = [sample for sample in infant_samples if sample not in olm_samples]
mi_samples_no_olm = [sample for sample in (mother_samples + infant_samples) if sample not in olm_samples]

# Sample-cohort map
sample_cohort_map = su.parse_sample_cohort_map()

# Sample-timepoint map
mi_sample_day_dict = su.get_mi_sample_day_dict(exclude_cohorts=['olm'])
mi_tp_sample_dict = su.get_mi_tp_sample_dict(exclude_cohorts=['olm']) # no binning
mi_tp_sample_dict_binned, mi_tp_binned_labels = su.get_mi_tp_sample_dict(exclude_cohorts=['olm'], binned=True)

Loading sample metadata...
Done!


In [2]:
# ======================================================================
# Load pickled data
# ======================================================================

# Parameters
sweep_type = 'full' # assume full for now
pp_prev_cohort = 'all'
min_coverage = 0

ddir = config.data_directory
pdir = "%s/pickles/cov%i_prev_%s/" % (ddir, min_coverage, pp_prev_cohort)

snp_changes = pickle.load(open('%s/big_snp_changes_%s.pkl' % (pdir, sweep_type), 'rb'))
gene_changes = pickle.load(open('%s/big_gene_changes_%s.pkl' % (pdir, sweep_type), 'rb'))
snp_change_freqs = pickle.load(open('%s/snp_change_freqs_%s.pkl' % (pdir, sweep_type), 'rb'))
snp_change_null_freqs = pickle.load(open('%s/snp_change_null_freqs_%s.pkl' % (pdir, sweep_type), 'rb'))
gene_gain_freqs = pickle.load(open('%s/gene_gain_freqs_%s.pkl' % (pdir, sweep_type), 'rb'))
gene_loss_freqs = pickle.load(open('%s/gene_loss_freqs_%s.pkl' % (pdir, sweep_type), 'rb'))
gene_loss_null_freqs = pickle.load(open('%s/gene_loss_null_freqs_%s.pkl' % (pdir, sweep_type), 'rb'))
between_snp_change_counts = pickle.load(open('%s/between_snp_change_counts_%s.pkl' % (pdir, sweep_type), 'rb'))
between_gene_change_counts = pickle.load(open('%s/between_gene_change_counts_%s.pkl' % (pdir, sweep_type), 'rb'))

In [3]:
# Relative abundance file
relab_fpath = "%s/species/relative_abundance.txt.bz2" % (config.data_directory)
relab_file = open(relab_fpath, 'r')
decompressor = bz2.BZ2Decompressor()
raw = decompressor.decompress(relab_file.read())
data = [row.split('\t') for row in raw.split('\n')]
data.pop() # Get rid of extra element due to terminal newline
header = su.parse_merged_sample_names(data[0]) # species_id, samples...

# Load species presence/absence information
sample_species_dict = defaultdict(set)

for row in data[1:]:
    species = row[0]
    for relab_str, sample in zip(row[1:], header[1:]):
        relab = float(relab_str)
        if relab > 0:
            sample_species_dict[sample].add(species)

In [4]:
# Custom sample pair cohorts [not just sample!]

is_mi = lambda sample_i, sample_j: ((sample_i in mother_samples and sample_j in infant_samples_no_olm) and mi_sample_day_dict[sample_i] >= 0 and mi_sample_day_dict[sample_i] <= 7 and mi_sample_day_dict[sample_j] <= 7)

In [13]:
num_transmission = 0 # Number of MI QP pairs which are strain transmissions
num_transmission_shared_species = []
num_replacement = 0 # Number of MI QP pairs which are strain replacements
num_replacement_shared_species = []
num_total = 0 # Total number of MI QP pairs (sanity check)
num_shared_species_per_dyad = {}
shared_highcov_species_per_dyad = defaultdict(set)
existing_hosts = set()

# For every mother-infant QP pair, also count number of shared species
for species in snp_changes:
    for sample_i, sample_j in snp_changes[species]:
        
        # Only consider mother-infant QP pairs
        if not is_mi(sample_i, sample_j):
            continue
        
        # Make sure only one sample pair per host
        host = sample_order_map[sample_i][0][:-2]
        if host in existing_hosts:
            continue
        existing_hosts.add(host)
        
        num_total += 1
        
        # Get number of shared species
        shared_species = sample_species_dict[sample_i].intersection(sample_species_dict[sample_j])
        num_shared_species = len(shared_species)
        
        num_shared_species_per_dyad[(sample_i, sample_j)] = num_shared_species
        shared_highcov_species_per_dyad[(sample_i, sample_j)].add(species)
        
        # Get number of SNP differences
        val = snp_changes[species][(sample_i, sample_j)]
                
        if (type(val) == type(1)): # Replacement
            num_replacement += 1
            num_replacement_shared_species.append(num_shared_species)
        else: # Modification or no change
            num_transmission += 1
            num_transmission_shared_species.append(num_shared_species)

In [7]:
hosts = defaultdict(int)
for s1, s2 in num_shared_species_per_dyad:
    if sample_order_map[s1][0][:-2] != sample_order_map[s2][0][:-2]:
        print("Weird")
    hosts[sample_order_map[s1][0][:-2]] += 1

In [18]:
print("%i transmissions" % num_transmission)
print("%i shared species aggregated over transmissions" % sum([nss for nss in num_transmission_shared_species]))
print("%i replacements" % num_replacement)
print("%i shared species aggregated over replacements" % sum([nss for nss in num_replacement_shared_species]))
print("%i total mother-infant QP pairs" % num_total)
print("%i total shared species aggregated over dyads" % sum(num_shared_species_per_dyad.values()))
print("%i dyads" % len(shared_highcov_species_per_dyad))
print("%i total shared highcov species aggregated over dyads" % sum([len(shared_highcov_species_per_dyad[dyad]) for dyad in shared_highcov_species_per_dyad]))

40 transmissions
1397 shared species aggregated over transmissions
6 replacements
247 shared species aggregated over replacements
46 total mother-infant QP pairs
1644 total shared species aggregated over dyads
46 dyads
46 total shared highcov species aggregated over dyads


In [34]:
float(num_transmission)/(sum([len(nss) for nss in num_transmission_shared_species])*2)

0.013814832767813864