### Enrichment calculation function

To calculate enrichment, I will follow some of the general principals used to calculate [DMS differential selection](https://jbloomlab.github.io/dms_tools2/diffsel.html#formula-for-differential-selection).

Conceptually, I am calculating the ratio of codon frequency in the selected condition compared to codon frequency in the reference condition. I add a pseudocount to accomodate missing codons in one of the two conditions. I take the log2 of the ratio.

**I am not normalizing codon frequency to the WT codon**, because I want to detect and display strong enrichment of the WT codon in the selected condition.

```Enrichment = log2( ((codon_count_selected + pseudocount) / (total_codon_count_selected)) /  
                   ((codon_count_reference + pseudocount) / (total_codon_count_reference)) )```

In [1]:
import pandas as pd
import numpy as np

In [2]:
def calculate_enrichment(input_df,
                         site,
                         wildtype_AA,
                         selected_sample,
                         reference_sample,
                         pseudocount=0.1):
    assert selected_sample in input_df['name'].unique(), \
        f"{selected_sample} not in input dataframe"
    assert reference_sample in input_df['name'].unique(), \
        f"{reference_sample} not in input dataframe"
    selected_freqs = (input_df
                      .query(f'name == "{selected_sample}" and '
                             f'site == {site}')
                      [['codon', 'wildtype', 'letter', 'count']])
    selected_freqs['count_pseudo'] = selected_freqs['count'] + pseudocount
    selected_wt = float(selected_freqs.query(f'letter == "{wildtype_AA}"')['count_pseudo'])
    
    reference_freqs = (input_df
                       .query(f'name == "{reference_sample}" and '
                              f'site == {site}')
                       [['codon', 'wildtype', 'letter', 'count']])
    reference_freqs['count_pseudo'] = reference_freqs['count'] + pseudocount
    reference_wt = float(reference_freqs.query(f'letter == "{wildtype_AA}"')['count_pseudo'])
    
    enrichment_df = pd.merge(
        left=selected_freqs,
        right=reference_freqs,
        on=['codon', 'letter'],
        how='outer',
        validate='one_to_one',
        suffixes=['_selected','_reference'])
    enrichment_df['count_pseudo_selected'] = (enrichment_df['count_pseudo_selected']
                                              .fillna(pseudocount))
    enrichment_df['count_pseudo_reference'] = (enrichment_df['count_pseudo_reference']
                                              .fillna(pseudocount))
    
    enrichment_df['enrichment'] = (
        np.log2((enrichment_df['count_pseudo_selected'] / selected_wt) /
                (enrichment_df['count_pseudo_reference'] / reference_wt)))
    
    return(enrichment_df)