In [1]:
from utils import config, sample_utils as su
from utils import parse_midas_data, parse_patric, core_gene_utils, midas_db_utils
import numpy as np, pickle, sys
from numpy.random import choice
from collections import defaultdict

from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

In [2]:
plot_dir = "%s/revs_genes/" % (config.analysis_directory)
ddir = config.data_directory
pdir = "%s/pickles" % ddir
sweep_type = 'full'
snp_modifications_by_site = pickle.load(open('%s/snp_modifications_by_site_%s.pkl' % (pdir, sweep_type), 'rb'))
snp_modification_genes = pickle.load(open('%s/snp_modification_genes_%s.pkl' % (pdir, sweep_type), 'rb'))

In [3]:
# Returns two arrays, one of ordered timepoints
# and other of corresponding values
def order_by_tp(tp_dict):
    items = tp_dict.items()
    mother_items = filter(lambda x: x[0][0] == 'M', items)
    mother_sorted = sorted(mother_items, key=lambda x: int(x[0][1:]))
    infant_items = filter(lambda x: x[0][0] == 'I', items)
    infant_sorted = sorted(infant_items, key=lambda x: int(x[0][1:]))
    comb_sorted = mother_sorted + infant_sorted
    return comb_sorted if comb_sorted != [] else sorted(items, key=lambda x: int(x[0][1:]))

In [4]:
# Returns True if one value is below lower_threshold
# and the other value is above upper_threshold
def is_diff(val1, val2, lower_threshold, upper_threshold):
    return (val1 <= lower_threshold and val2 >= upper_threshold) or (val1 >= upper_threshold and val2 <= lower_threshold)

In [5]:
# Determines if list of values has reversion
def has_reversion(vals, lower_threshold, upper_threshold):
    ref_val = vals[0]
    one_change = False
    for val in vals[1:]:
        # While first change has yet to occur
        if one_change == False:
            if is_diff(val, ref_val, lower_threshold, upper_threshold):
                # If first change detected, set ref_val to new value
                # and set one_change to True
                ref_val = val
                one_change = True
        # First change has already occurred with ref_val reset
        elif one_change == True:
            if is_diff(val, ref_val, lower_threshold, upper_threshold):
                # Second change (must be in opp. direction) detected
                return True
    return False

In [9]:
reversions = []
lower_threshold = 0.2
upper_threshold = 0.8

# Enumerate reversions
for species in snp_modifications_by_site:
    
    # Also prepare gene information
    if sum([len(snp_modifications_by_site[species][cohort].keys()) for cohort in snp_modifications_by_site[species]]) != 0:
        print("Getting gene info for %s..." % species)
        genome_ids = midas_db_utils.get_ref_genome_ids(species)
        non_shared_genes = core_gene_utils.parse_non_shared_pangenome_genes(species)
        gene_desc = parse_patric.load_patric_gene_descriptions(genome_ids, non_shared_genes)
    
    for cohort in snp_modifications_by_site[species]:
        subdict = snp_modifications_by_site[species][cohort]
        for subject in subdict:
            for site in subdict[subject]:
                if len(subdict[subject][site]) > 1:
                    tp_freq_dict = {} # freq info only
                    for snp_change in subdict[subject][site]:
                        tp_pair, gene_id, var_type, A1, D1, A2, D2 = snp_change
                        tp1, tp2 = tp_pair
                        tp_freq_dict[tp1] = (A1/float(D1))
                        tp_freq_dict[tp2] = (A2/float(D2))
                    ordered_tp_freqs = order_by_tp(tp_freq_dict)
                    gene_name = gene_desc[gene_id] if gene_id in gene_desc else 'No name'
                    if has_reversion([freq for tp, freq in ordered_tp_freqs], lower_threshold, upper_threshold):
                        reversions.append((species, cohort, subject, site, gene_id, gene_name, var_type, ordered_tp_freqs))

In [18]:
all_subjects = set()
# Enumerate reversions
for species in snp_modifications_by_site:
    for cohort in snp_modifications_by_site[species]:
        subdict = snp_modifications_by_site[species][cohort]
        for subject in subdict:
            all_subjects.add(subject)

In [21]:
subjects = set()
for species, cohort, subject, site, gene_id, gene_name, var_type, ordered_tp_freqs in reversions:
    subjects.add(subject)

In [22]:
subjects

{'M0808-M', 'M1098-M', 'N1_009', 'N1_011', 'N1_018', 'N1_023', 'N4_097'}

In [None]:
cohorts = ['backhed', 'ferretti', 'yassour', 'shao', 'olm', 'hmp']

genes_species = {cohort: defaultdict(list) for cohort in cohorts}
genes_tp_pairs = {cohort: defaultdict(list) for cohort in cohorts}

# Enumerate genes
for species in snp_modification_genes:
    
    if sum([len(snp_modification_genes[species][cohort].keys()) for cohort in snp_modification_genes[species]]) != 0:
        print("Getting gene info for %s..." % species)
        genome_ids = midas_db_utils.get_ref_genome_ids(species)
        non_shared_genes = core_gene_utils.parse_non_shared_pangenome_genes(species)
        gene_desc = parse_patric.load_patric_gene_descriptions(genome_ids, non_shared_genes)
    
    for cohort in snp_modification_genes[species]:
        subdict = snp_modification_genes[species][cohort]
        
        for tp_pair in subdict:
            for gene_id in subdict[tp_pair].keys():
                gene_name = gene_desc[gene_id] if gene_id in gene_desc else 'No name'
                genes_species[cohort][(gene_id, gene_name)].append(species)
                genes_tp_pairs[cohort][(gene_id, gene_name)].append(tp_pair)

In [None]:
genes_cohort = {cohort: set() for cohort in cohorts}

for cohort in genes_tp_pairs:
    for gene_tuple in genes_tp_pairs[cohort]:
        for tp_pair in genes_tp_pairs[cohort][gene_tuple]:
            tpa, tpb = tp_pair
            tp_type = tpa[0] + tpb[0]
            if tp_type == 'II':
                genes_cohort[cohort].add(gene_tuple)

In [11]:
# Formatted printing of reversions
for species, cohort, subject, site, gene_id, gene_name, var_type, ordered_tp_freqs in reversions:
    
    print(species + ' | ' + cohort + ' | Subject: ' + subject)
    print("Gene id: " + gene_id + ' | Variant type: ' + var_type)
    print("Gene name: " + gene_name)
    for tp, freq in ordered_tp_freqs:
        print('\t' + tp + ': %.03f' % freq)
    print('')

Enterococcus_faecalis_56297 | olm | Subject: N1_018
Gene id: 1158976.3.peg.771 | Variant type: 4D
Gene name: hypothetical protein
	I18: 0.826
	I19: 0.074
	I20: 0.074
	I24: 0.801
	I37: 0.118

Klebsiella_pneumoniae_54788 | olm | Subject: N1_023
Gene id: 1328400.3.peg.232 | Variant type: 1D
Gene name: No name
	I15: 0.026
	I17: 0.897
	I18: 0.818
	I19: 0.088
	I21: 0.911
	I29: 0.172

Citrobacter_freundii_56148 | olm | Subject: N1_009
Gene id: 1114920.3.peg.2929 | Variant type: 4D
Gene name: Phosphoenolpyruvate-dihydroxyacetone phosphotransferase operon regulatory protein DhaR
	I10: 0.870
	I13: 0.164
	I15: 0.091
	I16: 0.191
	I17: 0.944

Bacteroides_vulgatus_57955 | olm | Subject: N4_097
Gene id: 435590.9.peg.1561 | Variant type: 4D
Gene name: Uncharacterized protein BT3327
	I12: 0.112
	I27: 0.826
	I34: 0.115

Bacteroides_vulgatus_57955 | yassour | Subject: M1098-M
Gene id: 435590.9.peg.52 | Variant type: 1D
Gene name: DNA internalization-related competence protein ComEC/Rec2
	M1: 1.000
	I2: 0