# Generating output figures

**A note on Birth Death Skyline Models**
The results in this notebook are from a phylodynamics pipeline using Birth Death Skyline Models. Reading material on Birth Death Skyline Models can be found at:
* [Taming the BEAST Tutorial: Skylineplots](https://taming-the-beast.org/tutorials/Skyline-plots/) 
* [Stadler et al. 2012 PNAS](https://www.pnas.org/doi/full/10.1073/pnas.1207965110)

Variants run are chosen from the internal Nowcasting resutls that estimate recent variant proportions of major (often grouped) lineages and estimate their relative growth coefficients (see RV Intel crunch). The variants growing the fastest where chosen for phylodynamic modelling, along with the main resident lineage that is being replaced. 

Bayesian BD skyline serial models were run on each variant separately to infer variant specific parameters (e.g. variant specific effective reproductive number over time, Re).   

In [None]:
save_dir = None
chosen_samples = 'all'
collection_date_field = 'date'

In [None]:
from beast_pype.outputs import (read_strain_logs_for_plotting, plot_comparative_box_violine, plot_skyline, plot_comparative_origin, hdi_pivot)
from beast_pype.date_utilities import date_to_decimal
import scipy
import pandas as pd
from datetime import datetime
import json
import warnings
import os
import matplotlib.pylab as plt  

## Other setup things
# stop annoying matplotlib warnings
warnings.filterwarnings("ignore", module="matplotlib\*")
if save_dir is None:
    save_dir = os.getcwd()
outputs4cw_dir = f"{save_dir}/outputs4CW"
if not os.path.exists(outputs4cw_dir):
    os.makedirs(outputs4cw_dir)

### Date the pipeline that produced this report was launched:

In [None]:
display(save_dir.split('/')[-1].split('_')[0])

The strains used and sample sizes in the analyses on data from the chooen date is displayed below.

*Note* DR stands for Dominant Resident strain & VOI stands for Variant of Interest (a newly emerging strain, i.e. "invader").

In [None]:
with open(save_dir + "/pipeline_run_info.json", "r") as file:
    data = file.read()
file.close()
pipeline_run_info = json.loads(data)

if chosen_samples == 'all':
    xml_sets = [path for path in os.listdir() if os.path.isdir(path) and path not in ['.ipynb_checkpoints', 'outputs4CW']]
else:
    xml_sets = chosen_samples

In [None]:
records = []
youngest_tip_dates = {}
trace_path_dict = {}
for xml_set in xml_sets:
    log_file = f'{save_dir}/{xml_set}/merged.log'
    csv_file = f'{save_dir}/{xml_set}/merged_log.csv'
    if os.path.isfile(log_file):
        trace_path_dict[xml_set] = log_file
    else:
        trace_path_dict[xml_set] = csv_file
    dowmsampled_metadata_path = f'{save_dir}/{xml_set}/down_sampled_metadata.csv'
    if os.path.exists(dowmsampled_metadata_path):
        metadata = pd.read_csv(dowmsampled_metadata_path, parse_dates=['date'])
        available_sample_size = len(pd.read_csv(f'{save_dir}/{xml_set}/metadata.csv'))
    else:
        metadata = pd.read_csv(f'{save_dir}/{xml_set}/metadata.csv', parse_dates=['date'])
        available_sample_size = len(metadata)
        
    youngest_tip_dates[xml_set] = metadata[collection_date_field].max()
    records.append(
        {'Type of Strain': xml_set.split('_')[0],
         'Strain': xml_set.split('_')[1], 
         'Sample Size': len(metadata),
         'Available samples': available_sample_size})

sample_info = pd.DataFrame.from_records(records)
sample_info.to_csv(f"{outputs4cw_dir}/sample_info.csv", index=False)
sample_info

Note the sub varients indluded within strains in the table above:

In [None]:
strain_sub_varients = pipeline_run_info["strain sub-variants"]
strain_sub_varients = {strain: sub_varients for strain, sub_varients in strain_sub_varients.items() if strain in sample_info.Strain.unique()}
with open(f"{outputs4cw_dir}/strain_sub_varients.json", 'w') as fp:
    json.dump(strain_sub_varients, fp, sort_keys=True, indent=4)
fp.close()
display(strain_sub_varients)

In [None]:
# This cell retrieves all the log files for the samples you selected.       
df, df_melted_for_seaborn = read_strain_logs_for_plotting(
    file_path_dict=trace_path_dict,
    convert_become_uninfectious_rate=True,
    youngest_tips_dict=youngest_tip_dates)

## Infection Period 

BD Skyline models estimate the rate of becoming uninfectious (whose inverse if the average infection period). 

In [None]:
gamma_prior_params = {'a': 5.921111111111111, 'loc': 0, 'scale': 12.32876712328767}
prior = scipy.stats.gamma(**gamma_prior_params)
yearly_rate_prior_draw = prior.rvs(size=int(1e5))
daily_rate_prior_draw = yearly_rate_prior_draw / 365.25
inf_period_draw = 1/daily_rate_prior_draw

In [None]:
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'Infection period (per day)', prior_draws=inf_period_draw)
plt.savefig(f"{outputs4cw_dir}/infection_period_box_violin_with_prior.png")
display(ax)

In [None]:
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'Infection period (per day)')
plt.savefig(f"{outputs4cw_dir}/infection_period_box_violin.png")
display(ax)

In [None]:
infection_period_hdi_df = hdi_pivot(df, 'Infection period (per day)')
infection_period_hdi_df.to_csv(f"{outputs4cw_dir}/infection_period_hdi_df.csv", index=False)
display(infection_period_hdi_df)

## Clock Rate

The evolutionary substitution rate (i.e. clock rate) is estimated. The prior given here is based on a strict clock estimate for SARS-CoV-2. Note clock rate's unit is 'number of nucleotide substitutions per site per year'.

In [None]:
gamma_prior_params = {'a': 1.7777777777777781, 'scale': 0.00022499999999999994}
prior = scipy.stats.gamma(**gamma_prior_params)
subs_rate_prior_draw = prior.rvs(size=int(1e5)) 
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'clockRate', prior_draws=subs_rate_prior_draw)
plt.savefig(f"{outputs4cw_dir}/clockRate_box_violin_with_prior.png")


In [None]:
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'clockRate')
plt.savefig(f"{outputs4cw_dir}/clockRate_box_violin.png")


In [None]:
clock_rate_hdi_df = hdi_pivot(df, 'clockRate')
clock_rate_hdi_df.to_csv(f"{outputs4cw_dir}/clock_rate_hdi_df.csv", index=False)
display(clock_rate_hdi_df)

# Sampling Proportion

In [None]:
beta_prior_params = {'a': 1, 'b': 999}
prior = scipy.stats.beta(**beta_prior_params)
sampling_prior_draw= prior.rvs(size=int(1e5)) 
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'samplingProportion_BDSKY_Serial', prior_draws=sampling_prior_draw)
plt.savefig(f"{outputs4cw_dir}/sampling_proportion_box_violin_with_prior.png")


In [None]:
ax = plot_comparative_box_violine(df_melted_for_seaborn, 'samplingProportion_BDSKY_Serial')
plt.savefig(f"{outputs4cw_dir}/sampling_proportion_box_violin.png")


In [None]:
sampling_proportion_hdi_df =hdi_pivot(df, 'samplingProportion_BDSKY_Serial')
sampling_proportion_hdi_df.to_csv(f"{outputs4cw_dir}/sampling_proportion_hdi_df.csv", index=False)
display(sampling_proportion_hdi_df)

# $R_T$


## True Skyline

The effective reproductive number, Re, is estimated in serial intervals for each variant. Note that for computational speed, the resident variant less resolution is given prior to the arrival of the newly emerging lineages (if of interest this could be changed). 

**Note** Lower values are 0.05 Highest Posterior Density (HPD), higher values are 0.95 HPD.

In [None]:
type_change_date_dicts = {sample: [datetime.strptime(date_str, '%Y-%m-%d') for date_str in date_list]
                          for sample, date_list in pipeline_run_info['Re change dates'].items()}
sample_change_date_dicts = {xml_set: type_change_date_dicts[xml_set.split('_')[0]] for xml_set in xml_sets}
youngest_tip_year_decimals = {key: date_to_decimal(value) for key,value in youngest_tip_dates.items()}

In [None]:
{key: len(value) for key, value in sample_change_date_dicts.items()}

In [None]:
fig, axs = plot_skyline(df,
                        parameter_start='reproductiveNumber',
                        ylabel='$R_t$',
                        include_grid=True,
                        partition_dates=sample_change_date_dicts,
                        youngest_tip=youngest_tip_year_decimals)
plt.savefig(f"{outputs4cw_dir}/rt_true_skyline_fig.png")


## Mid points spline smoothed

In [None]:
fig, axs = plot_skyline(df,
                        parameter_start='reproductiveNumber',
                        ylabel='$R_t$',
                        include_grid=True,
                        partition_dates=sample_change_date_dicts,
                        youngest_tip=youngest_tip_year_decimals,
                        style='smooth spline')
plt.savefig(f"{outputs4cw_dir}/rt_smoothed_fig.png")


## Mid points spline smoothed with points

In [None]:
fig, axs = plot_skyline(df,
                        parameter_start='reproductiveNumber',
                        ylabel='$R_t$',
                        include_grid=True,
                        partition_dates=sample_change_date_dicts,
                        youngest_tip=youngest_tip_year_decimals,
                        style='smooth spline with mid-points')
plt.savefig(f"{outputs4cw_dir}/rt_smoothed_with_points_fig.png")


## Rt results in table form

In [None]:
from beast_pype.outputs import  _set_dates_skyline_plotting_df
from beast_pype.date_utilities import decimal_to_date
parameter_start='reproductiveNumber'
style='skyline'
r_t_plot_dfs_dict = {}
for sample in df['xml_set'].unique(): 
    temp_df= _set_dates_skyline_plotting_df(df[df['xml_set'] == sample],
                                            parameter_start=parameter_start,
                                            style=style,
                                            youngest_tip=youngest_tip_dates[sample],
                                            dates_of_change=sample_change_date_dicts[sample])
    
    temp_df['xml_set'] = sample
    start_df = temp_df[temp_df.index % 2 == 0]
    end_df = temp_df[temp_df.index % 2 != 0]
    start_df['Start of Period'] = start_df.year_decimal.map(decimal_to_date).map(lambda x: x.strftime('%Y-%m-%d'))
    start_df['End of Period'] = end_df.year_decimal.map(decimal_to_date).map(lambda x: x.strftime('%Y-%m-%d')).to_list()
    r_t_plot_dfs_dict[sample] = start_df

r_t_plot_df = pd.concat(r_t_plot_dfs_dict.values())
r_t_plot_df = r_t_plot_df[['xml_set', 'Start of Period', 'End of Period', 'lower', 'median', 'upper']]
r_t_plot_df = r_t_plot_df.set_index('xml_set')
r_t_plot_df.index = r_t_plot_df.index.str.split('_', expand=True)
r_t_plot_df.reset_index(inplace=True)
r_t_plot_df.columns = ['Type of Strain', 'Strain', 'Start of Period', 'End of Period', 'lower', 'median', 'upper']
r_t_plot_df.to_csv(f"{outputs4cw_dir}/rt_df.csv")
display(r_t_plot_df)

## Ratios of Rt in last time period


### Lower HPD (5%)
Rows are denominators and columns are numerators.

In [None]:
def _last_period_rt_ratio(rt_type, r_t_plot_dfs_dict):
    ratio_of_rts = []
    for xml_set in xml_sets:
        denominator = r_t_plot_dfs_dict[xml_set][rt_type].iloc[-1]
        ratio_of_rts.append({column:r_t_plot_dfs_dict[column][rt_type].iloc[-1]/denominator for column in xml_sets})

    ratio_of_rts = pd.DataFrame.from_records(ratio_of_rts)
    ratio_of_rts.index = xml_sets
    return ratio_of_rts

last_period_rt_ratio_lower_df = _last_period_rt_ratio('lower', r_t_plot_dfs_dict)
last_period_rt_ratio_lower_df.to_csv(f"{outputs4cw_dir}/last_period_rt_lower_df.csv", index=False)   
display(last_period_rt_ratio_lower_df)

### Median
Rows are denominators and columns are numerators.

In [None]:
last_period_rt_ratio_median_df = _last_period_rt_ratio('median', r_t_plot_dfs_dict)
last_period_rt_ratio_median_df.to_csv(f"{outputs4cw_dir}/last_period_rt_ratio_median_df.csv", index=False)
display(last_period_rt_ratio_median_df)

### Upper HPD (95%)
Rows are denominators and columns are numerators.

In [None]:
last_period_rt_ratio_upper_df = _last_period_rt_ratio('upper', r_t_plot_dfs_dict)
last_period_rt_ratio_upper_df.to_csv(f"{outputs4cw_dir}/last_period_rt_ratio_upper_df.csv", index=False)
display(last_period_rt_ratio_upper_df)

# Origin

The origin is the time at which the index case (the first Canadian case) became infected, which is slightly earlier than the time-to-the-most-recent-common-ancestor (tMRCA). This parameter is used to investigate the detection delay from emergence to first detection in Canada.

In [None]:
fig = plot_comparative_origin(df_melted_for_seaborn, one_figure=True)
plt.savefig(f"{outputs4cw_dir}/orign_overlayed_fig.png")


In [None]:
fig = plot_comparative_origin(df_melted_for_seaborn)
plt.savefig(f"{outputs4cw_dir}/orign_stacked_fig.png")

In [None]:
orign_hdi_df = hdi_pivot(df, 'Origin')
orign_hdi_df.to_csv(f"{outputs4cw_dir}/orign_hdi_df.csv", index=False)
orign_hdi_df.set_index('xml_set',inplace=True)
orign_hdi_df = orign_hdi_df.map(decimal_to_date)
for column in orign_hdi_df.columns:
    orign_hdi_df[column] = orign_hdi_df[column].dt.strftime('%Y-%m-%d')
display(orign_hdi_df)