# Variational simulation study
## Setup

In [None]:
import yaml
import json
import pickle
import newick
import pymc3 as pm
import numpy as np
import scipy
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import random
import sys
import os
import templating
import variational_analysis
import topology_inference
import subprocess
import process_results
import util
import Bio
import Bio.Phylo

In [None]:
with open('config.yaml') as f:
    config = yaml.load(f)

config

In [None]:
util.do_seeding(config)
out_dir = config['out_dir']
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
build_templates = templating.TemplateBuilder(out_dir)

## Simulation

### Population size

In [None]:
pop_size_prior = scipy.stats.lognorm(scale=np.exp(config['prior_params']['pop_size']['m']), s=config['prior_params']['pop_size']['s'])
xs = np.arange(0, pop_size_prior.ppf(0.999), 0.001)
plt.plot(xs, pop_size_prior.pdf(xs))
pop_size_prior.ppf([0.025, 0.975])

In [None]:
beast_args = ['java'] + util.cmd_kwargs(jar=config['beast_jar'], seed=config['seed']) + ['-overwrite']
pop_size, taxon_names, date_trait_string = build_templates.build_tree_sim(config)
pop_size

### Sampling times

In [None]:
import statsmodels
date_trait_dict = topology_inference.parse_date_trait_string(date_trait_string)
sampling_times = list(date_trait_dict.values())
plt.scatter(sampling_times, np.zeros_like(sampling_times), alpha=0.5);
np.max(sampling_times) - np.min(sampling_times)

### Tree simulation

In [None]:
from io import StringIO

def run_beast(xml_path, **kwargs):
    result = subprocess.run(beast_args + [xml_path], **kwargs)
    if result.returncode != 0:
        print(result.stderr)
        print(result.stdout)
        raise RuntimeError('BEAST run failed')
    else:
        print('Ran BEAST ({0}) successfully'.format(xml_path))
    

run_beast(build_templates.tree_sim_out_path)
newick_string = build_templates.extract_newick_string(build_templates.tree_sim_result_path)
bio_tree = next(Bio.Phylo.parse(StringIO(newick_string), 'newick'))
Bio.Phylo.draw(bio_tree)
tree_height = max(bio_tree.depths().values())
tree_height

In [None]:
newick_string

In [None]:
run_summary = {
    'config': config,
    'pop_size': pop_size,
    'date_trait_string': date_trait_string,
    'newick_string': newick_string
}

with(open(build_templates.run_summary_path, 'w')) as f:
    yaml.dump(run_summary, f)

### Sequence simulation

In [None]:
build_templates.build_seq_sim(config, taxon_names, newick_string)
run_beast(build_templates.seq_sim_out_path)
sequence_dict = build_templates.extract_sequence_dict()
sequence_values = [pd.Series(list(x)) for x in sequence_dict.values()]
char_counts = pd.concat(sequence_values).value_counts()
char_counts / sum(char_counts)

In [None]:
from itertools import combinations

prop_differences = [np.mean(x != y) for x, y in combinations(sequence_values, 2)]
plt.hist(prop_differences);

## Inference

### Neighbour joining

In [None]:
nj_tree = topology_inference.get_neighbor_joining_tree(sequence_dict)
Bio.Phylo.draw(nj_tree)

### Rooting & dating

In [None]:
topology_inference.build_lsd_inputs(config, build_templates, nj_tree, date_trait_string)
subprocess.run([config['lsd_executable']] + topology_inference.get_lsd_args(build_templates))

In [None]:
lsd_tree = topology_inference.extract_lsd_tree(build_templates)    
analysis_newick_io = StringIO()
Bio.Phylo.write([lsd_tree], analysis_newick_io, format='newick')
analysis_newick = analysis_newick_io.getvalue()
    
fig, axs = plt.subplots(ncols=2, figsize=(20,10))
Bio.Phylo.draw(bio_tree, axes=axs[0], do_show=False)
axs[0].set_title('True tree')
Bio.Phylo.draw(lsd_tree, show_confidence=False, axes=axs[1], do_show=False)
axs[1].set_title('Estimated tree - Neighbour joining + LSD');
analysis_newick

### BEAST analysis (estimating tree)

In [None]:
build_templates.build_beast_analysis(util.update_dict(config, estimate_topology=True), analysis_newick, date_trait_string, sequence_dict)
run_beast(build_templates.beast_analysis_out_path)

In [None]:
with open(build_templates.beast_analysis_tree_path) as f:
    beast_trees = list(Bio.Phylo.parse(f, 'nexus'))
    
Bio.Phylo.draw(beast_trees[-1])

In [None]:
beast_trace_df = process_results.process_beast_trace(build_templates.beast_analysis_trace_path, config, burn_in=False)
beast_trace_df.plot(subplots=True);

In [None]:
# TODO: Effective sample size

In [None]:
true_values = {
    'tree_height': tree_height,
    'pop_size': pop_size,
    'kappa': config['kappa']
}

p_limits = np.array([0.025, 0.975])

In [None]:
def get_beast_quantiles(trace, ps):
    return np.stack([np.quantile(trace.values[int(i * config['burn_in']):i], ps, axis=0) for i in range(1, trace.shape[0])])

def plot_trace_hpd(quantiles, varnames, xs=None, plot_prior=False): # quantiles: iteration, p, var
    n_vars = quantiles.shape[2]
    if xs is None:
        xs = np.arange(quantiles.shape[0])
    fig, axs = plt.subplots(nrows=n_vars, figsize=(20, 20))
    for j in range(n_vars):
        ax = axs[j]
        varname = varnames[j]
        ax.set_ylabel(varname)
        
        ax.fill_between(xs, quantiles[:, 0, j], quantiles[:, 1, j], alpha=0.5, label='95% posterior interval')
        
        if plot_prior and varname in config['prior_params']:
            prior = scipy.stats.lognorm(scale=np.exp(config['prior_params'][varname]['m']), s=config['prior_params'][varname]['s'])
            ax.axhspan(*prior.ppf(p_limits), color='yellow', alpha=0.3, label='95% prior interval')
        
        ax.axhline(true_values[varname], color='green', label='True value')
        
        ax.legend()
plot_trace_hpd(get_beast_quantiles(beast_trace_df, p_limits), beast_trace_df.columns, xs=np.arange(beast_trace_df.shape[0] - 1)*config['log_every'])

### BEAST analysis (fixed tree)


In [None]:
beast_fixed_out_file = 'beast-analysis-fixed.xml'
beast_fixed_trace_file = 'beast-log-fixed.log'
beast_fixed_tree_file = 'beast-log-fixed.trees'

build_templates.build_beast_analysis(util.update_dict(config, estimate_topology=False),
                                     analysis_newick,
                                     date_trait_string,
                                     sequence_dict,
                                     out_file=beast_fixed_out_file,
                                     trace_file=beast_fixed_trace_file,
                                     tree_file=beast_fixed_tree_file
                                    )
run_beast(build_templates.out_path / beast_fixed_out_file)

In [None]:
beast_fixed_trace_df = process_results.process_beast_trace(build_templates.out_path / beast_fixed_trace_file, config, burn_in=False)
beast_fixed_trace_df.plot(subplots=True);

In [None]:
plot_trace_hpd(get_beast_quantiles(beast_fixed_trace_df, p_limits), beast_fixed_trace_df.columns, xs=np.arange(beast_fixed_trace_df.shape[0] - 1)*config['log_every'])

### Variational analysis (true tree)

In [None]:
tree = newick.loads(analysis_newick)[0]
model = variational_analysis.construct_model(config, tree, sequence_dict)
inference = variational_analysis.construct_inference(config, model)
print(model.logp(model.test_point))
model

In [None]:
tracker = pm.callbacks.Tracker(
   mean=inference.approx.mean.eval,
   std=inference.approx.std.eval
)

approx = inference.fit(config['n_iter'], callbacks=[tracker])

with open(build_templates.pymc_analysis_result_path, 'wb') as f:
    pickle.dump(tracker, f)

plt.plot(approx.hist)

In [None]:
rvs_dict = { rv.name: rv for rv in model.deterministics }
slices = { name: inference.approx.ordering.by_name[rv.transformed.name].slc for name, rv in rvs_dict.items() }
indices_dict = { 'tree_height': slices['tree'].stop - 1, 'pop_size': slices['pop_size'].start, 'kappa': slices['kappa'].start  }

means = np.stack(tracker.hist['mean'])
stds = np.stack(tracker.hist['std'])

varnames = list(indices_dict.keys())
indices = np.array(list(indices_dict.values()))

fig, axs = plt.subplots(ncols=len(tracker.hist), figsize=(20, 6))

for ax, (name, param) in zip(axs, tracker.hist.items()):
    ax.set_title(name)
    vals = np.stack(param)
    for varname, index in zip(varnames, indices):
        ax.plot(vals[:, index], label=varname)
    ax.legend()

In [None]:
from pylo.topology import TreeTopology
topology = TreeTopology(tree)

transformed_quantiles = scipy.stats.norm.ppf(p_limits[np.newaxis, :, np.newaxis],
                                             loc=means[:, np.newaxis, indices],
                                             scale=stds[:, np.newaxis, indices])

transforms = {
    'tree_height': lambda x: np.exp(x) + topology.get_max_leaf_height(),
    'kappa': np.exp,
    'pop_size': np.exp
}

quantiles = np.stack([transforms[varname](transformed_quantiles[:, :, i]) for i, varname in enumerate(varnames)], axis=-1)
plot_trace_hpd(quantiles, varnames)