# Objectives

1. Run Treemmer (Complete)
1. Plot Taxa vs. RTL
1. Run Treemmer (Target)
1. Prune Tree, Dataframe, Alignment
1. Plot Tree Comparison

---
# Setup

## Module Imports

In [None]:
import os
import matplotlib.pyplot as plt
from matplotlib import gridspec
from Bio import Phylo, Align, AlignIO
import seaborn as sns
import copy

## Input Paths

In [None]:
tree_path = "../../docs/results/latest/parse_tree/parse_tree.nwk"
aln_path = "../../docs/results/latest/snippy_multi/snippy-core_chromosome.snps.filter5.aln"
tree_df_path = "../../docs/results/latest/metadata/metadata.tsv"
treemmer_path = "../scripts/Treemmer.py"

auspice_config_path = "../../config/auspice_config.json" 
auspice_colors_path="../../docs/results/latest/parse_tree/parse_tree_colors.tsv"
auspice_latlons_path="../../docs/results/latest/parse_tree/parse_tree_latlon.tsv"
auspice_remote_dir_path = "../../auspice/"

outdir = "../../docs/results/latest/treemmer"

# Create output directory if it doesn't exist
if not os.path.exists(outdir):
    os.mkdir(outdir)

## Variables

In [None]:
from config import *

SCRIPT_NAME = "treemmer"

## Import Tree and Dataframe

In [None]:
tree_pre = Phylo.read(tree_path, "newick")
tree_pre.ladderize(reverse=False)

tree_df = pd.read_csv(tree_df_path, sep='\t')
# Fix the problem with multiple forms of NA in the table
# Consolidate missing data to the NO_DATA_CHAR
tree_df.fillna(NO_DATA_CHAR, inplace=True)
tree_df.set_index("Sample", inplace=True)

---
# 1. Run Treemmer (Complete)

In [None]:
# Remove old treemmer results
old_treemmer_files = os.popen("ls {}*".format(os.path.join(outdir,"treemmer"))).read().strip().split("\n")
for file in old_treemmer_files:
    if not file: continue
    print("Deleting old treemmer file:", file)
    os.remove(file)
    
# Copy over the tree to the new directory
out_path_tree = os.path.join(outdir, "treemmer.nwk")
os.system("cp {old_tree} {new_tree}".format(old_tree = tree_path, new_tree = out_path_tree))

# Run treemmer
os.system("python3 {treemmer} {tree} --verbose 0 --plot_complete".format(treemmer=treemmer_path, tree=out_path_tree))

## Create a dataframe for the output

In [None]:
treemmer_df_path = os.path.join(outdir, "treemmer_res_1_LD.txt")
treemmer_df = pd.read_csv(treemmer_df_path, sep='\t', header=None)
treemmer_df.columns = ["rtl", "taxa"]
treemmer_df

---
# 2. Plot Taxa vs. RTL

In [None]:
target_taxa = max(treemmer_df["taxa"])
target_iloc = 0

# Identify the number of taxa
for rec in treemmer_df.iterrows():
    iloc = rec[0]
    rtl = rec[1]["rtl"]
    taxa = rec[1]["taxa"] 
    # Stop once threshold is passed
    if rtl >= TARGET_RTL:
        target_taxa = taxa   
        target_iloc = iloc
    else:
        break
        
# Setup the figure        
fig, ax1 = plt.subplots(1, dpi=dpi, figsize=figsize)

# Plot Taxa vs. RTL
sns.scatterplot(data=treemmer_df, 
                x="taxa", 
                y="rtl", 
                color="black", 
                s=2, 
                edgecolor="black",
                ax=ax1)

# Plot the candidate point
sns.scatterplot(data=treemmer_df.iloc[[target_iloc]], 
                x="taxa", 
                y="rtl")

ax1.set_xlabel("Number of Tips")
ax1.set_ylabel("Relative Tree Length (RTL)")


ax1.set_xticks([t for t in range(0,max(treemmer_df["taxa"]),50)])
ax1.set_yticks([t/10 for t in range(0,11,1)])

x_buffer = max(treemmer_df["taxa"]) * 0.01
y_buffer = max(treemmer_df["rtl"]) * 0.01
ax1.set_xlim(max(treemmer_df["taxa"]) + x_buffer, 0 - x_buffer)

ax1.axhline(TARGET_RTL, linestyle="--")
ax1.axvline(target_taxa, linestyle="--")

# Annotation
ax1.annotate("Taxa: {}  RTL: {}".format(target_taxa, TARGET_RTL), 
             xy=(target_taxa / 1.01, TARGET_RTL * 1.01), 
             xycoords="data",      
             va="bottom", 
             ha="left",
             xytext=(target_taxa / 1.10, TARGET_RTL * 1.05), 
             textcoords='data',     
             arrowprops=dict(arrowstyle="->", connectionstyle="arc3"),             
             bbox=dict(boxstyle="round", fc="w", lw=0.5))

# Export
out_path = os.path.join(outdir, SCRIPT_NAME + "_rtl." + FMT)
plt.savefig(out_path, dpi=dpi, bbox_inches = "tight")

---
# 3. Run Treemmer (Target)

In [None]:
os.system("python3 {treemmer} {tree} --stop_at_X_leaves {stop} --verbose 0".format(treemmer=treemmer_path, 
                                                                                                    tree=out_path_tree,
                                                                                                    stop=int(target_taxa))
                                                                                                   )

---
# 4. Prune Tree, Dataframe, Alignment

In [None]:
tree = Phylo.read(tree_path, "newick")
tree.ladderize(reverse=False)

tree_df = pd.read_csv(tree_df_path, sep='\t')
# Fix the problem with multiple forms of NA in the table
# Consolidate missing data to the NO_DATA_CHAR
tree_df.fillna(NO_DATA_CHAR, inplace=True)
tree_df.set_index("Name", inplace=True)

treemmer_path = os.path.join(outdir, "treemmer_trimmed_tree_X_{}.nwk".format(int(target_taxa)))
treemmer_tree = Phylo.read(treemmer_path, "newick")
treemmer_tree.ladderize(reverse=False)

align = AlignIO.read(aln_path, format="fasta")
treemmer_seq = []

# Remove internal nodes that could have become tips
treemmer_tips = [t.name for t in treemmer_tree.get_terminals() if "NODE" not in t.name]

# -----------------------------
# Remove tips from tree, dataframe, and alignment
for t in tree.get_terminals():
    
    if t.name not in treemmer_tips:
        
        # Remove tips from tree
        tree.prune(t)
        
        # Remove tips from dataframe
        tree_df.drop(t.name, inplace=True)
        
        # Remove tips from Alignment
        for i in range(0, len(align._records)):
            if align[i].id == t:
                treemmer_seq.append(align[i])
                break

treemmer_align = Align.MultipleSeqAlignment(treemmer_seq)

# -----------------------------
# Draw tree
Phylo.draw(tree, 
        show_confidence=False, 
        label_func = lambda x: '',
        do_show=False,
        )

---
# 3. Plot

In [None]:
"""size = "200"
tree_post_path = os.path.join(outdir, "treemmer_trimmed_tree_X_{}.nwk".format(size))
tree_post = Phylo.read(tree_post_path, "newick")
tree_post.ladderize(reverse=False)

#tree_post_tips_path = 

# Setup the figure
fig = plt.figure(dpi=dpi,
                 figsize=(figsize[0] * 1, figsize[1] * 1),
                 )

# Setup the grid
gs = gridspec.GridSpec(2, 2, figure=fig, wspace=0.1)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])


# -----------------------------------------------------
# Axis 1: Phylogeny Pre-Filter

for c in tree_pre.get_terminals():
    if c.name not in tree_dict[size]["tips"]:
        c.color = "red"

Phylo.draw(tree_pre, 
        show_confidence=False, 
        label_func = lambda x: '',
        do_show=False,
        axes=ax1)

# Labels
ax1.set_xlabel("Branch Length")
ax1.set_yticks([])
ax1.set_ylabel('')
num_tips = len(tree_pre.get_terminals())
ax1.set_title("Phylogeny Pre Filter (N={})".format(num_tips))


Phylo.draw(tree_post, 
        show_confidence=False, 
        label_func = lambda x: '',
        do_show=False,
        axes=ax2)

# Labels
ax2.set_xlabel("Branch Length")
ax2.set_yticks([])
ax2.set_ylabel('')
num_tips = len(tree_post.get_terminals())
ax2.set_title("Phylogeny Post Filter (N={})".format(num_tips))"""

---
# 2. Add treemmer info to dataframe

---
# Export

## Save Tree, Dataframe, and Alignment

In [None]:
# Tree
out_path_xml = os.path.join(outdir,  SCRIPT_NAME + ".xml" )
out_path_nwk = os.path.join(outdir, SCRIPT_NAME + ".nwk" )
out_path_nexus = os.path.join(outdir, SCRIPT_NAME + ".nexus" )

Phylo.write(tree, out_path_xml, 'phyloxml')
Phylo.write(tree, out_path_nwk, 'newick', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))
Phylo.write(tree, out_path_nexus, 'nexus', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

# Dataframe
out_path_df = os.path.join(outdir, SCRIPT_NAME + ".tsv" )
tree_df.to_csv(out_path_df, sep="\t")

# Alignment
out_path = os.path.join(outdir, SCRIPT_NAME + ".fasta")
AlignIO.write(treemmer_align, outfile, "fasta")

## Augur JSON

In [None]:
augur_dict = augur_export(
    tree_path=out_path_nwk, 
    aln_path=aln_path, 
    tree=tree, 
    tree_df=tree_df, 
    color_keyword_exclude=["color", "coord"],
    type_convert = {
        "Branch_Number" : (lambda x : str(x))
    },
)

out_path_augur_json = os.path.join(outdir, SCRIPT_NAME + "_augur.json" )
utils.write_json(data=augur_dict, file_name=out_path_augur_json, indent=JSON_INDENT)

## Auspice JSON

In [None]:
auspice_dict = auspice_export(
    tree=tree, 
    augur_json_paths=[out_path_augur_json], 
    auspice_config_path=auspice_config_path, 
    auspice_colors_path=auspice_colors_path,
    auspice_latlons_path=auspice_latlons_path, 
    )

# Write outputs - For Local Rendering
out_path_auspice_local_json = os.path.join(outdir, SCRIPT_NAME + "_auspice.json" )
utils.write_json(data=auspice_dict, file_name=out_path_auspice_local_json, indent=JSON_INDENT, include_version=False)
export_v2.validate_data_json(out_path_auspice_local_json)
print("Validation successful for local JSON.")

# Write outputs - For Remote Rendering
out_path_auspice_remote_json = os.path.join(auspice_remote_dir_path, AUSPICE_PREFIX + SCRIPT_NAME.replace("_","-") + ".json" )
utils.write_json(data=auspice_dict, file_name=out_path_auspice_remote_json, indent=JSON_INDENT, include_version=False)
export_v2.validate_data_json(out_path_auspice_remote_json)
print("Validation successful for remote JSON.")