# Notebook for estimating probabilistic fitness

## Snakemake input

In [None]:
clades = snakemake.params.clades
cluster = snakemake.params.cluster
counts_df = snakemake.input.counts_df
output = snakemake.output.cluster_counts

## Import packages

In [None]:
import pandas as pd

## Import cladewise mutations table

In [None]:
muts_by_clade = pd.read_csv(counts_df, low_memory=False)

In [None]:
muts_by_clade

Ignore sites that are annotated as being masked in any clade of the UShER tree (`masked_in_usher == True`), or are annotated for exclusion (`exclude == True`)

In [None]:
muts_by_clade = muts_by_clade.query('not exclude').query('not masked_in_usher')

## Aggregate counts for clades in cluster

In [None]:
group_cols = ['nt_mutation', 'gene', 'codon_site', 'aa_mutation', 'synonymous', 'noncoding']

In [None]:
nucleotides = ['A', 'C', 'G', 'T']

In [None]:
muts_by_clade_cluster = (muts_by_clade
    .query("clade.isin(@clades)")              # Selecting clades
    .groupby(group_cols, as_index=False)    # Columns not be aggregated
    .aggregate(                             # Aggregating counts
        expected_count = pd.NamedAgg('expected_count', 'sum'),
        predicted_count = pd.NamedAgg('predicted_count', 'sum'),
        actual_count = pd.NamedAgg('actual_count', 'sum'),
        tau_squared = pd.NamedAgg('tau_squared', 'mean')
    )
)
muts_by_clade_cluster.insert(0, 'nt_site', muts_by_clade_cluster['nt_mutation'].apply(lambda x: int(x[1:-1])))
muts_by_clade_cluster.insert(0, 'cluster', cluster)
muts_by_clade_cluster['wt'] = pd.CategoricalIndex(muts_by_clade_cluster['nt_mutation'].apply(lambda x: x[0]), ordered=True, categories=nucleotides)
muts_by_clade_cluster['mut'] = pd.CategoricalIndex(muts_by_clade_cluster['nt_mutation'].apply(lambda x: x[-1]), ordered=True, categories=nucleotides)
muts_by_clade_cluster = muts_by_clade_cluster.sort_values(['nt_site', 'wt', 'mut']).reset_index(drop=True) # Ordering by site, wildtype and mutant nucleotides
muts_by_clade_cluster.drop(columns=['wt', 'mut'], inplace=True)
# Save dataframe
muts_by_clade_cluster.to_csv(output, index=False)