# Data analysis workflow for the study "Pairs of amino acids at the P- and A-sites of the ribosome predictably and causally modulate translation-elongation rates"

### List of data used in this study

1. Jan, C. H., Williams, C. C., & Weissman, J. S. (2014). Principles of ER cotranslational translocation revealed by proximity-specific ribosome profiling. Science, 346(6210).

2. Williams, C. C., Jan, C. H., & Weissman, J. S. (2014). Targeting and plasticity of mitochondrial proteins revealed by proximity-specific ribosome profiling. Science, 346(6210), 748-751.

3. Young, D. J., Guydosh, N. R., Zhang, F., Hinnebusch, A. G., & Green, R. (2015). Rli1/ABCE1 recycles terminating ribosomes and controls translation reinitiation in 3′ UTRs in vivo. Cell, 162(4), 872-884.

4. Weinberg, D. E., Shah, P., Eichhorn, S. W., Hussmann, J. A., Plotkin, J. B., & Bartel, D. P. (2016). Improved ribosome-footprint and mRNA measurements provide insights into dynamics and regulation of yeast translation. Cell reports, 14(7), 1787-1799.

5. Nissley, D. A., Sharma, A. K., Ahmed, N., Friedrich, U. A., Kramer, G., Bukau, B., & O’Brien, E. P. (2016). Accurate prediction of cellular co-translational folding indicates proteins can switch from post-to co-translational folding. Nature communications, 7(1), 1-13.



In [None]:
## How to run the 

In [8]:
from __future__ import division
import matplotlib
import numpy as np
import math
import os
import pickle as Pickle
from optparse import OptionParser
import matplotlib.backends.backend_pdf as pdf
from scipy import stats
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib as mpl
from matplotlib.table import Table
import sys
import statsmodels.sandbox.stats.multicomp as mc
import operator as op
from matplotlib.ticker import FormatStrFormatter
from time import localtime, strftime
import itertools
from matplotlib import rcParams
import matplotlib.pyplot as plt



CODON_TYPES = ['UUU', 'UUC', 'UUA', 'UUG', 'CUU', 'CUC', 'CUA', 'CUG', 'AUU', 'AUC', 'AUA', 'AUG', 'GUU', 'GUC', 'GUA',
               'GUG', 'UCU', 'UCC', 'UCA', 'UCG', 'CCU', 'CCC', 'CCA', 'CCG', 'ACU', 'ACC', 'ACA', 'ACG', 'GCU', 'GCC',
               'GCA', 'GCG', 'UAU', 'UAC', 'CAU', 'CAC', 'CAA', 'CAG', 'AAU', 'AAC', 'AAA', 'AAG', 'GAU', 'GAC', 'GAA',
               'GAG', 'UGU', 'UGC', 'UGG', 'CGU', 'CGC', 'CGA', 'CGG', 'AGU', 'AGC', 'AGA', 'AGG', 'GGU', 'GGC', 'GGA',
               'GGG', 'UAA', 'UAG', 'UGA']

genetic_code = {'UUU': 'F', 'UCU': 'S', 'UAU': 'Y', 'UGU': 'C', 'UUC': 'F', 'UCC': 'S', 'UAC': 'Y', 'UGC': 'C',
                'UUA': 'L', 'UCA': 'S', 'UAA': '*', 'UGA': '*', 'UUG': 'L', 'UCG': 'S', 'UAG': '*', 'UGG': 'W',
                'CUU': 'L', 'CCU': 'P', 'CAU': 'H', 'CGU': 'R', 'CUC': 'L', 'CCC': 'P', 'CAC': 'H', 'CGC': 'R',
                'CUA': 'L', 'CCA': 'P', 'CAA': 'Q', 'CGA': 'R', 'CUG': 'L', 'CCG': 'P', 'CAG': 'Q', 'CGG': 'R',
                'AUU': 'I', 'ACU': 'T', 'AAU': 'N', 'AGU': 'S', 'AUC': 'I', 'ACC': 'T', 'AAC': 'N', 'AGC': 'S',
                'AUA': 'I', 'ACA': 'T', 'AAA': 'K', 'AGA': 'R', 'AUG': 'M', 'ACG': 'T', 'AAG': 'K', 'AGG': 'R',
                'GUU': 'V', 'GCU': 'A', 'GAU': 'D', 'GGU': 'G', 'GUC': 'V', 'GCC': 'A', 'GAC': 'D', 'GGC': 'G',
                'GUA': 'V', 'GCA': 'A', 'GAA': 'E', 'GGA': 'G', 'GUG': 'V', 'GCG': 'A', 'GAG': 'E', 'GGG': 'G'}

# In the following dict, synonymous codons for each amino acid are grouped in list such that they are decoded by similar tRNA.
# For example, in amino acid 'A', GCU and GCC codons are decoded by one type of tRNA while GCA and GCG are decoded by another kind of tRNA
synonymous = {'A': [['GCU', 'GCC'], ['GCA', 'GCG']],
              'C': [['UGU', 'UGC']],
              'D': [['GAU', 'GAC']],
              'E': [['GAA'], ['GAG']],
              'F': [['UUU', 'UUC']],
              'G': [['GGU', 'GGC'], ['GGA'], ['GGG']],
              'H': [['CAU', 'CAC']],
              'I': [['AUU', 'AUC'], ['AUA']],
              'K': [['AAG'], ['AAA']],
              'L': [['UUG'], ['UUA'], ['CUC', 'CUU'], ['CUA', 'CUG']],
              'M': [['AUG']],
              'N': [['AAU', 'AAC']],
              'P': [['CCA', 'CCG'], ['CCU', 'CCC']],
              'Q': [['CAA'], ['CAG']],
              'R': [['AGA'], ['CGU', 'CGC'], ['CGG', 'CGA'], ['AGG']],
              'S': [['UCU', 'UCC'], ['AGU', 'AGC'], ['UCA'], ['UCG']],
              'T': [['ACU', 'ACC'], ['ACA'], ['ACG']],
              'V': [['GUU', 'GUC'], ['GUG'], ['GUA']],
              'W': [['UGG']],
              'Y': [['UAU', 'UAC']],
              '*': [['UAA', 'UAG', 'UGA']]
              }

AMINO_ACIDS = ['A', 'R', 'D', 'N', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', '*']

base_pairing = {'A': {'Wobble': ['GCC', 'GCG'], 'Watson-Crick': ['GCA', 'GCU']},
                'C': {'Wobble': ['UGU'], 'Watson-Crick': ['UGC']},
                'D': {'Wobble': ['GAU'], 'Watson-Crick': ['GAC']},
                'E': {'Wobble': [], 'Watson-Crick': ['GAA', 'GAG']},
                'F': {'Wobble': ['UUU'], 'Watson-Crick': ['UUC']},
                'G': {'Wobble': ['GGU'], 'Watson-Crick': ['GGA', 'GGC', 'GGG']},
                'H': {'Wobble': ['CAU'], 'Watson-Crick': ['CAC']},
                'I': {'Wobble': ['AUC'], 'Watson-Crick': ['AUA', 'AUU']},
                'K': {'Wobble': [], 'Watson-Crick': ['AAA', 'AAG']},
                'L': {'Wobble': ['CUG', 'CUU'], 'Watson-Crick': ['CUA', 'CUC', 'UUA', 'UUG']},
                'M': {'Wobble': [], 'Watson-Crick': ['AUG']},
                'N': {'Wobble': ['AAU'], 'Watson-Crick': ['AAC']},
                'P': {'Wobble': ['CCC', 'CCG'], 'Watson-Crick': ['CCA', 'CCU']},
                'Q': {'Wobble': [], 'Watson-Crick': ['CAA', 'CAG']},
                'R': {'Wobble': ['CGA', 'CGC'], 'Watson-Crick': ['AGA', 'AGG', 'CGG', 'CGU']},
                'S': {'Wobble': ['UCC', 'AGU'], 'Watson-Crick': ['UCA', 'UCG', 'UCU', 'AGC']},
                'T': {'Wobble': ['ACC'], 'Watson-Crick': ['ACA', 'ACU', 'ACG']},
                'V': {'Wobble': ['GUC'], 'Watson-Crick': ['GUA', 'GUG', 'GUU']},
                'W': {'Wobble': [], 'Watson-Crick': ['UGG']},
                'Y': {'Wobble': ['UAU'], 'Watson-Crick': ['UAC']},
                '*': {'Wobble': [], 'Watson-Crick': ['UAA', 'UAG', 'UGA']}}

# Optimal codons selected based on their corresponding tRNA abundance (measured by RNA-Seq in Weinberg et al). Wobble only pairs are measured by 0.64*cognate tRNA concentration.
# Corrected mistake for G. Earlier it was 'G': {'Non-optimal': ['GGC', 'GGG'], 'Optimal': ['GGA', 'GGU']},
optimal_codon_usage = {'A': {'Non-optimal': ['GCC', 'GCG'], 'Optimal': ['GCA', 'GCU']},
                       'C': {'Non-optimal': ['UGU'], 'Optimal': ['UGC']},
                       'D': {'Non-optimal': ['GAU'], 'Optimal': ['GAC']},
                       'E': {'Non-optimal': ['GAG'], 'Optimal': ['GAA']},
                       'F': {'Non-optimal': ['UUU'], 'Optimal': ['UUC']},
                       'G': {'Non-optimal': ['GGA', 'GGG'], 'Optimal': ['GGC', 'GGU']},
                       'H': {'Non-optimal': ['CAU'], 'Optimal': ['CAC']},
                       'I': {'Non-optimal': ['AUA'], 'Optimal': ['AUC', 'AUU']},
                       'K': {'Non-optimal': ['AAA'], 'Optimal': ['AAG']},
                       'L': {'Non-optimal': ['CUA', 'CUC', 'CUG', 'CUU'], 'Optimal': ['UUA', 'UUG']},
                       'M': {'Non-optimal': [], 'Optimal': ['AUG']},
                       'N': {'Non-optimal': ['AAU'], 'Optimal': ['AAC']},
                       'P': {'Non-optimal': ['CCC', 'CCU'], 'Optimal': ['CCA', 'CCG']},
                       'Q': {'Non-optimal': ['CAG'], 'Optimal': ['CAA']},
                       'R': {'Non-optimal': ['AGG', 'CGG', 'CGA', 'CGC'], 'Optimal': ['AGA',  'CGU']},
                       'S': {'Non-optimal': ['UCA', 'UCG', 'AGU', 'AGC'], 'Optimal': ['UCC', 'UCU']},
                       'T': {'Non-optimal': ['ACA', 'ACG'], 'Optimal': ['ACC', 'ACU']},
                       'V': {'Non-optimal': ['GUA', 'GUG'], 'Optimal': ['GUC', 'GUU']},
                       'W': {'Non-optimal': [], 'Optimal': ['UGG']},
                       'Y': {'Non-optimal': ['UAU'], 'Optimal': ['UAC']}}

# Most optimal codon for every amino acid
most_optimal_codon = {'A': 'GCU', 'C': 'UGC', 'D': 'GAC', 'E': 'GAA', 'F': 'UUC', 'G': 'GGC', 'H': 'CAC', 'I': 'AUU', 'K': 'AAG', 'L': 'UUG', 'M': 'AUG', 'N': 'AAC', 'P': 'CCA',
                      'Q': 'CAA', 'R': 'AGA', 'S': 'UCU', 'T': 'ACU', 'V': 'GUU', 'W': 'UGG', 'Y': 'UAC', '*': 'UAA'}

# Optimal and non-optimal codons based on Penchman, Frydman, tAI cutoff of 0.47 as well as used for codon optimality in Jeff Coller's paper.
optimal_dict = {'Optimal': ['GCU', 'GCC', 'GAC', 'GAA', 'UUC', 'GGC', 'AUU', 'AUC', 'AAG', 'UUG', 'AUG', 'AAC', 'CCA', 'CAA', 'AGA', 'UCU', 'UCC', 'ACU', 'ACC', 'GUU', 'GUC', 'UAC'],
                'Non-optimal': ['GCA', 'GCG', 'UGC', 'UGU', 'GAU', 'GAG', 'UUU', 'GGU', 'GGA', 'GGG', 'CAC', 'CAU', 'AUA', 'AAA', 'UUA', 'CUA', 'CUC', 'CUG', 'CUU', 'AAU', 'CCG',
                                'CCU', 'CCC', 'CAG', 'CGU', 'AGG', 'CGC', 'CGG', 'CGA', 'UCA', 'AGC', 'UCG', 'AGU', 'ACA', 'ACG', 'GUG', 'GUA', 'UGG', 'UAU']}

CHROMOSOMES = ['chrI', 'chrII', 'chrIII', 'chrIV', 'chrV', 'chrVI', 'chrVII', 'chrVIII', 'chrIX', 'chrX', 'chrXI', 'chrXII', 'chrXIII', 'chrXIV', 'chrXV', 'chrXVI', 'chrM']


### Defining the python functions below that will be called to execute the data analysis for different datasets

In [None]:
# Parse transcriptome file to get a dictionary of codon types for genes.
def parse_sequence(transcript_file, out_loc='/gpfs/group/epo2/default/nxa176/reference/sacCer3/Reference_files/'):
    codon_type_dict = {}
    with open(transcript_file) as f:
        for lines in f:
            fields = lines.strip().split('\t')
            gene = fields[0]
            if gene.startswith('Q'):
                continue
            start_index = int(fields[1])
            length = int(fields[2])
            # utr5 = list(fields[3])[:abs(start_index)]
            cds = list(fields[3])[abs(start_index):abs(start_index)+length]
            # Since we are looking at mRNA, convert any 'T' to 'U'
            codon_list = ['U' if x == 'T' else x for x in cds]
            codon_type_dict[gene] = []
            for x in range(0, len(codon_list), 3):
                codon_type_dict[gene].append(''.join(codon_list[x:x+3]))

            # Sanitary checks for start and stop codons
            try:
                if codon_type_dict[gene][0] != "AUG":
                    print gene + " does not have a AUG start codon and the start codon is " + codon_type_dict[gene][0]
            except KeyError:
                print 'KeyError in finding the start codon for gene ', gene, codon_type_dict[gene]
            if codon_type_dict[gene][-1] not in ["UAA", "UAG", "UGA"]:
                print gene + " does not have UAA/UAG/UGA stop codon and stop codon is " + codon_type_dict[gene][-1]

    Pickle.dump(codon_type_dict, open(out_loc+'codon_type_dict.p', 'wb'))

    return codon_type_dict


# Parses DMS input file to determine the in vivo mRNA secondary sturcture status of indivudal nt in genes. Used when determining the prob of mol factor and controlling for it.
def parse_dms_codon_level(dms_file):
    dms_dict = {}
    # Parse the DMS tab file and populate the dictionary
    with open(dms_file) as infile:
        infile.readline()
        for lines in infile:
            fields = lines.strip().split('\t')
            gene_name = fields[0]
            classifier_list = map(int, fields[2].split(','))
            dms_dict[gene_name] = classifier_list

    return dms_dict


# Parses A-site file for reads per nucleotide.
# Converts reads per nucleotide to reads per codon.
# Filters genes according to the criteria specified in options: genes which have at least 3 reads per codon or/and do not overlap with any other genes or/and do not have introns
def gene_codon_filter(asite_file, mul_map_file, mul_map_threshold=1.0, filter_threshold=0.1, read_threshold=1, strict=True, window=False, relaxed=False):
    # Total reads mapped to a gene
    unique_mapped_reads = {}
    # Total multiple mapped reads mapped to a gene
    mul_mapped_reads = {}
    # Genes with multiple mapped reads to be filtered out from analysis. Default threshold of 1%. Genes with more than 1% multiple mapped reads will be removed.
    mul_map_genes = []

    overlap_genes = Pickle.load(open('/gpfs/group/epo2/default/nxa176/Yeast_Ribo-seq_datasets/yeast_annotation_files/Pickle_dicts/overlap_genes.p', 'rb'))
    intronic_genes = Pickle.load(open('/gpfs/group/epo2/default/nxa176/Yeast_Ribo-seq_datasets/yeast_annotation_files/Pickle_dicts/intronic_genes.p', 'rb'))

    stats_file = open('Summary_stats.tab', 'w')

    total_read_count = 0
    dict_len = {}
    codon_dict = {}
    # Parse the A-site file to get the reads for each position of each gene
    # We get the reads info in dict_gene_count and the nucleotide info in nuc_dict
    with open(asite_file) as file_asite_table:
        for lines in file_asite_table:
            line_list = lines.strip().split('\t')
            gene = line_list[0]
            gene_length = int(line_list[1])
            count_list = map(int, line_list[2].split(','))
            # Quality check
            if len(count_list) % 3 != 0:
                print 'QUALITY CHECK NOT MET: Gene '+gene+' have a length not a multiple of 3. The length is '+str(len(count_list))
            if gene.startswith('Q'):
                continue
            codon_dict[gene] = []
            # Convert the reads per nucleotide to reads per codon
            for i in range(0, len(count_list), 3):
                codon_dict[gene].append(sum(count_list[i:i+3]))
            dict_len[gene] = gene_length
            unique_mapped_reads[gene] = sum(count_list)
            total_read_count += sum(count_list)

    print 'Parsed the A-site file.'

    # Get the number of mul mapped reads to decide whether to delete the gene or not. If a gene has more than 0.1% of reads multiple mapped, we delete it
    with open(mul_map_file) as f:
        for lines in f:
            line_list = lines.strip().split('\t')
            gene = line_list[0]
            read_count = map(int, line_list[1:])
            mul_mapped_reads[gene] = sum(read_count)

    stats_file.write('Number of genes with A-site profiles: ' + str(len(dict_len)) + '\n')
    stats_file.write('Number of genes containing introns: ' + str(len(intronic_genes)) + '\n')
    stats_file.write('Number of genes containing overlaps: ' + str(len(overlap_genes)) + '\n')
    stats_file.write('Number of genes containing multiple aligned reads: ' + str(len(mul_mapped_reads)) + '\n')
    stats_file.write('Number of reads mapped to the yeast transcriptome: ' + str(total_read_count) + '\n')

    codon_raw_file = open("Codon_reads_all_genes.tab", 'w')
    if relaxed:
        codon_filtered_file = open("Codon_reads_filtered_genes_relaxed_"+str((1-filter_threshold)*100)+".tab", 'w')
    elif strict:
        codon_filtered_file = open("Codon_reads_filtered_genes_strict_threshold_" + str(read_threshold) + ".tab", 'w')
    else:
        codon_filtered_file = open("Codon_reads_filtered_genes_window_median_greater_than_" + str(read_threshold) + ".tab", 'w')
    exp_file = open("Expression_levels_genes.tab", "w")
    exp_file.write('Gene\tLength(codons)\tAverage reads (per codon)\tSum of reads\n')
    codon_raw_file.write('Gene\tLength(codons)\tRaw read profile\n')

    # Count variables to determine the statistics of gene counts in each category
    no_of_genes = 0
    filtered_multistatus = 0
    filtered_1 = 0
    filtered_2 = 0
    filtered_overlap = 0
    filtered_intron = 0

    for gene, gene_len in dict_len.iteritems():
        if gene in mul_mapped_reads:
            try:
                perc_mul_map = float(mul_mapped_reads[gene]) * 100 / float(mul_mapped_reads[gene] + unique_mapped_reads[gene])
            except ZeroDivisionError:
                print 'ZeroDivisionError for mul map calculation for gene '+str(gene)
                print mul_mapped_reads[gene], unique_mapped_reads[gene]
                continue
            if perc_mul_map > mul_map_threshold:
                mul_map_genes.append(gene)

        # Writing out a file at codon level
        cod_len = len(codon_dict[gene])
        # Sanitary check
        if cod_len != gene_len / 3:
            print 'Discrepancy in populating codon dicts. Length of codon dict (' + str(cod_len) + ') not equal to one-third of gene length (' + str(gene_len) + ')'

        codon_raw_file.write(gene+'\t'+str(cod_len)+'\t' + ','.join(map(str, codon_dict[gene])) + '\n')
        avg_reads = np.mean(codon_dict[gene])
        sum_reads = np.sum(codon_dict[gene])

        # Select for high coverage genes based on how many positions have non-zero reads
        number_of_zeroes = codon_dict[gene].count(0)

        # By default, the threshold is 0.1 which means we will select genes which have less than 10% positions with zero reads
        if relaxed and number_of_zeroes <= math.ceil(filter_threshold * cod_len):
            if gene not in overlap_genes and gene not in mul_map_genes and gene not in intronic_genes:
                if gene in mul_map_genes:
                    filtered_multistatus += 1
                    continue
                else:
                    no_of_genes += 1

                exp_file.write(gene + '\t' + str(cod_len) + '\t' + str(avg_reads) + '\t' + str(sum_reads) + '\n')
                codon_filtered_file.write(gene + '\t' + str(cod_len) + '\t' + ','.join(map(str, codon_dict[gene])) + '\n')

        # If we apply the strict criteria where every codon position needs to have greater than read_threshold (default=1) reads
        if strict:
            # We remove the first two codons from the analysis since start codon is expected not to contain any reads and the second codon's ribosome density is influenced by initiation
            # now each codon position should contain at least the read_threhsold number of reads for the gene to be included
            if all(v > read_threshold for v in codon_dict[gene][2:]) and gene not in overlap_genes and gene not in intronic_genes:
                if gene in mul_map_genes:
                    filtered_multistatus += 1
                    continue
                else:
                    no_of_genes += 1

                exp_file.write(gene+'\t'+str(cod_len)+'\t'+str(avg_reads)+'\t'+str(sum_reads)+'\n')
                codon_filtered_file.write(gene + '\t' + str(cod_len) + '\t' + ','.join(map(str, codon_dict[gene])) + '\n')

            elif read_threshold >= 1 and all(v > read_threshold-1 for v in codon_dict[gene][2:]) and gene not in overlap_genes and gene not in intronic_genes and gene not in mul_map_genes:
                filtered_2 += 1
            elif read_threshold >= 2 and all(v > read_threshold-2 for v in codon_dict[gene][2:]) and gene not in overlap_genes and gene not in intronic_genes and gene not in mul_map_genes:
                filtered_1 += 1
            elif all(v > read_threshold for v in codon_dict[gene][2:]) and gene in overlap_genes and gene not in intronic_genes and gene not in mul_map_genes:
                filtered_overlap += 1
            elif all(v > read_threshold for v in codon_dict[gene][2:]) and gene not in overlap_genes and gene in intronic_genes:
                filtered_intron += 1

        if window:
            # Excluding first 15 codons and last 5 codons
            read_list = codon_dict[gene][15:-5]
            if not read_list:
                continue
            window_medians = []
            for i in range(0, len(read_list), 5):
                # window_medians.append(np.median(read_list[i:i+5]))
                window_medians.append(np.sum(read_list[i:i + 5]))
            # if window_medians and all(v > read_threshold for v in window_medians):
            if np.median(window_medians) > read_threshold:
                if gene in mul_map_genes:
                    filtered_multistatus += 1
                    continue
                else:
                    no_of_genes += 1
                codon_filtered_file.write(gene + '\t' + str(cod_len) + '\t' + ','.join(map(str, codon_dict[gene])) + '\n')

    codon_filtered_file.close()

    if relaxed:
        print "No. of genes which have which have greater than " + str(filter_threshold * 100) + "% of positions > 0: " + str(no_of_genes)
        stats_file.write("No. of genes which have which have greater than " + str(filter_threshold * 100) + "% of positions > 0: " + str(no_of_genes) + '\n')
    if strict:
        print "No. of genes with atleast " + str(read_threshold) + " reads in each codon position is: " + str(no_of_genes)
        stats_file.write("No. of genes with greater than " + str(read_threshold) + " reads in each codon: " + str(no_of_genes) + '\n')
        stats_file.write("No. of genes missed out earlier with greater than 1 reads in each codon: " + str(filtered_1) + '\n')
        stats_file.write("No. of genes missed out earlier with greater than 0 reads in each codon: " + str(filtered_2) + '\n')
        stats_file.write("No. of genes missed out earlier with greater than " + str(read_threshold) + " reads in each codon but overlapping genes: " + str(filtered_overlap) + '\n')
        stats_file.write("No. of genes missed out earlier with greater than " + str(read_threshold) + " reads in each codon but intronic genes: " + str(filtered_intron) + '\n')
        stats_file.write("No. of genes missed out earlier with greater than " + str(read_threshold) + " reads in each codon but contains multiple mapped reads: " + str(
            filtered_multistatus) + '\n')
    if window:
        print "No. of genes with atleast " + str(read_threshold) + " median reads in each 5 codon nonoverlapping window: " + str(no_of_genes)

    return codon_filtered_file.name


# Convert the reads per codon of the filtered genes to translation times of codons
def reads_to_translation_times(codon_file, genelist=''):
    translation_times = {}
    time_file = open("Translation_times_profiles.tab", 'w')
    #  If the translation times is to be calculated using instances from a constant set of genes, that list will be parsed here.
    # Make sure the times_dict passed here contain info for all genes and not just for filtered genes.
    list_of_genes = []
    if genelist:
        with open(genelist) as f:
            for lines in f:
                list_of_genes.append(lines.strip())
        print 'Running analysis for set of '+str(len(list_of_genes))+' genes specified in file '+str(genelist)
    # Parse the codon file to get the reads for each codon position of each gene
    # We get the reads info in dict_gene_count and the nucleotide info in nuc_dict
    with open(codon_file) as f:
        for lines in f:
            line_list = lines.strip().split('\t')
            gene = line_list[0]
            if genelist and gene not in list_of_genes:
                continue
            count_list = map(int, line_list[2].split(','))
            no_of_codons = len(count_list)
            synthesis_time = no_of_codons * 200
            # We remove the first two codons from the analysis since start codon is expected not to contain any reads and the second codon's ribosome density is influenced by initiation
            summed_reads = sum(count_list)
            if summed_reads == 0:
                print gene+' does not have any mapped reads'
                continue
            ttime = [(float(reads) / float(summed_reads)) * synthesis_time for reads in count_list]
            translation_times[gene] = ttime
            time_file.write(gene + '\t' + str(len(ttime)) + '\t' + ','.join(map(str, ttime)) + '\n')

    time_file.close()
    Pickle.dump(translation_times, open("Translation_times.p", "wb"))

    return translation_times

### Defining the python functions that are used to compare norm ribosome density distributions and create a matrix

In [10]:
def psite_asite_matrix(times_dict, codon_type_dict, time=True, do_perc_change=True, genelist=''):
    # If the Matrix is to be calculated using instances from a constant set of genes, that list will be parsed here.
    # Make sure the times_dict passed here contain info for all genes and not just for filtered genes.
    list_of_genes = []
    if genelist:
        with open(genelist) as f:
            for lines in f:
                list_of_genes.append(lines.strip())
        print '\nRunning matrix analysis for set of '+str(len(list_of_genes))+' genes specified in file '+str(genelist)

    dict_amino_acids = {}
    dict_aa_class_psite = {}

    for aa in AMINO_ACIDS:
        dict_amino_acids[aa] = []
        dict_aa_class_psite[aa] = {}
        for psite_aa in AMINO_ACIDS:
            # Stop codon cannot be in P-site
            if psite_aa == '*':
                continue
            dict_aa_class_psite[aa][psite_aa] = []

    # We will store P-site aa for each gene and codon position in psite_aa_dict
    psite_aa_dict = {}
    asite_aa_dict = {}
    log_file = open('asite_psite.log', 'w')
    stat_measure_file = open('stat_measures.tab', 'w')
    # Get all the aa info by translating codon_type_dict codons to corresponding amino acids
    for gene, dict_time in times_dict.iteritems():
        if list_of_genes and gene not in list_of_genes:
            continue
        psite_aa_dict[gene] = {}
        asite_aa_dict[gene] = {}
        for codon, trans_time in enumerate(dict_time):
            try:
                # Ignoring the first two codons
                if codon in [0, 1]:
                    continue
                # Get the P-site aa for that codon
                psite_aa = genetic_code[codon_type_dict[gene][codon-1]]
                # Get the A-site aa for that codon
                asite_aa = genetic_code[codon_type_dict[gene][codon]]
                # dict_amino_acids will have trans time for each amino acid
                # Ignore instances which have zero reads. This will most likely happen when we are using instances from constant set of genes which may not have necessarily met the filtering criteria.
                if trans_time > 0:
                    # Time based on translation time calculation. Otherwise normalized ribosome density will be used.
                    if time:
                        dict_amino_acids[asite_aa].append(float(trans_time))
                    else:
                        dict_amino_acids[asite_aa].append(float(trans_time)/200)
                    # dict_aa_class_psite will have a dict of p-site and t-times for all a-site aa. This is a dictionary initialized before for all combo of aa
                    # dict_aa_class_psite[Asite_AA][P-site_AA] = [trans_time, gene, A-site codon number, P-site codon type, A-site codon type]
                    if time:
                        dict_aa_class_psite[asite_aa][psite_aa].append((float(trans_time), gene, codon+1, codon_type_dict[gene][codon-1], codon_type_dict[gene][codon]))
                    else:
                        dict_aa_class_psite[asite_aa][psite_aa].append((float(trans_time)/200, gene, codon+1, codon_type_dict[gene][codon-1], codon_type_dict[gene][codon]))

            except KeyError:
                print gene, codon

    # Initializing a dict for metrics to store for each pair of a-site and p-site
    dict_aa_psite_effect_size = {}
    dict_aa_psite_pval = {}
    dict_aa_psite_sample_size = {}

    # outf = open("Mean_skew_stats.tab", 'w')
    sig_pair_list = psite_sig_pairs('Asite_Psite_matrix_perc_change.tab')

    # Initializing the inner dict for each a-site aa as key
    for aa in AMINO_ACIDS:
        dict_aa_psite_effect_size[aa] = {}
        dict_aa_psite_pval[aa] = {}
        dict_aa_psite_sample_size[aa] = {}

    # For each amino acid in A-site,
    for aa, dict_psite in dict_aa_class_psite.iteritems():
        # for all combinations of aa in P-site and their list of trans time
        for psite_aa, trans_list in dict_psite.iteritems():
            times_list = []
            # the times_list will contain only float values of translation times. trans_list contain many other values like gene name, codon number, codon type etc
            for ttime in trans_list:
                times_list.append(ttime[0])
            # Create the list of trans time for all other amino acids in the P-site excluding the one being compared
            alt_times_list = []
            for alt_aa, list_trans in dict_psite.iteritems():
                if alt_aa != psite_aa and alt_aa != 'P':
                    for ttime in list_trans:
                        alt_times_list.append(ttime[0])
                else:
                    continue
            log_file.write('A-site Amino acid '+aa+'\n')
            log_file.write('P-site Amino acid '+psite_aa+'\n')
            log_file.write('Length of trans_list'+str(len(times_list))+'\n')
            stat_measure_file.write(psite_aa+'\t'+aa+'\t'+str(len(times_list))+'\t'+str(len(alt_times_list))+'\t'+str(np.mean(times_list))+'\t'+str(np.mean(alt_times_list))+'\t'+str(np.median(times_list))+'\t' +
                                    str(np.median(alt_times_list))+'\t'+str(stats.skew(np.asarray(times_list)))+'\t'+str(stats.skew(np.asarray(alt_times_list))))
            if (psite_aa, aa) in sig_pair_list:
                stat_measure_file.write('\tSignificant\n')
            else:
                stat_measure_file.write('\tInsignificant\n')

            if len(times_list) >= 5:
                u, p = stats.mannwhitneyu(times_list, alt_times_list)
                log_file.write('Mann Whitney U test is:\n')
                log_file.write(str(u)+'\t'+str(p)+'\n')
                if aa == 'R' and psite_aa == 'N':
                    print 'Median(N-R)\tMedian(Others)'
                    print np.median(times_list), np.median(alt_times_list)
                    # fig, ax2 = plt.subplots()
                    # sns.distplot(times_list, ax=ax2, kde=True, label=psite_aa + '_' + aa)
                    # sns.distplot(times_list_new, ax=ax2, kde=True, label=psite_new + '_' + aa)
                    # N = max(max(set(times_list)), max(set(alt_times_list)))
                    Pickle.dump(times_list, open(aa + '_' + psite_aa + '_instances.p', 'wb'))
                    Pickle.dump(alt_times_list, open(aa + 'not'+psite_aa+'_instances.p', 'wb'))
                Pickle.dump(times_list, open('pickle_dicts/' + aa + '_' + psite_aa + '_instances.p', 'wb'))
                Pickle.dump(alt_times_list, open('pickle_dicts/' + aa + '_~' + psite_aa + '_instances.p', 'wb'))
                perc_change = ((np.median(times_list) - np.median(alt_times_list))/np.median(alt_times_list))*100
                perc_diff = ((np.median(times_list) - np.median(alt_times_list)) / ((np.median(times_list)+np.median(alt_times_list))/2)) * 100
                # Mostly use perc change as you want to know how much X in P-site causes slowdown/speedup in X-Y pair relative to when X is not present
                if do_perc_change:
                    dict_aa_psite_effect_size[aa][psite_aa] = perc_change
                else:
                    # Use perc_diff when you are comparing two specific AA pairs X-Y and Z-Y
                    dict_aa_psite_effect_size[aa][psite_aa] = perc_diff
                dict_aa_psite_pval[aa][psite_aa] = p
                dict_aa_psite_sample_size[aa][psite_aa] = [len(times_list), len(alt_times_list)]
                # outf.write(psite_aa+'\t'+aa+'\t'+str(np.mean(times_list))+'\t'+str(np.median(times_list))+'\t'+str(stats.skew(times_list)))
            else:
                dict_aa_psite_effect_size[aa][psite_aa] = 0  # 'Sample_less_than_5'
                dict_aa_psite_pval[aa][psite_aa] = 1  # 'Sample_less_than_5'
                dict_aa_psite_sample_size[aa][psite_aa] = [len(times_list), len(alt_times_list)]

    times_file = open("Asite_Psite_times_coordinates.tab", "w")
    times_file.write("Asite\tPsite\tTranslation_time_Asite_codon\tGene\tAsite_codon_number\tPsite_codon\tAsite_codon\n")
    for aa, dict_psite in dict_aa_class_psite.iteritems():
        for psite, trans_list in dict_psite.iteritems():
            for codon in sorted(trans_list, reverse=True):
                times_file.write(aa+'\t'+psite+'\t'+str(codon[0])+'\t'+str(codon[1])+'\t'+str(codon[2])+'\t'+str(codon[3])+'\t'+codon[4]+'\n')
    times_file.close()

    if do_perc_change:
        outf = open('Asite_Psite_matrix_perc_change.tab', 'w')
    else:
        outf = open('Asite_Psite_matrix_perc_diff.tab', 'w')

    # Benjamini-Hochberg correction. We get all the p-values and pool them together in a list and adjust it
    list_of_pval = []
    for aa, data in sorted(dict_aa_psite_effect_size.iteritems()):
        for p_site in sorted(data):
            if dict_aa_psite_pval[aa][p_site] == 2:
                continue
            else:
                list_of_pval.append(dict_aa_psite_pval[aa][p_site])
    hyp_test, pval_adj, alpsidac, alpbonf = mc.multipletests(list_of_pval, method='fdr_bh')   # bonferonni

    dict_aa_psite_pval_adj = {}
    outf.write('A-site AA\tP-site AA\tPercent change\tp-value\tAdjusted p-value\tN(P-A pair)\t(Alt P-A pair)\n')
    # Extract back the adjusted p-values into corresponding cells making sure that the keys are sorted according to when it was put in the list for multiple test correction
    i = 0
    for aa, data in sorted(dict_aa_psite_effect_size.iteritems()):
        dict_aa_psite_pval_adj[aa] = {}
        for p_site, perc_change in sorted(data.iteritems()):
            if dict_aa_psite_pval[aa][p_site] == 2:
                dict_aa_psite_pval_adj[aa][p_site] = 1
            else:
                dict_aa_psite_pval_adj[aa][p_site] = pval_adj[i]
                i += 1
            outf.write(aa+'\t'+p_site+'\t'+str(perc_change)+'%\t'+str(dict_aa_psite_pval[aa][p_site])+'\t'+str(dict_aa_psite_pval_adj[aa][p_site])+'\t' +
                       str(dict_aa_psite_sample_size[aa][p_site][0]) + '\t'+str(dict_aa_psite_sample_size[aa][p_site][1])+'\n')

    return dict_aa_psite_effect_size, dict_aa_psite_pval_adj, dict_aa_class_psite


def parse_psite_matrix_info(matrix_file):
    dict_pairs = {}
    dict_perc_change = {}
    dict_pval = {}
    # Parse the matrix file and get the percent change, pvalue and sample size information
    with open(matrix_file) as f:
        # Skip the header
        f.next()
        for lines in f:
            fields = lines.strip().split('\t')
            aa_pair = fields[1] + fields[0]
            perc = float(fields[2].split('%')[0])
            adj_p_val = float(fields[4])
            samp_size = int(fields[5])
            compared_size = int(fields[6])
            dict_pairs[aa_pair] = [perc, adj_p_val, samp_size, compared_size]
            if fields[0] not in dict_perc_change:
                dict_perc_change[fields[0]] = {}
                dict_pval[fields[0]] = {}
            dict_perc_change[fields[0]][fields[1]] = perc
            dict_pval[fields[0]][fields[1]] = adj_p_val

    return dict_pairs, dict_perc_change, dict_pval


def measure_robustness_of_pairs(dict_of_pairs, datsets, threshold=4, control_factor='', compare_uncontrolled='', condition=''):
    pair_stats = {}
    # Populating a dict with all the statistics
    for data_set in datsets:
        for pair in dict_of_pairs[data_set]:
            if pair not in pair_stats:
                pair_stats[pair] = {'perc': [], 'adj_pval': [], 'samp_size': [], 'other_size': []}
            pair_stats[pair]['perc'].append(dict_of_pairs[data_set][pair][0])
            pair_stats[pair]['adj_pval'].append(dict_of_pairs[data_set][pair][1])
            pair_stats[pair]['samp_size'].append(dict_of_pairs[data_set][pair][2])
            pair_stats[pair]['other_size'].append(dict_of_pairs[data_set][pair][3])

    if control_factor:
        outfile = open('Merged_pair_statistics_'+control_factor+condition+'.tab', 'w')
        statfile = open('Control_factor_stats_'+control_factor+'.tab', 'w')
    else:
        outfile = open('Merged_pairs_statistics'+condition+'.tab', 'w')
        statfile = open('Control_factor_stats_' + condition + '.tab', 'w')
    # Putting all the stats together for all datasets
    for pair in pair_stats:
        outfile.write(pair + '\t' + '\t'.join(map(str, pair_stats[pair]['perc'])) + '\t' + '\t'.join(map(str, pair_stats[pair]['adj_pval']))
                      + '\t' + '\t'.join(map(str, pair_stats[pair]['samp_size'])) + '\t' + '\t'.join(map(str, pair_stats[pair]['other_size'])) + '\n')

    robust_pairs = {'Fast': [], 'Slow': [], 'Mixed': [], 'Not Robust': []}
    dict_robust_perc = {}
    dict_robust_pval = {}

    for pair in pair_stats:
        if pair[1] not in dict_robust_perc:
            dict_robust_perc[pair[1]] = {}
            dict_robust_pval[pair[1]] = {}
        percs = pair_stats[pair]['perc']
        pvals = pair_stats[pair]['adj_pval']
        counter = 0
        for pval in pvals:
            if pval < 0.05:
                counter += 1
        # The threshold is the number of datasets in which there should be a significant change in speed.
        if counter >= threshold:
            # All the pairs should show translational speed change in the same direction to be considered robust
            if all(p > 0 for p in percs):
                robust_pairs['Slow'].append(pair)
                dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
                # Assigning a constant pvalue of 0 for significant pairs.
                dict_robust_pval[pair[1]][pair[0]] = 0
            elif all(p < 0 for p in percs):
                robust_pairs['Fast'].append(pair)
                dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
                # Assigning a constant pvalue of 0 for significant pairs.
                dict_robust_pval[pair[1]][pair[0]] = 0
            else:
                # Determine the trend in data
                # if at least the 'threshold' number of datasets show same direction
                if sum(p > 0 for p in percs) >= threshold:
                    updated_percs = []
                    for idx, val in enumerate(pvals):
                        if val < 0.05:
                            updated_percs.append(percs[idx])
                elif sum(p < 0 for p in percs) >= threshold:
                    updated_percs = []
                    for idx, val in enumerate(pvals):
                        if val < 0.05:
                            updated_percs.append(percs[idx])
                else:
                    updated_percs = percs
                # in addition to condition in above comment, if all significant data sets show same direction
                if all(p > 0 for p in updated_percs):
                    robust_pairs['Slow'].append(pair)
                    dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
                    # Assigning a constant p-value of 0 for significant pairs.
                    dict_robust_pval[pair[1]][pair[0]] = 0
                elif all(p < 0 for p in updated_percs):
                    robust_pairs['Fast'].append(pair)
                    dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
                    # Assigning a constant p-value of 0 for significant pairs.
                    dict_robust_pval[pair[1]][pair[0]] = 0
                else:
                    robust_pairs['Mixed'].append(pair)
                    dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
                    # Assigning a constant p-value of 0.5 for insignificant pairs.
                    dict_robust_pval[pair[1]][pair[0]] = 0.5
        else:
            robust_pairs['Not Robust'].append(pair)
            dict_robust_perc[pair[1]][pair[0]] = np.mean(percs)
            # Assigning a constant p-value of 0.5 for insignificant pairs.
            dict_robust_pval[pair[1]][pair[0]] = 0.5
    if control_factor:
        outf = open('Robust_aminoacid_pairs_' + control_factor+'_'+str(threshold) + condition + '_datasets.tab', 'w')
    else:
        outf = open('Robust_aminoacid_pairs_'+str(threshold)+condition+'_datasets.tab', 'w')
    outf.write('Pair\tPair_type\tAvg(Perc change)')
    for data_set in datsets:
        outf.write('\t Perc change('+str(data_set)+')')
    outf.write('\t')
    for data_set in datsets:
        outf.write('\t p-value('+str(data_set)+')')
    outf.write('\n')
    for pair_type in robust_pairs:
        for pair in robust_pairs[pair_type]:
            outf.write(pair + '\t' + pair_type + '\t' + str(np.mean(pair_stats[pair]['perc'])) + '\t' + '\t'.join(map(str, pair_stats[pair]['perc'])) + '\t' +
                       '\t'.join(map(str, pair_stats[pair]['adj_pval'])) + '\n')

    print 'Length of the dict robust perc is '+str(len(dict_robust_perc))
    print 'Length of the dict robust pvalue is '+str(len(dict_robust_pval))

    if compare_uncontrolled:
        fast_consistent = 0
        slow_consistent = 0
        fast_insignificant = 0
        slow_insignificant = 0
        inconsistent = 0
        with open(compare_uncontrolled) as f:
            f.next()
            for lines in f:
                fields = lines.strip().split('\t')
                pair = fields[0]
                pair_type = fields[1]
                if pair_type == 'Fast' and pair in robust_pairs[pair_type]:
                    fast_consistent += 1
                elif pair_type == 'Fast' and pair not in robust_pairs[pair_type]:
                    fast_insignificant += 1
                elif pair_type == 'Slow' and pair in robust_pairs[pair_type]:
                    slow_consistent += 1
                elif pair_type == 'Slow' and pair not in robust_pairs[pair_type]:
                    slow_insignificant += 1
                else:
                    inconsistent += 1
                    if pair_type == 'Fast' and pair in robust_pairs['Slow']:
                        print pair + ' is Slow when controlled for '+control_factor+' but fast when not controlled'
                    elif pair_type == 'Slow' and pair in robust_pairs['Fast']:
                        print pair + ' is Fast when controlled for '+control_factor+' but slow when not controlled'

        perc_cons = float(fast_consistent) * 100 / float(fast_consistent + fast_insignificant)
        perc_cons_slow = float(slow_consistent) * 100 / float(slow_consistent + slow_insignificant)
        print 'Control factor\tConsistent\tInconsistent\n'
        print control_factor, fast_consistent, fast_insignificant, perc_cons, slow_consistent, slow_insignificant, perc_cons_slow, inconsistent
        statfile.write(str(control_factor)+'\t'+str(fast_consistent)+'\t'+str(fast_insignificant)+'\t'+str(perc_cons)+'\t'+str(slow_consistent)+'\t'+str(slow_insignificant)+'\t'+
                       str(perc_cons_slow)+'\t'+str(inconsistent))
    # if control_factor != 'Not controlled for any factors':
    #     matrix_plot, legend_plot = plot_asite_psite_matrix(dict_robust_perc, dict_robust_pval, control_factor+' % Consistent with uncontrolled='+str(round(perc_cons,1))+'%', title_fontsize=12)
    # else:
    matrix_plot, legend_plot = plot_asite_psite_matrix(dict_robust_perc, dict_robust_pval, control_factor)
    # ppdf = pdf.PdfPages('Asite_Psite_plot_for_robust_pairs_'+str(threshold)+'_datasets.pdf')
    # ppdf.savefig(matrix_plot)
    # ppdf.savefig(legend_plot)
    # ppdf.close()
    return matrix_plot, dict_robust_perc, dict_robust_pval


# PLOT THE MATRIX FIGURE 1B
def plot_asite_psite_matrix(dict_asite_psite_metric, dict_asite_psite_adj_pval, title, txt=False, title_fontsize=18):
    # Open stats file to append the matrix stats
    stats_file = open('Summary_stats.tab', 'a')
    stats_file.write('\n\nSTATISTICS FOR AMINO ACIDS PAIRS\n')
    # Converting the dict of dicts into list of lists to be converted to a numpy matrix
    # One empty list is added for the '*' codon which needs to appear in the first row before Y, W etc
    list_of_lists = [[]]
    list_of_pval_adj = [[]]
    for aa, data in sorted(dict_asite_psite_metric.items(), reverse=True):
        psite_list = []
        psite_pval_adj_list = []
        for p_site, perc_change in sorted(data.iteritems()):
            try:
                if p_site == '*':
                    continue
                psite_list.append(perc_change)
                psite_pval_adj_list.append(dict_asite_psite_adj_pval[aa][p_site])
            except KeyError:
                psite_list.append(0)
                psite_pval_adj_list.append(1)
        if aa == '*':
            list_of_lists[0] = psite_list
            list_of_pval_adj[0] = psite_pval_adj_list
        else:
            list_of_lists.append(psite_list)
            list_of_pval_adj.append(psite_pval_adj_list)

    # print 'Length of list_of_lists', 'Length of list of pval adj'
    # print len(list_of_lists), len(list_of_pval_adj)
    # print list_of_pval_adj
    ap_matrix = np.array(list_of_lists)
    ap_matrix_pval = np.array(list_of_pval_adj)
    # print 'Matrix shape ',ap_matrix.shape
    # print 'P-val matrix shape ',ap_matrix_pval.shape
    '''
    np.savetxt('Asite_P-site_matrix.tab', ap_matrix, delimiter='\t')

    row_labels = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
                  '*']
    col_labels = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
                  '*']

    plt.matshow(ap_matrix)
    plt.xticks(range(21), col_labels)
    plt.yticks(range(21), row_labels)

    plt.savefig(title+.'png')
    '''
    # print ap_matrix
    ap_matrix_new = pd.DataFrame(ap_matrix, index=['*', 'Y', 'W', 'V', 'T', 'S', 'R', 'Q', 'P', 'N', 'M', 'L', 'K', 'I', 'H', 'G', 'F', 'E', 'D', 'C', 'A'],
                                 columns=['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I',  'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T',  'V', 'W', 'Y'])

    figure, pairs_stats = checkerboard_table(ap_matrix_new, ap_matrix_pval, txt=txt)
    plt.xlabel('P-site Amino Acid', fontsize=18)
    plt.ylabel('A-site Amino Acid', fontsize=18)
    figure2 = generate_legend_for_matrix()
    figure.suptitle(title, fontsize=title_fontsize)

    figure.savefig('Matrix.png', dpi=300)
    stats_file.write('Total number of amino acid pairs: '+str(pairs_stats['total_pairs'])+'\n')
    stats_file.write('Total number of fast amino acid pairs: ' + str(pairs_stats['fast_sig_pair']) + '\n')
    stats_file.write('Total number of slow amino acid pairs: ' + str(pairs_stats['slow_sig_pair']) + '\n')
    stats_file.write('Total number of insignificant amino acid pairs: '+str(pairs_stats['insig_pair'])+'\n')
    stats_file.write('Total number of amino acid pairs with less than 5 instances: ' + str(pairs_stats['insufficient']) + '\n')
    stats_file.close()
    return figure, figure2


def checkerboard_table(data, pval, fmt='{:.2f}', txt=False, xlabel='P-site Amino Acid', ylabel='A-site Amino Acid'):
    pairs_stats = {'total_pairs': 0, 'fast_sig_pair': 0, 'slow_sig_pair': 0, 'insig_pair': 0, 'insufficient': 0}

    fig, ax = plt.subplots()
    # ax.set_axis_off()
    fig.savefig('Original.png')
    # ax.set_axis_off()

    plt.setp(ax.spines.values(), visible=False)
    # ax.tick_params(left=True, labelleft=True, labelbottom=True, bottom=True, direction='out', length=4, width=1)
    ax.tick_params(left=False, labelleft=False, labelbottom=False, bottom=False, direction='out', length=4, width=1)
    # ax.set_xticks(range(0, 20))
    # ax.set_yticks(range(0, 21))
    # ax.set_xticklabels(data.columns)
    # ax.set_yticklabels(data.index)
    # ax.patch.set_visible(False)
    # ax.set_frame_on(False)
    ax.grid(False)
    # ax.xaxis.set_visible(False)
    # ax.yaxis.set_visible(False)
    tb = Table(ax, bbox=[0, 0, 1, 1])
    mpl.rcParams['grid.linewidth'] = 0.5
    nrows, ncols = data.shape
    width, height = 1.0 / ncols, 1.0 / nrows
    # Add cells
    for (i, j), val in np.ndenumerate(data):
        pairs_stats['total_pairs'] += 1
        # Index either the first or second item of bkg_colors based on
        # a checker board pattern
        if pval[i][j] < 0.05:
            if val < 0:
                pairs_stats['fast_sig_pair'] += 1
            else:
                pairs_stats['slow_sig_pair'] += 1
            if val < -50:
                color = 'blue'
            elif val < -25:
                color = 'green'
            elif val < -10:
                color = 'mediumseagreen'     # darkcyan
            elif val < 0:
                color = 'lightgreen'   # cyan
            elif val > 100:
                # color = 'red'
                color = 'maroon'
            elif val > 75:
                # color = 'darkorange'  red
                color = 'red'
            elif val > 50:
                color = 'tomato'  # darkorange
            elif val > 25:
                color = 'orange'
            elif val > 0:
                color = 'gold'   # yellow
            else:
                color = 'grey'
        elif pval[i][j] == 1:
            color = 'grey'
            pairs_stats['insufficient'] += 1
        else:
            color = 'silver'  # lightgrey
            pairs_stats['insig_pair'] += 1
        if txt:
            tb.add_cell(i, j, width, width, text=fmt.format(val), loc='center', facecolor=color)
        else:
            tb.add_cell(i, j, width, width, loc='center', facecolor=color)
    for key, cell in tb.get_celld().items():
        cell.set_linewidth(0.5)
    # Row Labels...
    for i, label in enumerate(data.index):
        tb.add_cell(i, -1, width, width, text=label, loc='right', edgecolor='none', facecolor='none')
    # Column Labels...
    for j, label in enumerate(data.columns):
        tb.add_cell(22, j, width, width, text=label, loc='center', edgecolor='none', facecolor='none',)
    tb.set_fontsize(8)
    ax.add_table(tb)
    ax.set_xlabel(xlabel, fontsize=16)
    ax.set_ylabel(ylabel, fontsize=16)
    ax.yaxis.set_label_coords(-0.05, 0.5)
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    ax.set_aspect(abs(x1-x0)/abs(y1-y0))
    # ax.legend(loc='upper left', frameon=False)

    return fig, pairs_stats


def generate_legend_for_matrix():
    fig, ax = plt.subplots()
    fig.savefig('Original.png')
    # ax.set_axis_off()
    ax.set_xlabel('P-site Amino Acid', fontsize=18)
    ax.set_ylabel('A-site Amino Acid', fontsize=18)
    ax.yaxis.set_label_coords(-0.05, 0.5)
    plt.setp(ax.spines.values(), visible=False)
    # ax.tick_params(left=True, labelleft=True, labelbottom=True, bottom=True, direction='out', length=4, width=1)
    ax.tick_params(left=False, labelleft=False, labelbottom=False, bottom=False, direction='out', length=4, width=1)
    # ax.set_xticks(range(0, 20))
    # ax.set_yticks(range(0, 21))
    # ax.set_xticklabels(data.columns)
    # ax.set_yticklabels(data.index)
    ax.patch.set_visible(False)
    ax.set_frame_on(False)
    ax.grid(False)
    # ax.xaxis.set_visible(False)
    # ax.yaxis.set_visible(False)
    tb = Table(ax, bbox=[0, 0, 1, 1])

    data = np.array([[-55, 0, 0, 0, -20, 0, 0, 0, 8, 0, 0, 0, 60, 0, 0, 0, 103, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [-44, 0, 0, 0, -8, 0, 0, 0, 30, 0, 0, 0, 80, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [-20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [-8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [30, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

    pval = np.array([[0.01, 1, 1, 1, 0.01, 1, 1, 1, 0.01, 1, 1, 1, 0.01, 1, 1, 1, 0.01, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 0.01, 1, 1, 1, 0.01, 1, 1, 1, 0.01, 1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [0.01, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                     [3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

    nrows, ncols = data.shape
    width, height = 1.0 / ncols, 1.0 / ncols
    # Add cells
    for (i, j), val in np.ndenumerate(data):
        # Index either the first or second item of bkg_colors based on
        # a checker board pattern
        if pval[i][j] < 0.05:
            if val < -50:
                color = 'blue'
            elif val < -25:
                color = 'green'
            elif val < -10:
                color = 'mediumseagreen'     # darkcyan
            elif val < 0:
                color = 'lightgreen'   # cyan
            elif val > 100:
                # color = 'red'
                color = 'maroon'
            elif val > 75:
                # color = 'darkorange'  red
                color = 'red'
            elif val > 50:
                color = 'tomato'  # darkorange
            elif val > 25:
                color = 'orange'
            elif val > 0:
                color = 'gold'   # yellow
            else:
                color = 'white'
        elif pval[i][j] == 1:
            color = 'white'

        else:
            color = 'silver'  # lightgrey
        # if pval[i][j] < 0.05:
        #     if val < -50:
        #         color = 'blue'
        #     elif val < -25:
        #         color = 'green'
        #     elif val < -10:
        #         color = 'darkcyan'
        #     elif val < 0:
        #         color = 'cyan'
        #     elif val > 100:
        #         # color = 'red'
        #         color = 'darkred'
        #     elif val > 75:
        #         # color = 'darkorange'
        #         color = 'red'
        #     elif val > 50:
        #         color = 'darkorange'
        #     elif val > 25:
        #         color = 'orange'
        #     elif val > 0:
        #         color = 'yellow'
        #
        # elif pval[i][j] == 1:
        #     color = 'white'
        #
        # elif pval[i][j] == 2:
        #     color = 'lightgrey'
        # elif pval[i][j] == 3:
        #     color = 'grey'

        if color == 'white':
            tb.add_cell(i, j, width, width, edgecolor=color, loc='center', facecolor=color)
        else:
            tb.add_cell(i, j, width, width, loc='center', facecolor=color, edgecolor=color)

    # # Row Labels...
    # for i, label in enumerate(data.index):
    #     tb.add_cell(i, -1, width, width, text=label, loc='right',
    #                 edgecolor='none', facecolor='none')
    # # Column Labels...
    # for j, label in enumerate(data.columns):
    #     tb.add_cell(22, j, width, width, text=label, loc='center',
    #                 edgecolor='none', facecolor='none', )
    tb.set_fontsize(14)
    ax.add_table(tb)
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    ax.set_aspect(abs(x1 - x0) / abs(y1 - y0))
    # ax.legend(loc='upper left', frameon=False)

    return fig


# R2 response for more stats related to skew of the distributions
def psite_sig_pairs(infile):
    sig_pairs_list = []
    with open(infile) as f:
        f.next()
        for lines in f:
            fields = lines.strip().split('\t')
            psite_aa = fields[1]
            asite_aa = fields[0]
            adj_pvalue = float(fields[4])
            if adj_pvalue < 0.05:
                sig_pairs_list.append((psite_aa, asite_aa))
    return sig_pairs_list


# PLOT THE MATRIX subplot for all 6 datasets  (FIGURE S1)
def plot_figS1_subplot_matrix_datasets(dict_plots):
    # Creating a figure object
    fig = plt.figure(figsize=(5, 7))

    # Reducing the distance by half (default is 6) between the title and the subplot
    rcParams['axes.titlepad'] = 3

    ax1 = fig.add_subplot(321)
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Williams'][0], dict_plots['Williams'][1])

    ax1, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax1, xlabel=False)
    ax1.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax1.set_title('Williams', fontsize=8)
    ax1.text(-0.12 * ax1.get_xlim()[1], 1.025 * ax1.get_ylim()[1], 'a', fontsize=10, fontweight='bold')

    ax2 = fig.add_subplot(322)
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Jan'][0], dict_plots['Jan'][1])

    ax2, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax2, xlabel=False)  # Having ylabel for all subplots to avoid dead space, ylabel=False)
    ax2.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax2.set_title('Jan', fontsize=8)
    ax2.text(-0.1225 * ax2.get_xlim()[1], 1.0125 * ax2.get_ylim()[1], 'b', fontsize=10, fontweight='bold')

    ax3 = fig.add_subplot(323, adjustable='box-forced')
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Nissley1'][0], dict_plots['Nissley1'][1])

    ax3, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax3, xlabel=False)
    ax3.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax3.set_title('Nissley Replicate 1', fontsize=8)
    ax3.text(-0.12 * ax3.get_xlim()[1], 1.0275 * ax3.get_ylim()[1], 'c', fontsize=10, fontweight='bold')

    ax4 = fig.add_subplot(324)
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Nissley2'][0], dict_plots['Nissley2'][1])

    ax4, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax4, xlabel=False)  # , ylabel=False)
    ax4.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax4.set_title('Nissley Replicate 2', fontsize=8)
    ax4.text(-0.12 * ax4.get_xlim()[1], 1.015 * ax4.get_ylim()[1], 'd', fontsize=10, fontweight='bold')

    ax5 = fig.add_subplot(325)
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Weinberg'][0], dict_plots['Weinberg'][1])

    ax5, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax5)
    ax5.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax5.set_title('Weinberg', fontsize=8)
    ax5.text(-0.12 * ax5.get_xlim()[1], 1.025 * ax5.get_ylim()[1], 'e', fontsize=10, fontweight='bold')

    ax6 = fig.add_subplot(326)
    ap_matrix, ap_matrix_pval = get_matrix_dataframe(dict_plots['Young'][0], dict_plots['Young'][1])

    ax6, pairs_stats = checkerboard_table_subplot(ap_matrix, ap_matrix_pval, ax=ax6)  # , ylabel=False)
    ax6.tick_params(width=1, length=4, axis='both', which='major', labelsize=6, pad=2)
    ax6.set_title('Young', fontsize=8)
    ax6.text(-0.12 * ax6.get_xlim()[1], 1.0135 * ax6.get_ylim()[1], 'f', fontsize=10, fontweight='bold')

    # plt.tight_layout()
    plt.savefig('FigureS1.png', dpi=600, pad_inches=0)


def checkerboard_table_subplot(data, pval, ax, xlabel_txt='P-site Amino Acid', ylabel_txt='A-site Amino Acid', xlabel=True, ylabel=True):
    pairs_stats = {'total_pairs': 0, 'fast_sig_pair': 0, 'slow_sig_pair': 0, 'insig_pair': 0, 'insufficient': 0}

    plt.setp(ax.spines.values(), visible=False)
    # ax.tick_params(left=True, labelleft=True, labelbottom=True, bottom=True, direction='out', length=4, width=1)
    ax.tick_params(left=False, labelleft=False, labelbottom=False, bottom=False, direction='out', length=4, width=1)
    # ax.set_xticks(range(0, 20))
    # ax.set_yticks(range(0, 21))
    # ax.set_xticklabels(data.columns)
    # ax.set_yticklabels(data.index)
    # ax.patch.set_visible(False)
    # ax.set_frame_on(False)
    ax.grid(False)
    # ax.xaxis.set_visible(False)
    # ax.yaxis.set_visible(False)
    tb = Table(ax, bbox=[0, 0, 1, 1])
    mpl.rcParams['grid.linewidth'] = 0.5
    nrows, ncols = data.shape
    width, height = 1.0 / ncols, 1.0 / nrows
    # Add cells
    for (i, j), val in np.ndenumerate(data):
        pairs_stats['total_pairs'] += 1
        # Index either the first or second item of bkg_colors based on
        # a checker board pattern
        if pval[i][j] < 0.05:
            if val < 0:
                pairs_stats['fast_sig_pair'] += 1
            else:
                pairs_stats['slow_sig_pair'] += 1
            if val < -50:
                color = 'blue'
            elif val < -25:
                color = 'green'
            elif val < -10:
                color = 'mediumseagreen'     # darkcyan
            elif val < 0:
                color = 'lightgreen'   # cyan
            elif val > 100:
                # color = 'red'
                color = 'maroon'
            elif val > 75:
                # color = 'darkorange'  red
                color = 'red'
            elif val > 50:
                color = 'tomato'  # darkorange
            elif val > 25:
                color = 'orange'
            elif val > 0:
                color = 'gold'   # yellow
            else:
                color = 'grey'
        elif pval[i][j] == 1:
            color = 'grey'
            pairs_stats['insufficient'] += 1
        else:
            color = 'silver'  # lightgrey
            pairs_stats['insig_pair'] += 1

        tb.add_cell(i, j, width, width, loc='center', facecolor=color)
    for key, cell in tb.get_celld().items():
        cell.set_linewidth(0.5)
    # Row Labels...
    for i, label in enumerate(data.index):
        tb.add_cell(i, -1, width, width, text=label, loc='right', edgecolor='none', facecolor='none')
    # Column Labels...
    for j, label in enumerate(data.columns):
        tb.add_cell(22, j, width, width, text=label, loc='left', edgecolor='none', facecolor='none',)
    tb.set_fontsize(8)
    ax.add_table(tb)
    if xlabel:
        ax.set_xlabel(xlabel_txt, fontsize=8)
    if ylabel:
        ax.set_ylabel(ylabel_txt, fontsize=8)
    ax.yaxis.set_label_coords(-0.05, 0.5)
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    ax.set_aspect(abs(x1-x0)/abs(y1-y0))
    # ax.legend(loc='upper left', frameon=False)

    return ax, pairs_stats


def compare_individual_aa_pairs(dict_aa_psite, min_dist_length=5, perc_diff=True, plot_perc=True):
    # Open stats file to append the stats for comparison
    stats_file = open('Summary_stats.tab', 'a')

    if not os.path.exists('pickle_dicts/'):
        os.makedirs('pickle_dicts/')

    # The effect size of the diff between two dist can be measured either as a percent difference or percent change.
    # For comparing two amino acid pairs, it is not proper to choose one as a reference. Hence most likely we will measure a percent difference between them
    if perc_diff:
        outf = open('Asite_Psite_perc_diff_new_Psite.tab', 'w')
        log_file = open('asite_psite_perc_difference.log', 'w')
    else:
        outf = open('Asite_Psite_perc_change_new_Psite.tab', 'w')
        log_file = open('asite_psite_perc_change.log', 'w')

    # Initializing a dict for metrics to store for each pair of A-site and P-site
    dict_effect_size = {}
    dict_pvalues = {}
    dict_times_list = {}

    palette_colors = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),
                      (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
    sns.set()
    sns.set_palette(palette_colors)

    # Initializing the inner dict for each A-site aa as key
    for aa in AMINO_ACIDS:
        dict_effect_size[aa] = {}
        dict_pvalues[aa] = {}
        for psite in AMINO_ACIDS:
            # Stop codon cannot be in P-site
            if psite != '*':
                dict_effect_size[aa][psite] = {}
                dict_pvalues[aa][psite] = {}
    for aa, dict_psite in dict_aa_psite.iteritems():
        # For aa in A-site
        dict_times_list[aa] = {}
        # for P-site aa and list of translation times and other details
        for psite_aa, trans_list in dict_psite.iteritems():
            # Ignoring an impossible case
            if psite_aa == '*':
                continue
            # trans_list is a list of all values for each instance of trans time of the format [trans_time, gene, codon number, P-site codon type, A-site codon type]
            times_list = []
            # the times_list will contain only float values of translation times
            for ttime in trans_list:
                times_list.append(ttime[0])
            log_file.write('A-site Amino acid '+aa+'\n')
            log_file.write('P-site Amino acid '+psite_aa+'\n')
            log_file.write('Length of trans_list'+str(len(times_list))+'\n')

            dict_times_list[aa][psite_aa] = times_list
            if len(times_list) >= min_dist_length:
                # for all other combinations of P-site aa with this A-site aa, we will calculate the % difference
                for psite_new, trans_list_new in dict_psite.iteritems():
                    if psite_aa != psite_new and psite_new != '*':
                        times_list_new = []
                        for trans_time in trans_list_new:
                            times_list_new.append(trans_time[0])
                        # Compare the distributions and calculate the percentage difference between the medians
                        if len(times_list_new) >= 5:
                            u, p = stats.mannwhitneyu(times_list, times_list_new)

                            # Printing out test cases which are the mutated samples and also plotting their distributions
                            if (aa, psite_aa, psite_new) in [('R', 'N', 'S'), ('Q', 'W', 'D'), ('T', 'W', 'D'), ('T', 'M', 'G')]:
                                # The plot for ('R', 'N', 'S') is used for Fig 1B
                                print 'Median('+psite_aa+'-'+aa+')\tMedian('+psite_new+'-'+aa+')'
                                print np.median(times_list), np.median(times_list_new)
                                Pickle.dump(times_list, open('pickle_dicts/'+aa+'_'+psite_aa+'_instances.p', 'wb'))
                                Pickle.dump(times_list_new, open('pickle_dicts/'+aa+'_'+psite_new+'_instances.p', 'wb'))
                                plot_trans_distribution(aa, psite_aa, psite_new, 'pickle_dicts/')

                            # Variance of the distributions
                            var_orig = np.var(times_list)
                            var_mut = np.var(times_list_new)

                            log_file.write('Mann Whitney U test with '+psite_new+' in the P-site and '+aa+' in the A-site is:\n')
                            log_file.write(str(u)+'\t'+str(p)+'\n')

                            # Choosing the effect size
                            if perc_diff:
                                effect_size = math.fabs(np.median(times_list_new) - np.median(times_list))*100/((np.median(times_list)+np.median(times_list_new))/2)
                            else:
                                effect_size = ((np.median(times_list_new) - np.median(times_list)) / np.median(times_list)) * 100

                            # The odds of getting the speed change in the same direction as the difference of the medians.
                            odds = odds_speed_change_aa(times_list, times_list_new)

                            if odds == -1:
                                print 'Zero division error for odds for comparison of '+psite_aa+'_'+aa+' with '+psite_new+'_'+aa

                            # This dict is storing for every pair of P-site-Asite, when P-site is mutated to new P-site, what is the perc change in trans time of the medians
                            dict_effect_size[aa][psite_aa][psite_new] = [effect_size, len(times_list), len(times_list_new), var_orig, var_mut, odds]
                            dict_pvalues[aa][psite_aa][psite_new] = p

                        else:
                            var_orig = np.var(times_list)
                            var_mut = np.var(times_list_new)
                            odds = odds_speed_change_aa(times_list, times_list_new)
                            if odds == -1:
                                print 'Zero division error for odds for comparison of '+psite_aa+'_'+aa+' with '+psite_new+'_'+aa+'. Sample size of new psite is less than 5'
                            dict_effect_size[aa][psite_aa][psite_new] = [0, len(times_list), len(times_list_new), var_orig, var_mut, odds]  # 'Sample_less_than_5'
                            dict_pvalues[aa][psite_aa][psite_new] = 2  # 'Sample_less_than_5'
            else:
                for psite_new, trans_list_new in dict_psite.iteritems():
                    if psite_new != psite_aa and psite_new != '*':
                        times_list_new = []
                        for trans_time in trans_list_new:
                            times_list_new.append(trans_time[0])
                        var_orig = np.var(times_list)
                        var_mut = np.var(times_list_new)
                        odds = odds_speed_change_aa(times_list, times_list_new)
                        if odds == -1:
                            print 'Zero division error for odds for comparison of ' + psite_aa + '_' + aa + ' with ' + psite_new + '_' + aa + '. Sample size of old psite is less than 5'
                        dict_effect_size[aa][psite_aa][psite_new] = [0, len(times_list), len(times_list_new), var_orig, var_mut, odds]  # 'Sample_less_than_5'
                        dict_pvalues[aa][psite_aa][psite_new] = 2  # 'Sample_less_than_5'

    # Benjamini-Hochberg correction. We get all the p-values and pool them together in a list and adjust it
    list_of_pval = []
    for aa, data in sorted(dict_effect_size.iteritems()):
        for p_site, data_lower in sorted(data.iteritems()):
            for psite_new in sorted(data_lower):
                if dict_pvalues[aa][p_site][psite_new] == 2:
                    continue
                else:
                    list_of_pval.append(dict_pvalues[aa][p_site][psite_new])

    hyp_test, pval_adj, alpsidac, alpbonf = mc.multipletests(list_of_pval, method='fdr_bh')

    dict_aa_psite_pval_adj = {}
    pairs_stats = {'total_pairs': 0, 'fast_sig_pair': 0, 'slow_sig_pair': 0, 'insig_pair': 0, 'insufficient': 0}
    candidates = [['G', 'P', 'E'], ['T', 'G', 'S'], ['G', 'G', 'S'], ['R', 'N', 'S'], ['G', 'D', 'F'], ['R', 'S', 'N'], ['T', 'S', 'G'], ['G', 'S', 'G'], ['D', 'Q', 'P'], ['D', 'E', 'P']]

    odds_list = []
    odds_list_insig = []
    outf.write('A-site\tP-site\tNew P-site\tPercent Difference\tOdds\tInstances(P-site)\tInstances(New P-site)\tVariance(P-site)\tVariance(New P-site)\tp-value\t'
               'adjusted p-value\tSignificance\n')
    # Extract back the adjusted p-values into corresponding cells
    i = 0
    print 'Writing output file for P-site new psite comparisons'
    for aa, data in sorted(dict_effect_size.iteritems()):
        dict_aa_psite_pval_adj[aa] = {}
        for p_site, data_lower in sorted(data.iteritems()):
            dict_aa_psite_pval_adj[aa][p_site] = {}
            for psite_new, data_under in sorted(data_lower.iteritems()):
                effect_size, len_orig, len_mut, var_orig, var_mut, odds = data_under
                pairs_stats['total_pairs'] += 1
                if dict_pvalues[aa][p_site][psite_new] == 2:
                    dict_aa_psite_pval_adj[aa][p_site][psite_new] = 1
                    pairs_stats['insufficient'] += 1
                    odds_list_insig.append(odds)
                    significance = 'Not Significant'
                else:
                    dict_aa_psite_pval_adj[aa][p_site][psite_new] = pval_adj[i]
                    i += 1
                    if dict_aa_psite_pval_adj[aa][p_site][psite_new] > 0.05:
                        pairs_stats['insig_pair'] += 1
                        odds_list_insig.append(odds)
                        significance = 'Not Significant'
                    else:
                        significance = 'Significant'
                        if effect_size < 0:
                            pairs_stats['fast_sig_pair'] += 1
                        else:
                            pairs_stats['slow_sig_pair'] += 1
                            odds_list.append(odds)
                    aa_pair = [aa, p_site, psite_new]
                    if aa_pair in candidates:
                        stats_file.write(aa+'\t'+p_site+'\t'+psite_new+'\t'+str(effect_size)+'%\t'+str(dict_pvalues[aa][p_site][psite_new])+'\t'+str(dict_aa_psite_pval_adj[aa][p_site][psite_new]) + '\n')
                outf.write(aa+'\t'+p_site+'\t'+psite_new+'\t'+str(effect_size)+'%\t'+str(odds)+'\t'+str(len_orig)+'\t'+str(len_mut)+'\t'+str(var_orig)+'\t'+str(var_mut)+'\t' +
                           str(dict_pvalues[aa][p_site][psite_new])+'\t'+str(dict_aa_psite_pval_adj[aa][p_site][psite_new]) + '\t' + significance + '\n')
    outf.close()
    # Plot the odds of translation speed change for pairs of amino acids. This plot is used for Fig 1D
    fig, ax3 = plt.subplots()
    binsize = np.arange(0, 8, 0.25)
    ax3 = sns.distplot(sorted(odds_list)[:-10], ax=ax3, kde=False, label="Significant pairs", norm_hist=True, bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    # If we also have to plot the odds for the insignificant pairs
    # ax3 = sns.distplot(sorted(odds_list_insig)[:-10], ax=ax3, kde=True, label="Insignificant pairs", norm_hist=True, bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    # ax3.legend(fontsize=20).set_visible(False)
    ax3.tick_params(direction='out', axis='both', width=1, length=4, which='major', labelsize=16, pad=2, bottom=True, left=True)
    ax3.set_xlabel('Odds of translation rate change', fontsize=16)
    ax3.set_ylabel('Probability Density', fontsize=16)
    plt.tight_layout()
    fig.savefig('Odds_of_speed_change_significant.png', dpi=300)

    if plot_perc:
        plot_perc_change_aa('Asite_Psite_perc_diff_new_Psite.tab', perc_diff=perc_diff)

    stats_file.write('\n\nSTATISTICS FOR CHANGE IN TRNASLATION SPEED BETWEEN AMINO ACIDS PAIRS\n')
    stats_file.write('Total number of amino acid mutation pairs: ' + str(pairs_stats['total_pairs']) + '\n')
    stats_file.write('Total number of fast amino acid mutating pairs: ' + str(pairs_stats['fast_sig_pair']) + '\n')
    stats_file.write('Total number of slow amino acid mutating pairs: ' + str(pairs_stats['slow_sig_pair']) + '\n')
    stats_file.write('Total number of insignificant amino acid mutating  pairs: ' + str(pairs_stats['insig_pair']) + '\n')
    stats_file.write('Total number of amino acid mutating pairs with less than 5 instances: ' + str(pairs_stats['insufficient']) + '\n')
    stats_file.close()


# Plotting the normalized ribosome density distributions for two pairs of amino acids
def plot_trans_distribution(aa, psite_aa, psite_new, infolder, plot_both=True):
    palette_colors = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.3333333333333333, 0.6588235294117647, 0.40784313725490196)]
    # (0.8666666666666667, 0.5176470588235295, 0.3215686274509804)]
    plt.style.use('seaborn-white')
    sns.set_palette(palette_colors)
    times_list = Pickle.load(open(infolder+aa+'_'+psite_aa+'_instances.p', 'rb'))
    times_list_new = Pickle.load(open(infolder+aa+'_'+psite_new+'_instances.p', 'rb'))

    # plt.style.use('seaborn-white')
    fig, ax2 = plt.subplots()
    # sns.distplot(times_list, ax=ax2, kde=True, label=psite_aa + '_' + aa)
    # sns.distplot(times_list_new, ax=ax2, kde=True, label=psite_new + '_' + aa)
    N = max(max(set(times_list)), max(set(sorted(times_list_new)[:-10])))
    # Pickle.dump(times_list, open('R_N_instances.p', 'wb'))
    # Pickle.dump(times_list_new, open('R_S_instances.p', 'wb'))
    binsize = np.arange(0, N + 1, 0.25)
    # ax2 = sns.distplot(times_list, ax=ax2, kde=True, norm_hist=True, hist=False,label=None, bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    # ax2 = sns.distplot(times_list_new, ax=ax2, kde=True, norm_hist=True, hist=False,label=None, bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    ax2 = sns.distplot(times_list, ax=ax2, kde=True, norm_hist=True, label='{'+psite_aa + '-' + aa+'}', bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    print 'Plotting trans distributions with changes'
    # If plot an overlap of the distribution of the second pair of amino acid.
    if plot_both:
        ax2 = sns.distplot(times_list_new, ax=ax2, kde=True, norm_hist=True, label='{'+psite_new + '-' + aa+'}', bins=binsize, hist_kws=dict(edgecolor="black", linewidth=1))
    # ax2.set_ylim(0,1.02)
    ax2.set_xlim(-0.8, 7)
    ax2.legend(fontsize=16, loc='upper right')
    # labels = np.arange(ax2.get_ylim()[0]*0.25, ax2.get_ylim()[1]*0.25)
    # ylabel_list = list(ax2.get_yticklabels())
    # print ylabel_list
    # ax2.set_yticklabels(labels)
    ax2.tick_params(direction='out', axis='both', width=1, length=4, which='major', labelsize=16, pad=2)
    ax2.set_xlabel('Normalized Ribosome Density', fontsize=16)
    ax2.set_ylabel('Density', fontsize=16)
    if plot_both:
        fig.savefig(aa + '_' + psite_aa + '_' + psite_new + '.png', bbox_inches='tight', dpi=300)
    else:
        fig.savefig(aa + '_' + psite_aa + '.png', bbox_inches='tight', dpi=300)


def plot_perc_change_aa(perc_file, perc_diff=True, outf='Perc_difference_of_P-site.png', compare_with_uncontrolled=False, white=False):
    plot_data = pd.read_table(perc_file)
    fig, ax1 = plt.subplots()
    if white:
        plt.style.use('seaborn-white')
    else:
        sns.set()
    palette_colors = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), (0.8666666666666667, 0.5176470588235295, 0.3215686274509804),
                      (0.5490196078431373, 0.5490196078431373, 0.5490196078431373)]
    # green (0.3333333333333333, 0.6588235294117647, 0.40784313725490196),]
    # plt.style.use('seaborn-white')
    sns.set_palette(palette_colors)
    # ax1 = plt.subplot2grid(shape=(4, 20), loc=(0, 0), colspan=4, rowspan=2)
    # ax1.set(yscale='log')
    list_perc_change = [float(x.split('%')[0]) for x in plot_data['Percent Difference']]
    list_pvals = [-np.log(val) for val in plot_data['adjusted p-value']]
    if compare_with_uncontrolled:
        list_sig = list(plot_data['Comparison with uncontrolled'])
        for idx, val in enumerate(list_sig):
            if val == '(Insufficient sample size) Not Significant after controlling for confounding factors':
                list_sig[idx] = 'Not Significant after filtering out Wobble decoding pairs'
            elif val == 'Not Significant after controlling for confounding factor':
                list_sig[idx] = 'Not Significant after filtering out Wobble decoding pairs'
        order_hue = ['Significant', 'Not Significant', 'Not Significant after filtering out Wobble decoding pairs']
    else:
        list_sig = list(plot_data['Significance'])
        order_hue = ['Significant', 'Not Significant']
    # mapping = {'Significant' : 'blue', 'Not-significant' : 'red'}
    print len(list_perc_change), len(list_pvals), len(list_sig)
    ax1 = sns.scatterplot(list_perc_change, list_pvals, hue=list_sig, hue_order=order_hue, ax=ax1, s=5, linewidth=0.05)  # , hue_order=['Significant', 'Not-significant'])
    # ax1.set_title('FAST', fontsize=16)
    ax1.set_ylabel('-log (p-value)', fontsize=14)
    if perc_diff:
        #  % difference in translation rate
        ax1.set_xlabel('% difference in median ' + r'$\rho$', fontsize=14)
    else:
        ax1.set_xlabel('% change in translation rate', fontsize=14)
    # ax1.xaxis.label.set_visible(False)
    ax1.tick_params(width=1, length=4, axis='x', which='major', labelsize=9, pad=2, direction='out', bottom=True)
    ax1.tick_params(width=1, length=2, axis='y', which='major', labelsize=9, pad=1, direction='out', left=True)
    ax1.axhline(y=2.9957, linestyle='--', color='black')
    ax1.legend(fontsize=10, loc='upper left')
    ax1.xaxis.set_ticks(np.arange(0, 145, 10))
    ax1.yaxis.set_ticks(np.arange(0, 121, 20))
    # ax1.xaxis.set_ticks(np.arange(math.floor(ax1.get_xlim()[0]), math.ceil(ax1.get_xlim()[1]), 10))
    fig.savefig(outf, dpi=300)


def odds_speed_change_aa(times1, times2):
    list_diff = []
    for time1 in times1:
        for time2 in times2:
            diff = time2-time1
            list_diff.append(diff)
    perc_change = ((np.median(times2) - np.median(times1)) / np.median(times1)) * 100
    if perc_change > 0:
        try:
            odds = float(sum(i > 0 for i in list_diff))/float(sum(i < 0 for i in list_diff))
        except ZeroDivisionError:
            odds = -1
    else:
        try:
            odds = float(sum(i < 0 for i in list_diff))/float(sum(i > 0 for i in list_diff))
        except ZeroDivisionError:
            odds = -1
    return odds


def bin_norm_dens_plot_aa_psite_freq(times_dict, sig_pairs_file, codon_type_dict):
    sig_pairs = []
    with open(sig_pairs_file) as s:
        for lines in s:
            fields = lines.strip().split('\t')
            pair = fields[0]
            status = fields[1]
            if status in ['Fast', 'Slow']:
                sig_pairs.append(pair)

    dict_amino_acids = {}
    dict_aa_class_psite = {}

    for aa in AMINO_ACIDS:
        dict_amino_acids[aa] = []
        dict_aa_class_psite[aa] = {}
        for psite_aa in AMINO_ACIDS:
            # Stop codon cannot be in P-site
            if psite_aa == '*':
                continue
            dict_aa_class_psite[aa][psite_aa] = []

    # We will store P-site aa for each gene and codon position in psite_aa_dict
    psite_aa_dict = {}
    asite_aa_dict = {}

    total_sig_norm_dens = []
    total_aa_dist = []
    # Get all the aa info by translating codon_type_dict codons to corresponding amino acids
    for gene, dict_time in times_dict.iteritems():
        psite_aa_dict[gene] = {}
        asite_aa_dict[gene] = {}
        for codon, trans_time in enumerate(dict_time):
            try:
                # Ignoring the first two codons
                if codon in [0, 1]:
                    continue
                # Get the P-site aa for that codon
                psite_aa = genetic_code[codon_type_dict[gene][codon - 1]]
                # Get the A-site aa for that codon
                asite_aa = genetic_code[codon_type_dict[gene][codon]]
                # dict_amino_acids will have trans time for each amino acid
                # Ignore instances which have zero reads. This will most likely happen when we are using instances from constant set of genes which may not have necessarily met the filtering criteria.
                if trans_time > 0:
                    # Time based on translation time calculation. Otherwise normalized ribosome density will be used.
                    dict_amino_acids[asite_aa].append(float(trans_time) / 200)
                    # dict_aa_class_psite will have a dict of p-site and t-times for all a-site aa. This is a dictionary initialized before for all combo of aa
                    # dict_aa_class_psite[Asite_AA][P-site_AA] = [trans_time, gene, A-site codon number, P-site codon type, A-site codon type]
                    dict_aa_class_psite[asite_aa][psite_aa].append((float(trans_time) / 200, gene, codon + 1, codon_type_dict[gene][codon - 1], codon_type_dict[gene][codon]))
                    # if psite_aa+asite_aa in sig_pairs:
                    total_sig_norm_dens.append(float(trans_time) / 200)
                    total_aa_dist.append(psite_aa)
            except KeyError:
                print gene, codon

    total_sig_norm_dens_array = np.asarray(total_sig_norm_dens)
    percentile_dict = {}
    psite_bin = {}
    for i in range(0, 110, 10):
        percentile_dict[i] = np.percentile(total_sig_norm_dens_array, i)
        if i < 100:
            psite_bin[str(i)+'-'+str(i+10)] = []

    # for pair in sig_pairs:
    #     psite_aa = pair[0]
    #     asite_aa = pair[1]
    for asite_aa in AMINO_ACIDS:
        for psite_aa in AMINO_ACIDS:
            if psite_aa == '*':
                continue
            for val in dict_aa_class_psite[asite_aa][psite_aa]:
                for i in range(0, 100, 10):
                    if percentile_dict[i] <= val[0] < percentile_dict[i+10]:
                        psite_bin[str(i)+'-'+str(i+10)].append(psite_aa)

    print 'Length of psite_bin is '+str(len(psite_bin))
    if len(psite_bin) < 30:
        print 'Psite bin contains the following keys: '+str(psite_bin.keys())

    outf = open("Percentile_ranges_Psite_aa_frequencies.tab", "w")
    for perc_range in psite_bin:
        for psite_aa in set(psite_bin[perc_range]):
            psite_freq = psite_bin[perc_range].count(psite_aa)*100/len(psite_bin[perc_range])
            outf.write(perc_range+'\t'+str(psite_aa)+'\t'+str(psite_freq)+'\n')

    outf.close()

    aa_prob = open('Total_AA_prob_across_gene_subset.tab', 'w')
    for aa in AMINO_ACIDS:
        if aa == '*':
            continue
        aa_freq = total_aa_dist.count(aa)*100/len(total_aa_dist)
        aa_prob.write(aa+'\t'+str(aa_freq)+'\n')
    aa_prob.close()


def plot_lineplot(infile, prob_file):
    dict_prob = {}
    with open(prob_file) as f:
        for lines in f:
            fields = lines.strip().split('\t')
            dict_prob[fields[0]] = float(fields[1])
    full_data = pd.read_table(infile, sep='\t')
    plot_data = full_data[full_data['AA'].isin(['P', 'N', 'D', 'S', 'V', 'E'])]
    plt.figure()
    palette_vals = sns.color_palette()
    ax = sns.lineplot(x='Perc_range', y='freq', data=plot_data, hue='AA', hue_order=['S', 'V', 'E', 'D', 'N', 'P'], style='AA', dashes=False, markers=True)
    ax.set_xlabel('Percentile range', fontsize=7)
    ax.set_ylabel('Frequency (%)', fontsize=7)
    # ax2.yaxis.label.set_visible(False)
    ax.tick_params(width=1, length=4, axis='x', which='major', labelsize=6, left=True, bottom=True)
    # ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
    ax.tick_params(width=1, length=4, axis='y', which='major', labelsize=6, pad=2)
    ax.axhline(dict_prob['S'], color=palette_vals[0], linestyle='-')
    ax.axhline(dict_prob['V'], color=palette_vals[1], linestyle='-')
    ax.axhline(dict_prob['E'], color=palette_vals[2], linestyle='-')
    ax.axhline(dict_prob['D'], color=palette_vals[3], linestyle='-')
    ax.axhline(dict_prob['N'], color=palette_vals[4], linestyle='-')
    ax.axhline(dict_prob['P'], color=palette_vals[5], linestyle='-')
    plt.xticks(rotation=45)
    # Shrink current axis by 20%
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

    # Put a legend to the right of the current axis
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    # ax.legend().set_visible(True)
    plt.savefig('Binned_plot_AA_freq.png', bbox_inches='tight', dpi=300)


SyntaxError: invalid syntax (<ipython-input-10-d64749e2b55e>, line 9)