# Objectives

1. Run clock analysis.
1. Add clock stats to dataframe.

Troubleshooting:
  - Disappearing nodes?

---
# Setup

## Import Modules

In [1]:
import os # Create directories and files
import random # Set seed for stats
import copy # copy objects to prevent permanent modification

# Logging output to file
import sys
import io

# Phylogenetics
from Bio import Phylo # Tree operations
from Bio import AlignIO # Add constant sites to alignment
import treetime # Timetree operations

# JSON
import json

## Input File Paths

In [2]:
tree_path = "../../docs/results/latest/parse_tree/parse_tree.nwk"
tree_df_path = "../../docs/results/latest/mugration/mugration.tsv"
aln_path = "../../docs/results/latest/snippy_multi/snippy-core_chromosome.snps.filter5.aln"
constant_sites_path = "../../docs/results/latest/snippy_multi/snippy-core_chromosome.full.constant_sites.txt"
outdir = "../../docs/results/latest/timetree/"

# Create output directory if it doesn't exist
if not os.path.exists(outdir):
    os.mkdir(outdir)

## Variables

In [3]:
from config import *

NAME_COL = "Name"
SCRIPT_NAME = "timetree"

# Random
random.seed(1152342, 2)
np.random.seed(70262122)
st0 = np.random.get_state()


In [4]:
align = AlignIO.read(aln_path, format="fasta")

In [5]:
"""constant_sites_dict = {"A": 0, "C" : 0, "G" : 0, "T" : 0}

with open(constant_sites_path, "r") as infile:
    constant_sites_list = infile.read().strip().split(",")
    constant_sites_dict["A"] = int(constant_sites_list[0])
    constant_sites_dict["C"] = int(constant_sites_list[1])
    constant_sites_dict["G"] = int(constant_sites_list[2])
    constant_sites_dict["T"] = int(constant_sites_list[3])    

print(constant_sites_dict)
total_constant_sites = sum(constant_sites_dict.values())
print("Constant Sites:", total_constant_sites)

# Add the constant sites to each sample
# Iterate through each samples sequence
for rec in align:
    # Iterate through each nucleotide for constant sites
    for nucleotide,count in constant_sites_dict.items():
        rec.seq = rec.seq + (nucleotide * count)"""

'constant_sites_dict = {"A": 0, "C" : 0, "G" : 0, "T" : 0}\n\nwith open(constant_sites_path, "r") as infile:\n    constant_sites_list = infile.read().strip().split(",")\n    constant_sites_dict["A"] = int(constant_sites_list[0])\n    constant_sites_dict["C"] = int(constant_sites_list[1])\n    constant_sites_dict["G"] = int(constant_sites_list[2])\n    constant_sites_dict["T"] = int(constant_sites_list[3])    \n\nprint(constant_sites_dict)\ntotal_constant_sites = sum(constant_sites_dict.values())\nprint("Constant Sites:", total_constant_sites)\n\n# Add the constant sites to each sample\n# Iterate through each samples sequence\nfor rec in align:\n    # Iterate through each nucleotide for constant sites\n    for nucleotide,count in constant_sites_dict.items():\n        rec.seq = rec.seq + (nucleotide * count)'

## Import Divergence Tree

In [None]:
tree_div = Phylo.read(tree_path, "newick")
tree_div.ladderize(reverse=False)

## Import Dataframe

In [None]:
tree_df = pd.read_csv(tree_df_path, sep='\t')
# Fix the problem with multiple forms of NA in the table
# Consolidate missing data to the NO_DATA_CHAR
tree_df.fillna(NO_DATA_CHAR, inplace=True)
tree_df.set_index(NAME_COL, inplace=True)

## Remove Bad Branches

Identified in the next step

In [None]:
# HOT FIX for bad branches
tmp_tree = os.path.join(outdir, "temp.nwk")

bad_samples = ["GCA_008630375.1_ASM863037v1_genomic", 
                "GCA_008630375.2_ASM863037v2_genomic", 
                "GCA_003086075.1_ASM308607v1_genomic",
                "GCA_001613865.1_ASM161386v1_genomic"]

# Prune bad samples from tree and dataframe
for sample in bad_samples:
    tree_div.prune(sample)
    tree_df.drop(index=sample, inplace=True)
    
# Save temp files
tmp_path_df = os.path.join(outdir, SCRIPT_NAME + ".tsv" )
tree_df.to_csv(tmp_path_df, sep="\t")

tmp_path_nwk = os.path.join(outdir, SCRIPT_NAME + ".nwk" )
Phylo.write(tree_div, tmp_path_nwk, 'newick', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

## Intialize Timetree Object

In [None]:
# Use the utils function to parse the metadata dates
dates_raw = treetime.utils.parse_dates(tmp_path_df, 
                                   date_col=DATE_COL, 
                                   name_col = NAME_COL)

# Remove nan elements (internal nodes)
dates = {}
for k,v in dates_raw.items():
    if type(v) == list:
        dates[k] = v
    elif not pd.isnull(v):
        dates[k] = v

# Construct the treetime object
# Remember, including the alignment is crucial!
tt = treetime.TreeTime(dates=dates, 
                       aln=aln_path,                     
                       tree=tmp_path_nwk, 
                       verbose=4, 
                       fill_overhangs=False,
                       seq_len=REF_LEN,                        
                      )

# Remove outliers
tt.clock_filter(reroot=None, 
                n_iqd=N_IQD, 
                plot=True,
               )

# Check rtt
print(tt.date2dist.__dict__)

---
# 1. Clock Analysis

In [None]:
# Initialize stdout capture
# old_stdout = sys.stdout
# new_stdout = io.StringIO()
# sys.stdout = new_stdout

# PARAM MIN: root='-4101-09-02'
#tt.run()

#tt.run(time_marginal=TIME_MARGINAL)

# PARAM FULL: root=''
tt.run(
       Tc="skyline", 
       max_iter=MAX_ITER,
       relaxed_clock={"slack":5.0, "coupling": 0},
       infer_gtr=True,
       time_marginal=TIME_MARGINAL,
       sequence_marginal=SEQ_MARGINAL,
       verbose=4,
       resolve_polytomies=False,
       n_iqd=N_IQD,
       # branch_length_mode = "input",
       # root=None,
       # use_covariation=False,
       # vary_rate=False,
       )

# Save stdout to file
# output = new_stdout.getvalue()
# out_path = os.path.join(outdir, SCRIPT_NAME + ".log") 
# with open(out_path, "w") as file:
#     file.write(output)
# # Restore stdout
# sys.stdout = old_stdout
# print("Standard output restored.")

In [None]:
# Quick check 1
tt.tree.common_ancestor("NODE0")

In [None]:
# Quick check 2
tt.clock_model

## Ladderize Tree

In [None]:
tt.tree.ladderize(reverse=False)

---
# 2. Add clock stats to data frame

- Rates
- Dates
- RTT Regression

## Rates

In [None]:
tree_df["timetree_rate"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_rate_fold_change"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_mutation_length"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  

# The mean rate is the slope
mean_rate = tt.clock_model["slope"]

for c in tt.tree.find_clades():
    tree_df.at[c.name, "timetree_mutation_length"] = c.mutation_length
    
    # Relaxed Clock
    if hasattr(c, "branch_length_interpolator") and c.branch_length_interpolator:
        g = c.branch_length_interpolator.gamma
        tree_df.at[c.name, "timetree_rate_fold_change"] = g
        tree_df.at[c.name, "timetree_rate"] = mean_rate * g
        
    # Strict Clock
    else:
        tree_df.at[c.name, "timetree_rate_fold_change"] = 1
        tree_df.at[c.name, "timetree_rate"] = mean_rate

## Dates

In [None]:
# Create new columns
tree_df["timetree_date"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_numdate"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  

# Optional confidence intervals if marginal prob was run
tree_df["timetree_numdate_lower"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_numdate_upper"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  

# clock_length is the same as branch_length until running branch_length_to_years()
tree_df["timetree_clock_length"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  

# Make a copy to change branch_length
tt_copy = copy.deepcopy(tt)
tt_copy.branch_length_to_years()

for c in tt_copy.tree.find_clades():
    # Marginal Probability
    if hasattr(c, "marginal_inverse_cdf"):    
        # Retrieve the region containing the confidence interval
        conf = tt.get_max_posterior_region(c, fraction=CONFIDENCE) 
        
        # Set as lower and upper bounds on date
        tree_df.at[c.name, "timetree_numdate_lower"] = conf[0]
        tree_df.at[c.name, "timetree_numdate_upper"] = conf[1]

    tree_df.at[c.name, "timetree_date"] = c.date  
    tree_df.at[c.name, "timetree_numdate"] = c.numdate
    tree_df.at[c.name, "timetree_clock_length"] = c.branch_length

In [None]:
tree_df

## Regression

In [None]:
# make a copy of the tree
tt_copy = copy.deepcopy(tt)
tt_copy.branch_length_to_years()

# Plotting the tree
tree_df["timetree_coord_x"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_coord_y"] = [NO_DATA_CHAR for row in range(0,len(tree_df))] 

# Plotting the regression
tree_df["timetree_reg_x"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_reg_y"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  
tree_df["timetree_reg_bad"] = [NO_DATA_CHAR for row in range(0,len(tree_df))]  

x_posns = get_x_positions(tt_copy.tree)
y_posns = get_y_positions(tt_copy.tree)
tt_reg = tt_copy.setup_TreeRegression()

# Add x and y coordinates
for c in tt_copy.tree.find_clades():     
        
    # Tree Node Coordinates
    coord_x = [value for key,value in x_posns.items() if key.name == c.name][0]
    coord_y = [value for key,value in y_posns.items() if key.name == c.name][0]
    tree_df.at[c.name, 'timetree_coord_x'] = coord_x
    tree_df.at[c.name, 'timetree_coord_y'] = coord_y
    
    # Regression Node Coordinates
    reg_y = c._v
    if c.is_terminal():
        reg_x = tt_reg.tip_value(c)
    else:
        reg_x = c.numdate
    reg_bad = c.bad_branch  if hasattr(c, 'bad_branch') else False
    tree_df.at[c.name, 'timetree_reg_x'] = reg_x
    tree_df.at[c.name, 'timetree_reg_y'] = reg_y    
    tree_df.at[c.name, 'timetree_reg_bad'] = reg_bad

# Fix up new values that could be none
tree_df.fillna(NO_DATA_CHAR, inplace=True)
tree_df

---
# Export

## Dataframe

In [None]:
# Save tree dataframe with clock info
out_path_df = os.path.join(outdir, SCRIPT_NAME + ".tsv" )
tree_df.to_csv(out_path_df, sep="\t")

# Save timetree trees
out_path_xml = os.path.join(outdir, SCRIPT_NAME + ".xml" )
out_path_nwk = os.path.join(outdir, SCRIPT_NAME + ".nwk" )
out_path_nexus = os.path.join(outdir, SCRIPT_NAME + ".nexus" )
Phylo.write(tt.tree, out_path_xml, 'phyloxml')
Phylo.write(tt.tree, out_path_nwk, 'newick', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))
Phylo.write(tt.tree, out_path_nexus, 'nexus', format_branch_length='%1.{}f'.format(BRANCH_LEN_SIG_DIG))

## JSON

In [None]:
clock_model_dict = {}
keys = ["slope", "intercept", "chisq", "r_val"]
for k in keys:
    clock_model_dict[k] = tt.clock_model[k]

out_path_json = os.path.join(outdir, SCRIPT_NAME + "_clock_model.json" )
with open(out_path_json, "w") as outfile:  
    json.dump(clock_model_dict, outfile, indent=4, sort_keys=True)