# Notebook for estimating probabilistic fitness

## Snakemake input

In [None]:
clades = snakemake.params.clades
cluster = snakemake.params.cluster
counts_df = snakemake.input.counts_df
output = snakemake.output.cluster_counts

## Import packages

In [1]:
import pandas as pd

## Import cladewise mutations table

In [4]:
muts_by_clade = pd.read_csv(counts_df, low_memory=False)

In [5]:
muts_by_clade

Unnamed: 0,clade,nt_site,nt_mutation,exclude,masked_in_usher,expected_count,actual_count,clade_founder_nt,gene,clade_founder_codon,...,mut_type,mut_class,pre_omicron_or_omicron,nt_site_before_boundary,ss_prediction,unpaired,motif,ref_motif,predicted_count,tau_squared
0,20A,1,A1C,True,True,0.96873,0,A,noncoding,noncoding,...,AC,noncoding,pre_omicron,False,unpaired,1,AAT,,1.738421,0.433274
1,20A,1,A1G,True,True,3.60910,0,A,noncoding,noncoding,...,AG,noncoding,pre_omicron,False,unpaired,1,AAT,,6.308562,0.736113
2,20A,1,A1T,True,True,1.27820,0,A,noncoding,noncoding,...,AT,noncoding,pre_omicron,True,unpaired,1,AAT,,0.629081,0.635531
3,20A,2,T2A,True,True,0.90342,0,T,noncoding,noncoding,...,TA,noncoding,pre_omicron,False,unpaired,1,ATT,ATT,0.205190,0.520788
4,20A,2,T2C,True,True,3.58890,0,T,noncoding,noncoding,...,TC,noncoding,pre_omicron,False,unpaired,1,ATT,ATT,6.277356,0.719985
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2422138,23I,29902,A29902G,True,True,3.66450,0,A,noncoding,noncoding,...,AG,noncoding,omicron,False,nd,0,AAA,AAA,2.526564,0.983941
2422139,23I,29902,A29902T,True,True,1.28520,0,A,noncoding,noncoding,...,AT,noncoding,omicron,False,nd,0,AAA,AAA,1.421710,0.640234
2422140,23I,29903,A29903C,True,True,0.41880,0,A,noncoding,noncoding,...,AC,noncoding,omicron,False,nd,0,AAA,,0.559207,0.341904
2422141,23I,29903,A29903G,True,True,3.66450,0,A,noncoding,noncoding,...,AG,noncoding,omicron,False,nd,0,AAA,,2.526564,0.983941


Ignore sites that are annotated as being masked in any clade of the UShER tree (`masked_in_usher == True`), or are annotated for exclusion (`exclude == True`)

In [6]:
muts_by_clade = muts_by_clade.query('not exclude').query('not masked_in_usher')

## Aggregate counts for clades in cluster

In [19]:
group_cols = ['nt_mutation', 'gene', 'codon_site', 'aa_mutation', 'synonymous', 'noncoding']

In [None]:
muts_by_clade_cluster = (muts_by_clade
    .query("clade.isin(@clades)")              # Selecting clades
    .groupby(group_cols, as_index=False)    # Columns not be aggregated
    .aggregate(                             # Aggregating counts
        expected_count = pd.NamedAgg('expected_count', 'sum'),
        predicted_count = pd.NamedAgg('predicted_count', 'sum'),
        actual_count = pd.NamedAgg('actual_count', 'sum'),
        tau_squared = pd.NamedAgg('tau_squared', 'sum')
    )
)
muts_by_clade_cluster.insert(0, 'nt_site', muts_by_clade_cluster['nt_mutation'].apply(lambda x: int(x[1:-1])))
muts_by_clade_cluster.insert(0, 'cluster', cluster)
muts_by_clade_cluster = muts_by_clade_cluster.sort_values('nt_site').reset_index(drop=True) # Ordering by nucleotide site
# Save dataframe
muts_by_clade_cluster.to_csv(output, index=False)