# Objectives

1. Run mugration analysis on select traits.
1. Add mugration results to dataframe.
1. Plot confidence boxplots.
1. Plot colored trees.
1. Export:
   - Tree Dataframe (tsv)
   - Augur JSON
   - Auspice JSON

---
# Setup

## Module Imports

In [None]:
from Bio import Phylo
from treetime.utils import parse_dates
from treetime import wrappers

import copy
import os
import sys
import IPython
import io

import pandas as pd
import numpy as np
import math

import matplotlib.pyplot as plt
from matplotlib import colors, lines, patches, gridspec
import seaborn as sns

import dill

## Variables

In [None]:
from config import *

# Custom script variables
SCRIPT_NAME = "mugration"
PREV_DIR_NAME = "clock"
PREV_SCRIPT_NAME = "clock_model"

try:
    WILDCARDS = snakemake.wildcards
    project_dir = os.getcwd()
except NameError:
    WILDCARDS = ["all", "chromosome", "50"]
    project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
    
READS_ORIGIN = WILDCARDS[0]
LOCUS_NAME = WILDCARDS[1]
MISSING_DATA = WILDCARDS[2]

NAME_COL = "Name"

plt.rc('legend', frameon=False) # legend frame

## Input File Paths

In [None]:
results_dir = os.path.join(project_dir, "results")
config_dir = os.path.join(results_dir, "config")


outdir       = os.path.join(results_dir, SCRIPT_NAME, READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA),)
tree_dill    = os.path.join(results_dir, PREV_DIR_NAME, READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA),PREV_SCRIPT_NAME + "_timetree.treetime.obj")
tree_df_dill = os.path.join(results_dir, PREV_DIR_NAME, READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA), PREV_SCRIPT_NAME + ".df.obj")
tree_df_path = os.path.join(results_dir, PREV_DIR_NAME, READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA), PREV_SCRIPT_NAME + ".tsv")
aln_path     = os.path.join(results_dir,"snippy_multi",READS_ORIGIN,"snippy-core_{}.snps.filter{}.aln".format(LOCUS_NAME, MISSING_DATA))

# Auspice
auspice_latlon_path = os.path.join(results_dir, "parse_tree", READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA), "parse_tree" + "_latlon.tsv")
auspice_colors_path = os.path.join(results_dir, "parse_tree", READS_ORIGIN, LOCUS_NAME + "_filter{}".format(MISSING_DATA), "parse_tree" + "_colors.tsv")
auspice_config_path = os.path.join(config_dir, "auspice_config.json")
auspice_remote_dir_path = os.path.join(project_dir, "auspice/")

print("tree_dill:\t", tree_dill)
print("tree_df_dill:\t", tree_df_dill)
print("aln path:\t", aln_path)
print("auspice_latlon_path:", auspice_latlon_path)
print("auspice_colors_path:", auspice_colors_path)
print("auspice_config_path:", auspice_config_path)
print("auspice_remote_dir_path:", auspice_remote_dir_path)
print("outdir:", outdir)

# Create output directory if it doesn't exist
while not os.path.exists(outdir):
    os.makedirs(outdir)   
    
SCRIPT_NAME = "mugration_model"

## Import Tree

In [None]:
#tree_div = Phylo.read(tree_path, "newick")
#tree_div.ladderize(reverse=False)

with open(tree_dill, "rb") as infile:
    tt = dill.load(infile)
tt.tree.ladderize(reverse=False)

## Import Dataframe

In [None]:
with open(tree_df_dill, "rb") as infile:
    tree_df = dill.load(infile)
tree_df

## Import lat lon

In [None]:
latlon_df = pd.read_csv(auspice_latlon_path, sep='\t', header=None)
latlon_df.columns = ["Geo","Name","Lat","Lon"]
# Fix the problem with multiple forms of NA in the table
# Consolidate missing data to the NO_DATA_CHAR
latlon_df.fillna(NO_DATA_CHAR, inplace=True)
latlon_df

---
# 1. Run mugration analysis

## Setup Mugration Dictionary

In [None]:
mug_dict = {}


tt_copy = copy.deepcopy(tt)
# Set branch length to mutations
for n in tt_copy.tree.find_clades():
    n.branch_length=n.mutation_length
tree_div = tt_copy.tree

for attr in ATTRIBUTE_LIST:
    mug_dict[attr] = {}
    # Make a copy of the input tree
    mug_dict[attr]["tree_div"] = copy.deepcopy(tree_div)
    # Map the taxon name to the attribute for mugration analysis
    #mug_dict[attr]["leaf_to_attr"] = {sample:str(data[attr]) for sample,data in tree_df.iterrows()
    #                if data[attr]!=NO_DATA_CHAR and data[attr]}    
    mug_dict[attr]["leaf_to_attr"] = {sample:str(data[attr]) for sample,data in tree_df.iterrows()
                    if data[attr]!=NO_DATA_CHAR}     
    # Make blank dicts for the mugration output
    mug_dict[attr]["tree_mug"] = {}
    mug_dict[attr]["letter_to_state"] = {}
    mug_dict[attr]["reverse_alphabet"] = {}
    mug_dict[attr]["unique_states"] = {}

    
# Perform necessary type conversions

if "Branch_Number" in mug_dict.keys():
    for sample,attr_val in mug_dict["Branch_Number"]["leaf_to_attr"].items():
        float_to_str = str(math.ceil(float(attr_val)))
        mug_dict["Branch_Number"]["leaf_to_attr"][sample] = float_to_str

## Run mugration to capture log

In [None]:
for attr in ATTRIBUTE_LIST:
    # Initialize stdout capture
    print("Running mugration for attribute: {}".format(attr))    
    old_stdout = sys.stdout
    new_stdout = io.StringIO()
    sys.stdout = new_stdout
    
    # Run mugration
    mug, letter_to_state, reverse_alphabet = wrappers.reconstruct_discrete_traits(mug_dict[attr]["tree_div"], 
                                                                     traits=mug_dict[attr]["leaf_to_attr"], 
                                                                     missing_data=NO_DATA_CHAR,
                                                                     #missing_data="nan",
                                                                     #pc=pc, 
                                                                     #sampling_bias_correction=sampling_bias_correction, 
                                                                     verbose=4, 
                                                                     #weights=params.weights
                                                                    )
    mug_dict[attr]["tree_mug"] = mug
    mug_dict[attr]["letter_to_state"] = letter_to_state
    mug_dict[attr]["reverse_alphabet"] = reverse_alphabet
    mug_dict[attr]["unique_states"] = sorted(letter_to_state.values())
    
    # Save stdout to file
    output = new_stdout.getvalue()
    out_path = os.path.join(outdir, SCRIPT_NAME + "_{}.log".format(attr.lower())) 
    with open(out_path, "w") as file:
        file.write(output)
    # Restore stdout
    sys.stdout = old_stdout
    print("Standard output restored, logging to file disabled.")

---
# 2. Add mugration to dataframe

In [None]:
for attr in ATTRIBUTE_LIST:
    # Initialize empty values for the new mugration attribute and its confidence
    tree_df["Mugration_" + attr] = [NO_DATA_CHAR for row in range(0,len(tree_df))]
    tree_df["Mugration_" + attr + "_Confidence"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]    
    tree_df["Mugration_" + attr + "_Entropy"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
    
    tree_mug = mug_dict[attr]["tree_mug"].tree
    unique_states = mug_dict[attr]["unique_states"]
    
    # If this attribute is associated with lat,lon
    if attr + "Lat" in tree_df.columns:
        tree_df["Mugration_" + attr + "_Lat"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]
        tree_df["Mugration_" + attr + "_Lon"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]       

    # Iterate through the nodes in the tree
    for c in tree_mug.find_clades():
        # Store the name and confidence for the state with the most support
        state_max_name = mug_dict[attr]["letter_to_state"][c.cseq[0]]
        state_conf_list = c.marginal_profile[0]
        state_max_conf = max(state_conf_list)
        c.other = []

        # Store all the states and confidence values
        for state_name,state_conf in zip(unique_states, state_conf_list):
            attr_other = Phylo.PhyloXML.Other(tag=state_name, value="{:0.4f}".format(state_conf), namespace=attr)
            c.other.append(attr_other)

        # Store the Maximum assigned value
        attr_other = Phylo.PhyloXML.Other(tag=state_max_name, value="{:0.4f}".format(state_max_conf), namespace=attr + "_assign")
        c.other.append(attr_other)
    
        if c.name in tree_df.index:
            tree_df.at[c.name, "Mugration_" + attr] = state_max_name
            tree_df.at[c.name, "Mugration_" + attr + "_Confidence"] = state_max_conf
            
            prob_dist = c.marginal_profile[0]
            S = -np.sum(prob_dist * np.log(prob_dist + MUG_TINY))
            tree_df.at[c.name, "Mugration_" + attr + "_Entropy"] = S
        
        # Add mugration lat lon
        if attr + "Lat" in tree_df.columns:
            # Check for lat and lon
            c_geo = tree_df["Mugration_" + attr][c.name]
            c_geo_match = latlon_df[(latlon_df["Geo"] == attr.lower()) &
                            (latlon_df["Name"] == c_geo)
                           ]
            c_geo_lat = c_geo_match["Lat"].values[0]
            c_geo_lon = c_geo_match["Lon"].values[0]
            tree_df.at[c.name, "Mugration_" + attr + "_Lat"] = c_geo_lat
            tree_df.at[c.name, "Mugration_" + attr + "_Lon"] = c_geo_lon
            
tree_df

## Add Metadata as Comments

In [None]:
metadata_to_comment(tt.tree, tree_df)

---
# 3. Plot confidence boxplot

In [None]:
"""for attr in ATTRIBUTE_LIST:
    fig, (ax1, ax2) = plt.subplots(2, 
                               sharex=False, 
                               gridspec_kw={'hspace': 0},
                               figsize=figsize, 
                               dpi=dpi,
                               constrained_layout=True,
                              )
    
    # --------------------------
    # Axis 1 - Number of tips per state
    # Exclude samples with no attribute recorded
    data = tree_df[tree_df[attr] != NO_DATA_CHAR]
    label_order = list(data[attr].value_counts().index)
    label_order = [lab for lab in label_order if lab != "nan"] + ["nan"]
    
    sns.countplot(data=data, 
                  x=attr, 
                  #color="blue", 
                  edgecolor="black",
                  ax=ax1, 
                  order=label_order
                 )
    plt.setp(ax1.get_xticklabels(), visible=False)
    xticklabels = [item.get_text() for item in ax1.get_xticklabels()]
    ax1.set_xticklabels(xticklabels, rotation = 90, ha="right")
    ax1.set_xlabel("")
    ax1.set_ylabel("Number of Samples (Tips)")
    ax1.set_xlim(-1,len(label_order)) 
    ax1.set_title(attr.replace("_"," "))
    
    # --------------------------
    # Axis 2 - Mugration Confidence
    # Exclude nodes that are terminals (ie. branch support is grey)
    data = tree_df[tree_df["Branch_Support_Color"] != "grey"]
    
    # Fix typing
    if attr == "Branch_Number":
        label_order = [str(math.ceil(lab)) for lab in label_order if lab != "nan"] + ["nan"]

    # Customize outlier style
    flierprops = dict(marker='o', markerfacecolor='white', markersize=2,
                      linestyle='none', markeredgecolor='black')
    # Create a boxplot
    sns.boxplot(data=data, 
                  x="Mugration_" + attr, 
                  y="Mugration_" + attr + "_Confidence",
                  #color="blue", 
                  ax=ax2, 
                  order=label_order,
                  flierprops=flierprops)
    #plt.setp(ax2.get_xticklabels(), visible=False)
    xticklabels = [item.get_text() for item in ax2.get_xticklabels()]
    ax2.set_xticklabels(xticklabels, rotation = 90, ha="center")
    #ax2.axhline(y=MUG_CONF_THRESH, color=THRESH_COL, linewidth=0.5, linestyle='--')
    ax2.set_xlabel("")
    ax2.set_ylabel("Mugration Confidence (Internal Nodes)")
    ax2.set_xlim(-1,len(label_order))
    ax2.set_ylim(0, 1.1)
    
    fig.suptitle("Sampling Distribution and Mugration Analysis Confidence")
    
    out_path = os.path.join(outdir, SCRIPT_NAME + "_boxplot_{}.{}".format(attr.lower(), FMT)) 
    plt.savefig(out_path, 
            dpi=dpi, 
            bbox_inches = "tight")"""

---
# 4. Plot colored trees

## Color branches on mug/div tree according to state

In [None]:
"""out_path_colors = os.path.join(outdir, SCRIPT_NAME + "_colors.tsv")
file_colors = open(out_path_colors, "w")

for attr in ATTRIBUTE_LIST:  
    if attr != "Branch_Major" and attr != "Branch_Number":
        continue
    
    # --------------------------------------------   
    # Canvas
    fig, ax1 = plt.subplots(1, dpi=dpi, figsize=figsize, constrained_layout=True,)

    # --------------------------------------------
    # Color Tree
    hex_dict = color_tree(tree=tree_div, 
                          df=tree_df, 
                          attribute="Mugration_" + attr,
                          attribute_confidence="Mugration_" + attr + "_Confidence",
                          threshold_confidence=MUG_CONF_THRESH,                          
                          color_pal=CONT_COLOR_PAL)
    
    # Add the hex color dict to the dict
    mug_dict[attr]["hex_color"] = hex_dict  

    # Write to color file
    for state,color in hex_dict.items():
        file_colors.write(attr.lower() + "\t" + state + "\t" + color + "\n")    
        
    # Draw tree
    Phylo.draw(tree_div,
               axes=ax1, 
               show_confidence=False, 
               label_func = lambda x:'', 
               do_show=False)
    # --------------------------------------------
    # Draw tips
    colors = [mug_dict[attr]["hex_color"][state] for state in tree_df["Mugration_" + attr]]    
    ax1.scatter(data=tree_df, 
                x="coord_x", 
                y="coord_y", 
                s=0.5, 
                c=colors,
               )
    # --------------------------------------------
    # Ticks
    x_buffer = max(tree_df["coord_x"]) * 0.01
    y_buffer = math.ceil(len(tree_div.get_terminals()) * 0.01)   
    # --------------------------------------------
    # Limits
    ax1.set_xlim(0 - x_buffer, max(tree_df["coord_x"]) + x_buffer)    
    ax1.set_ylim(len(tree_div.get_terminals()) + y_buffer, 0 - y_buffer)
    ax1.set_yticks([])    
    # --------------------------------------------
    # Labels
    ax1.set_xlabel("Branch Length")
    ax1.set_ylabel("")
    fig.suptitle("{} Mugration".format(attr.replace("_"," ")))    
    # --------------------------------------------
    # Legend
    legend_elements = [patches.Patch(facecolor=value, edgecolor=value,) for value in mug_dict[attr]["hex_color"].values()]
    legend_labels = list(mug_dict[attr]["hex_color"].keys())
    ax1.legend(legend_elements, 
              legend_labels,
              bbox_to_anchor=(1.0, 1.0), 
              loc='upper left',
              frameon=False,
             )
    # --------------------------------------------
    # Export
    out_path = os.path.join(outdir, SCRIPT_NAME + "_tree_{}.{}".format(attr.lower(), FMT))  
    plt.savefig(out_path, dpi=dpi, bbox_inches = "tight")
    
    # Save Tree!
    #out_path_dill_tree = os.path.join(outdir,  SCRIPT_NAME + "_{}.phylo.obj".format(attr.lower()))
    #with open(out_path_dill_tree,"wb") as outfile:
    #    dill.dump(tree_div, outfile)
    
file_colors.close()"""

---
# 5. Export

## Dataframe

In [None]:
# Dataframe
out_path_df = os.path.join(outdir, SCRIPT_NAME + ".tsv" )
out_path_pickle_df = os.path.join(outdir,  SCRIPT_NAME + ".df.obj" )

tree_df.to_csv(out_path_df, sep="\t")
with open(out_path_pickle_df,"wb") as outfile:
    dill.dump(tree_df, outfile)

## Timetrees

In [None]:
tt_copy = copy.deepcopy(tt)

# Nexus
out_path_nexus = os.path.join(outdir, SCRIPT_NAME + "_timetree.nexus" )
Phylo.write(tt_copy.tree, out_path_nexus, 'nexus', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

# Dill object
out_path_dill_tree = os.path.join(outdir,  SCRIPT_NAME + "_timetree.treetime.obj" )
with open(out_path_dill_tree,"wb") as outfile:
    dill.dump(tt_copy, outfile)
    
# Newick (remove comments)
for c in tt_copy.tree.find_clades(): c.comment = None
out_path_nwk = os.path.join(outdir, SCRIPT_NAME + "_timetree.nwk" )
Phylo.write(tt_copy.tree, out_path_nwk, 'newick', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

## Divtrees

In [None]:
tt_copy = copy.deepcopy(tt)
# Convert to divtree
for n in tt_copy.tree.find_clades():
    n.branch_length=n.mutation_length

# Nexus
out_path_nexus = os.path.join(outdir, SCRIPT_NAME + "_divtree.nexus" )
Phylo.write(tt_copy.tree, out_path_nexus, 'nexus', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

# Dill object
out_path_dill_tree = os.path.join(outdir,  SCRIPT_NAME + "_divtree.treetime.obj" )
with open(out_path_dill_tree,"wb") as outfile:
    dill.dump(tt_copy, outfile)
    
# Newick (remove comments)
for c in tt_copy.tree.find_clades(): c.comment = None
out_path_nwk = os.path.join(outdir, SCRIPT_NAME + "_divtree.nwk" )
Phylo.write(tt_copy.tree, out_path_nwk, 'newick', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

## Augur JSON

  - alignment (empty)
  - input_tree (tree_path)
  - nodes (node_dict)
  

In [None]:
augur_dict = augur_export(
    tree_path=out_path_nwk, 
    aln_path=aln_path, 
    tree=tt.tree,     
    tree_df=tree_df, 
    color_keyword_exclude=["color", "coord", "reg", "lat", "lon"],
    type_convert = {
        "Branch_Number" : (lambda x : str(x))
    },
)


print(augur_dict["nodes"]["NODE0"])

### Add Mugration Model to JSON

In [None]:
mug_models_dict = {}

for attr in ATTRIBUTE_LIST:
    mug_models_dict[attr] = {}
    
    # Mugration GTR model
    gtr = mug_dict[attr]["tree_mug"].gtr
    
    # Mugration Rate
    mug_models_dict[attr]["rate"] = gtr.mu
    
    # Mugration states
    alphabet = mug_dict[attr]["letter_to_state"]
    mug_models_dict[attr]["alphabet"] = [alphabet[k] for k in sorted(alphabet.keys())]
    
    # Mugration probabilities
    prob =  list(mug_dict[attr]["tree_mug"].gtr.Pi)
    mug_models_dict[attr]["equilibrium_probabilities"] = list(gtr.Pi)
    
    # Mugration transition matrix
    mat = [list(x) for x in gtr.W]
    mug_models_dict[attr]["equilibrium_probabilities"] = mat
    
augur_dict["models"] = mug_models_dict

In [None]:
out_path_augur_json = os.path.join(outdir, SCRIPT_NAME + "_augur.json" )
utils.write_json(data=augur_dict, file_name=out_path_augur_json, indent=JSON_INDENT)

print(augur_dict["nodes"]["Reference"])

## Auspice JSON

Manual edits of https://github.com/nextstrain/augur/blob/master/augur/export_v2.py

This can then be used for auspice via:

```
HOST="localhost" auspice view --datasetDir .
```

In [None]:
auspice_dict = auspice_export(
    tree=tree_div, 
    augur_json_paths=[out_path_augur_json], 
    auspice_config_path=auspice_config_path, 
    auspice_colors_path=auspice_colors_path,
    auspice_latlons_path=auspice_latlon_path, 
    )

# Write outputs - For Local Rendering
out_path_auspice_local_json = os.path.join(outdir, SCRIPT_NAME + "_auspice.json" )
utils.write_json(data=auspice_dict, file_name=out_path_auspice_local_json, indent=JSON_INDENT, include_version=False)
export_v2.validate_data_json(out_path_auspice_local_json)
print("Validation successful for local JSON.")

# Write outputs - For Remote Rendering
#out_path_auspice_remote_json = os.path.join(auspice_remote_dir_path, AUSPICE_PREFIX + SCRIPT_NAME.replace("_","-") + ".json" )
#utils.write_json(data=auspice_dict, file_name=out_path_auspice_remote_json, indent=JSON_INDENT, include_version=False)
#export_v2.validate_data_json(out_path_auspice_remote_json)
#print("Validation successful for remote JSON.")