# Australian AA Variants

## Part 01: Create Visualizations

**Run all of `00_ExtractTransform.ipynb` before using this notebook.**

Plots the Quarterly Incidences of Substitutions per Australian Region.

In [3]:
from pathlib import Path

import pandas as pd
import numpy as np

import helpers

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from IPython import display
%matplotlib inline

## Definitions

### Filepaths

In [None]:
DATA_DIR = Path('./data')
SUMMARY_DIR = DATA_DIR / 'summaries'
FIGURES_DIR = Path('./figures')

aa_subs_by_seq_fpath = SUMMARY_DIR / 'aa_subs_by_seq.csv'
aa_sub_counts_fpath = SUMMARY_DIR / 'aa_sub_counts_by_regional_period.csv'
seq_counts_fpath = SUMMARY_DIR / 'seq_counts_per_regional_period.csv'
aa_sub_proportions_fpath = SUMMARY_DIR / 'aa_sub_proportions.csv.xz'

#### DataFrame Column Names, Index Levels, and Ordinal Positions

In [1]:
header_multiindex=[0,1,2,3]
header_names = ['gene', 'position', 'aa_ref', 'aa_sub']

index_multiindex = [0,1,2,3,4]
index_names=['region', 'collection_date', 'virus_code', 'accession_id', 'time_period']
short_index_names = ['region', 'time_period']

#### Visualizations

In [None]:

# By North-to-South Order
north_to_south_regions = [
    'Northern Territory', 'Queensland', 'Western Australia', 'New South Wales', 
    'South Australia', 'Australian Capital Territory', 'Victoria', 'Tasmania'
]
period_names = [
    '2020 Q1', '2020 Q2', '2020 Q3', '2020 Q4', 
    '2021 Q1', '2021 Q2', '2021 Q3', '2021 Q4',
    '2022 Q1'
]

fig_height = 4

### Load Australian Sequence **AA Substitution Counts** by Regional Period

Returns the counts of each **AA Substitution** found in Australian sequences, partitioned by a combined Region and Period (e.g., New Queensland Q1)

In [4]:
## This is only if the substitution counts dataframe was not constructed
##      when running `00_ExtractTransform.ipynb`.

# if aa_subs_by_seq_fpath.exists():
#     aa_subs_df = pd.read_csv(
#         aa_subs_by_seq_fpath, 
#         header=header_multiindex, 
#         index_col=index_multiindex,
#     ).fillna(0).astype('int32')
#     aa_subs_df.index.names = index_names

# else:
#     emsg = f'`{aa_subs_by_seq_fpath}` cannot be found. Run `00_ExtractTransform.ipynb` first.'
#     raise FileNotFoundError(emsg)

# aa_sub_counts_df = aa_subs_df.droplevel([1,2,3]).groupby(short_index_names).sum()

In [None]:

## Load pre-made version instead
if aa_sub_counts_fpath.exists():
    aa_sub_counts_df = pd.read_csv(
        aa_sub_counts_fpath, 
        header=header_multiindex, 
        index_col=[0,1],
    ).fillna(0).astype('int32')
    aa_sub_counts_df.index.names = short_index_names

else:
    emsg = f'`{aa_sub_counts_fpath}` cannot be found. Run `00_ExtractTransform.ipynb` first.'
    raise FileNotFoundError(emsg)


In [None]:
aa_sub_counts_df.sample(5)

### Get Australian **Sequence Counts** by Regional Period

Load the total number of *sequences* available for each Australian region over a given period.\
This includes sequences which match the reference sequence (`NC_045512`),\
which are not included in the aa substitution dataframe.

In [None]:
%%time

seq_counts_df = pd.read_csv(
    seq_counts_fpath, 
    index_col=short_index_names
)

In [None]:
seq_counts_df.sample(5)

### Get Proportion of Amino Acid Substitutions per Time-Period

In [None]:
%%time

if aa_sub_proportions_fpath.exists():
    print(f'`{aa_sub_proportions_fpath}` already exists. Loading...')
    aa_sub_props_df = pd.read_csv(aa_sub_proportions_fpath, index_col=index_multiindex[:-1], header=header_multiindex[:-2])
else:
    print(f'`{aa_sub_proportions_fpath}` does not exist. Creating...')
    aa_sub_props_df = aa_sub_counts_df.div(seq_counts_df.values, axis='columns').astype('float64')
    # Replacing the 0s and 1s as strings minimizes the space by removing decimalized versions (ex.: 0.000000, 1.000000)
    aa_sub_props_df.replace({0: '0', 1: '1'}).T.to_csv(aa_sub_proportions_fpath)
    aa_sub_props_df = aa_sub_props_df.T

### Accessor Functions for AA Substitution Data

In [None]:
def get_gene_name_list(gene_df):
    """ Get list of gene names"""
    return list(set([gene for gene in gene_df.index.get_level_values('gene')]))


def get_by_gene(gene_df, gene):
    """ Given a gene dataframe, get entries by gene"""
    if (gene in gene_df.index.get_level_values('gene')):
        return gene_df.loc[(gene_df.index.get_level_values('gene').str.lower() == gene.lower())]


def filter_for_relevancy(_df):
    """ Remove variants that do not pass threshold in relevance. """
    return _df[_df.sum(axis=1) > 0.1].sort_index()


### Organize Dataframes by Gene 

In [None]:
gene_dfs = {
    gene_name: filter_for_relevancy(get_by_gene(aa_sub_props_df, gene_name))
    for gene_name in get_gene_name_list(aa_sub_props_df)
}

## Visualization

#### 1. Configure visual elements of plots


In [None]:
marker_map = {v: f"${v}$" for v in aa_sub_props_df.index.get_level_values('aa_sub').unique()}
colors_by_region = {name: helpers.Plotting.color_fader('r', 'c', i/(len(north_to_south_regions)-1)) for i, name in enumerate(north_to_south_regions)}

#### 2. Build functions for plot creation

In [None]:
def draw_aa_sub_by_region_plots(i, fig, grid_spec, aa_sub_name, sub_df):
    """ 
        Draw plots comparing regional differences by Amino Acid Substitution. 
    """
    ax = fig.add_subplot(grid_spec[0, (i*2):((i+1)*2)])
    
    aa_ref = sub_df.index.get_level_values('aa_ref').to_numpy()[0]
    aa_position = sub_df.index.get_level_values('position').to_numpy()[0]

    reg_group = sub_df.droplevel(['position', 'aa_ref']).groupby('region', axis=1)

    for region, reg_df in reg_group:
        _df = reg_df.droplevel('region', axis=1).T.astype(float).iloc[:, 0]
        incidence = np.empty((len(period_names)))
        incidence[:] = np.nan

        for idx, val in _df.items():
            incidence[int(idx) - 1] = val

        ax.plot(period_names, incidence, color=colors_by_region.get(region), label=region)

    ax.set_title(f"$\mathregular{{{aa_ref}\;{aa_position}\;{aa_sub_name}}}$", fontname='monospace', size=14)

    ax.set_xlim((-0.1, (len(period_names) - 0.9)))
    ax.set_xlabel('Quarter')
    
    ax.set_ylim((-0.01,1.01))
    ax.set_ylabel('Proportion')

    ax.set_aspect(4)
    ax.apply_aspect()
    
    ax.grid(alpha=0.5)
    
    return ax


def draw_gene_aa_substitution_plots(gene_name, positional_df, save=False):
    """
        Builds Region-Period Plots for each gene and position
    """
    if positional_df is None or positional_df.empty:
        return
    for pos, pos_df in positional_df.groupby('position'):
        
        aa_sub_group = pos_df.groupby('aa_sub')
        plot_count = len(aa_sub_group) + 1
        
        figsize=((plot_count ) * fig_height, fig_height)
        title = f'Amino Acid Substitutions for {gene_name} {pos}'

        fig = plt.figure(constrained_layout=True, figsize=((plot_count * 7), 5))
        grid_spec = gridspec.GridSpec(ncols=(plot_count * 2) - 1, nrows=1, figure=fig)
        
        axs = [
            draw_aa_sub_by_region_plots(i, fig, grid_spec, aa_sub_name, sub_df) # Returns axis
            for i, (aa_sub_name, sub_df) in enumerate(aa_sub_group)
        ]

        fig.suptitle(title, fontsize=16, weight='semibold')

        ax = fig.add_subplot(grid_spec[-1, -1])
        ax.legend(handles=[mpl.patches.Patch(color=colors_by_region.get(region), label=region) for region in north_to_south_regions], loc='center')
        ax.axis('off')
        
        if save:
            output_img_filename = FIGURES_DIR / 'by_gene' / f'{gene_name}' / f'{gene_name}_{pos}.pdf'
            output_img_filename.parent.mkdir(exist_ok=True, parents=True)
            fig.savefig(output_img_filename)
        fig.clear()
        

#### 3. Plot Incidence of AA Substitutions across Region-Periods

In [None]:
plt.clf()
for gene_name, gene_df in gene_dfs.items():
    draw_gene_aa_substitution_plots(gene_name, gene_df, save=True)
    display.clear_output(wait=True)
    display.display(plt.gcf())
display.clear_output(wait=True)