# Notebook for estimating probabilistic fitness

## Import packages

In [19]:
import pandas as pd
import sys
import os
import json

In [2]:
# Adding module folder to system paths
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
from modules import probfit

## Import cladewise mutations table

In [4]:
muts_by_clade = pd.read_csv('../results/mut_counts_by_clade.csv', low_memory=False)

In [5]:
muts_by_clade

Unnamed: 0,clade,nt_site,nt_mutation,exclude,masked_in_usher,expected_count,actual_count,clade_founder_nt,gene,clade_founder_codon,...,mut_type,mut_class,pre_omicron_or_omicron,nt_site_before_boundary,ss_prediction,unpaired,motif,ref_motif,predicted_count,tau_squared
0,20A,1,A1C,True,True,0.96873,0,A,noncoding,noncoding,...,AC,noncoding,pre_omicron,False,unpaired,1,AAT,,1.738421,0.433274
1,20A,1,A1G,True,True,3.60910,0,A,noncoding,noncoding,...,AG,noncoding,pre_omicron,False,unpaired,1,AAT,,6.308562,0.736113
2,20A,1,A1T,True,True,1.27820,0,A,noncoding,noncoding,...,AT,noncoding,pre_omicron,True,unpaired,1,AAT,,0.629081,0.635531
3,20A,2,T2A,True,True,0.90342,0,T,noncoding,noncoding,...,TA,noncoding,pre_omicron,False,unpaired,1,ATT,ATT,0.205190,0.520788
4,20A,2,T2C,True,True,3.58890,0,T,noncoding,noncoding,...,TC,noncoding,pre_omicron,False,unpaired,1,ATT,ATT,6.277356,0.719985
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2422138,23I,29902,A29902G,True,True,3.66450,0,A,noncoding,noncoding,...,AG,noncoding,omicron,False,nd,0,AAA,AAA,2.526564,0.983941
2422139,23I,29902,A29902T,True,True,1.28520,0,A,noncoding,noncoding,...,AT,noncoding,omicron,False,nd,0,AAA,AAA,1.421710,0.640234
2422140,23I,29903,A29903C,True,True,0.41880,0,A,noncoding,noncoding,...,AC,noncoding,omicron,False,nd,0,AAA,,0.559207,0.341904
2422141,23I,29903,A29903G,True,True,3.66450,0,A,noncoding,noncoding,...,AG,noncoding,omicron,False,nd,0,AAA,,2.526564,0.983941


Ignore sites that are annotated as being masked in any clade of the UShER tree (`masked_in_usher == True`), or are annotated for exclusion (`exclude == True`)

In [6]:
muts_by_clade = muts_by_clade.query('not exclude').query('not masked_in_usher')

## Clades grouping

Here it is possible to define a customized grouping of clades. The statistics associated to the grouped clades will be aggregated, namely:
* **Predicted counts**
* **Actual counts**
* **Residual variance** of the synonymous counts model.
  
A separate table containing fitness effects and uncertainties of nucleotide mutations will be defined for each group.

The clade clustering is encoded into the `clades_cluster` dictionary, which is defined from:
* `clusters`: an array whose elements are arrays of clades to be grouped together.
* `clust_names`: an array containing reference names for each group.

By default, we aggregate together:
* *Ancestral*: 20A, 20B, 20C, 20E, 20G, 21C.
* *BA.1*: 21L, 22C.
* *22A-B*: 22A, 22B.
* *BA.2.75*: 22D, 23C.
* *XBB*: 22F, 23A, 23B, 23D, 23E, 23F.
* *Late*: 23I, 24A, 24B, 24C, 24E (so far only 23I is included in the tree).

In [7]:
clades = muts_by_clade.clade.unique()

In [8]:
clades

array(['20A', '20B', '20C', '20E', '20G', '20H', '20I', '20J', '21C',
       '21I', '21J', '21K', '21L', '22A', '22B', '22C', '22D', '22E',
       '22F', '23A', '23B', '23C', '23D', '23E', '23F', '23H', '23I'],
      dtype=object)

In [12]:
clusters = [['20A', '20B', '20C', '20E', '20G', '21C'],
            ['20H'],
            ['20I'],
            ['20J'],
            ['21I'],
            ['21J'],
            ['21K'],
            ['21L', '22C'],
            ['22A', '22B'],
            ['22D', '23C'],
            ['22F', '23A', '23B', '23D', '23E', '23F', '23H'],
            ['23I']]

In [13]:
clust_names = ['ancestral', '20H', '20I', '20J', '21I', '21J', '21K', 'BA.2', '22A-B', 'BA.2.75', 'XBB', 'late']

In [21]:
assert len(clusters) == len(clust_names)

12


In [17]:
clade_cluster = dict(zip(clust_names, clusters))

In [18]:
clade_cluster

{'ancestral': ['20A', '20B', '20C', '20E', '20G', '21C'],
 '20H': ['20H'],
 '20I': ['20I'],
 '20J': ['20J'],
 '21I': ['21I'],
 '21J': ['21J'],
 '21K': ['21K'],
 'BA.2': ['21L', '22C'],
 '22A-B': ['22A', '22B'],
 'BA.2.75': ['22D', '23C'],
 'XBB': ['22F', '23A', '23B', '23D', '23E', '23F', '23H'],
 'late': ['23I']}

In [20]:
# Convert `clade_cluster` dictionary to JSON and write to file
with open("../results/clade_clustering.json", "w") as outfile: 
    json.dump(clade_cluster, outfile)

## Compute probabilistic fitness estimates

In [19]:
group_cols = ['nt_mutation', 'gene', 'codon_site', 'aa_mutation', 'synonymous', 'noncoding']

For demonstrative purposes we only include here a subset of clusters for fitness determination

In [23]:
keys_to_include = {'ancestral', '21J', '21K'}
clade_cluster_subset = dict(filter(lambda item: item[0] in keys_to_include, clade_cluster.items()))
print(clade_cluster_subset)


{'ancestral': ['20A', '20B', '20C', '20E', '20G', '21C'], '21J': ['21J'], '21K': ['21K']}


In [24]:
# Cycling over dictionary entries
for key, val in clade_cluster_subset.items():
    muts_by_clade_cluster = (muts_by_clade
        .query("clade.isin(@val)")              # Selecting clades
        .groupby(group_cols, as_index=False)    # Columns not be aggregated
        .aggregate(                             # Aggregating counts
            expected_count = pd.NamedAgg('expected_count', 'sum'),
            predicted_count = pd.NamedAgg('predicted_count', 'sum'),
            actual_count = pd.NamedAgg('actual_count', 'sum'),
            tau_squared = pd.NamedAgg('tau_squared', 'sum')
        )
    )
    muts_by_clade_cluster.insert(0, 'nt_site', muts_by_clade_cluster['nt_mutation'].apply(lambda x: int(x[1:-1])))
    muts_by_clade_cluster.insert(0, 'cluster', key)
    muts_by_clade_cluster = muts_by_clade_cluster.sort_values('nt_site').reset_index(drop=True) # Ordering by nucleotide site
    # Add probabilistic fitness estimates to dataframe
    probfit.add_probabilistic_estimates(muts_by_clade_cluster, N_f=300)
    # Save dataframe
    muts_by_clade_cluster.to_csv(f'../results/ntmut_fitness/{key}_ntmut_fitness.csv', index=False)