---
# Setup

## Modules

In [1]:
import os
import pandas as pd
import copy
from Bio import Phylo
from functions import *
import subprocess
import matplotlib.pyplot as plt
from matplotlib import colors

## Paths

In [2]:
WILDCARDS = ["all", "chromosome", "full", "30"]
#project_dir = os.path.dirname(os.path.dirname(os.getcwd()))
project_dir = "/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/denmark/"
results_dir = project_dir

READS_ORIGIN = WILDCARDS[0]
LOCUS_NAME = WILDCARDS[1]
PRUNE = WILDCARDS[2]
MISSING_DATA = WILDCARDS[3]

In [3]:
metadata_path = os.path.join(results_dir, "iqtree/all/chromosome/{}/filter{}/filter-taxa/metadata.tsv".format(PRUNE, MISSING_DATA))
tree_path = os.path.join(results_dir, "iqtree/all/chromosome/{}/filter{}/filter-taxa/iqtree.treefile".format(PRUNE, MISSING_DATA))
auspice_config_path = results_dir + "config/auspice_config.json"

# ------------------------------------------
# Output
auspice_dir = os.path.join(results_dir, "auspice/all/chromosome/{}/filter{}/ml/".format(PRUNE, MISSING_DATA))
if not os.path.exists(auspice_dir):
    subprocess.run(["mkdir", "-p", auspice_dir]) 
    
augur_dir = os.path.join(results_dir, "augur/all/chromosome/{}/filter{}/ml/".format(PRUNE, MISSING_DATA))
if not os.path.exists(augur_dir):
    subprocess.run(["mkdir", "-p", augur_dir]) 

In [4]:
NO_DATA_CHAR = "NA"
UNKNOWN_CHAR = "?"
JSON_INDENT = 2

## Metadata

In [5]:
metadata_df = pd.read_csv(metadata_path, sep='\t')
metadata_df.set_index(metadata_df.columns[0], inplace=True)
metadata_df.fillna(NO_DATA_CHAR, inplace=True)

display(metadata_df)

Unnamed: 0_level_0,strain,date,date_bp,country,province,country_lat,country_lon,province_lat,province_lon,biovar,...,date_bp_mean,date_err,lat,lon,host_human,branch_major_color,geometry_size,geometry,root_rtt_dist,clade_rtt_dist
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SAMEA5818830,STN021,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.21946e-05,1.21946e-05
SAMEA5818829,STN020,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.19382e-05,1.19382e-05
SAMEA5818828,STN019,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.20081e-05,1.20081e-05
SAMEA5818826,STN014,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.19003e-05,1.19003e-05
SAMEA5818825,STN013,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.24044e-05,1.24044e-05
SAMEA5818822,STN008,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.27904e-05,1.27904e-05
SAMEA5818821,STN007,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.21946e-05,1.21946e-05
SAMEA5818818,STN002,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.798562,8.231974,46.942756,8.411977,Second Pandemic,...,461.0,75.0,46.942756,8.411977,Human,#8000ff,8.0,POINT (8.4119773 46.942756),1.20579e-05,1.20579e-05
SAMEA5818817,STA001,[1420:1630],[-601:-391],Germany,Bavaria,51.08342,10.423447,48.946756,11.403872,Second Pandemic,...,496.0,105.0,48.946756,11.403872,Human,#8000ff,4.0,POINT (11.4038717 48.9467562),5.6598e-06,5.6598e-06
SAMEA5818815,NMS002,[1475:1536],[-546:-485],England,East of England,52.531021,-1.264906,52.219977,0.487578,Second Pandemic,...,515.5,30.5,52.219977,0.487578,Human,#8000ff,1.0,POINT (0.4875777469166293 52.2199774),9.3109e-06,9.3109e-06


## Phylogeny

### Import Tree

In [6]:
tree = Phylo.read(tree_path, format="newick")
tree.ladderize(reverse=False)

# Rename internal nodes
node_i = 0
for c in tree.find_clades():
    if not c.name:
        c.name = "NODE{}".format(node_i)
        node_i += 1

### Add Tree Metadata to Dataframe

In [7]:
parameters = [
    "branch_length", 
    "node_type",    
    "branch_support",
    "branch_support_conf_category",
    "branch_support_conf_char",
    # Custom
    "country_date_strain",
    "province_date_strain"
]

# Add to dataframe
for param in parameters:
    metadata_df[param] = [NO_DATA_CHAR] * len(metadata_df)

for c in tree.find_clades():
    # Defaults
    node_type = "internal"
    branch_support = 0
    branch_support_conf_category = "LOW"
    branch_support_conf_char = ""
    branch_length = 0
    country_date_strain = NO_DATA_CHAR
    province_date_strain = NO_DATA_CHAR
    
    if c.branch_length:
        branch_length = c.branch_length
    
    if c.confidence:
        branch_support = float(c.confidence)
        if branch_support >= 95:
            branch_support_conf_category = "HIGH"
            branch_support_conf_char = "*"
    if c.is_terminal():
        node_type = "terminal"
        country = metadata_df["country"][c.name]
        province = metadata_df["province"][c.name]
        date = metadata_df["date"][c.name]
        strain = metadata_df["strain"][c.name]
        
        country_date_strain = "{} {} {}".format(country, date, strain)
        province_date_strain = "{} {} {}".format(province, date, strain)
        
    metadata_df.at[c.name, "branch_length"] = branch_length
    metadata_df.at[c.name, "node_type"] = node_type
    metadata_df.at[c.name, "branch_support"] = branch_support
    metadata_df.at[c.name, "branch_support_conf_category"] = branch_support_conf_category
    metadata_df.at[c.name, "branch_support_conf_char"] = branch_support_conf_char
    metadata_df.at[c.name, "country_date_strain"] = country_date_strain
    metadata_df.at[c.name, "province_date_strain"] = province_date_strain
    
    
# Update internal nodes NA
metadata_df.fillna(NO_DATA_CHAR, inplace=True)
display(metadata_df)

Unnamed: 0_level_0,strain,date,date_bp,country,province,country_lat,country_lon,province_lat,province_lon,biovar,...,geometry,root_rtt_dist,clade_rtt_dist,branch_length,node_type,branch_support,branch_support_conf_category,branch_support_conf_char,country_date_strain,province_date_strain
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SAMEA5818830,STN021,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.7986,8.23197,46.9428,8.41198,Second Pandemic,...,POINT (8.4119773 46.942756),1.21946e-05,1.21946e-05,2.330000e-08,terminal,0.0,LOW,,Switzerland [1485:1635] STN021,Nidwalden [1485:1635] STN021
SAMEA5818829,STN020,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.7986,8.23197,46.9428,8.41198,Second Pandemic,...,POINT (8.4119773 46.942756),1.19382e-05,1.19382e-05,2.330000e-08,terminal,0.0,LOW,,Switzerland [1485:1635] STN020,Nidwalden [1485:1635] STN020
SAMEA5818828,STN019,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.7986,8.23197,46.9428,8.41198,Second Pandemic,...,POINT (8.4119773 46.942756),1.20081e-05,1.20081e-05,2.330000e-08,terminal,0.0,LOW,,Switzerland [1485:1635] STN019,Nidwalden [1485:1635] STN019
SAMEA5818826,STN014,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.7986,8.23197,46.9428,8.41198,Second Pandemic,...,POINT (8.4119773 46.942756),1.19003e-05,1.19003e-05,2.330000e-08,terminal,0.0,LOW,,Switzerland [1485:1635] STN014,Nidwalden [1485:1635] STN014
SAMEA5818825,STN013,[1485:1635],[-536:-386],Switzerland,Nidwalden,46.7986,8.23197,46.9428,8.41198,Second Pandemic,...,POINT (8.4119773 46.942756),1.24044e-05,1.24044e-05,2.331000e-07,terminal,0.0,LOW,,Switzerland [1485:1635] STN013,Nidwalden [1485:1635] STN013
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
NODE43,,,,,,,,,,,...,,,,5.391000e-07,internal,95.0,HIGH,*,,
NODE44,,,,,,,,,,,...,,,,9.323000e-07,internal,99.0,HIGH,*,,
NODE45,,,,,,,,,,,...,,,,4.662000e-07,internal,97.0,HIGH,*,,
NODE46,,,,,,,,,,,...,,,,2.330000e-08,internal,93.0,LOW,,,


### Colors

In [8]:
out_path_colors = os.path.join(augur_dir, "colors.tsv")
states = ["country", "province", "branch_major"]
colors_dict = {}

for attr in states:
    # Create the color map
    attr_key = attr.lower()
    colors_dict[attr_key] = {}
    
    for t in tree.get_terminals():
        attr_val = metadata_df[attr][t.name]
        # Remove the letter suffix from branch_minor
        if attr == "branch_minor":
            while attr_val[-1].isalpha():
                attr_val = attr_val[:-1] 
        if attr_val not in colors_dict[attr_key] and attr_val != NO_DATA_CHAR:
            colors_dict[attr_key][attr_val] = {}
              
    # Create the custom color map (pyplot)
    cmap = plt.get_cmap("rainbow", len(colors_dict[attr_key]))
    # Convert the color map to a list of RGB values
    cmaplist = [cmap(i) for i in range(cmap.N)]
    # Convert RGB values to hex colors
    attr_hex = [colors.to_hex(col) for col in cmaplist]
    
    # Assign colors to value
    for attr_val, attr_col in zip(colors_dict[attr_key], attr_hex):
        colors_dict[attr_key][attr_val] = attr_col   
        
    # Add unknown
    colors_dict[attr_key][UNKNOWN_CHAR] = "#969696"
    
print(colors_dict)

with open(out_path_colors, "w") as outfile:
    for attr_key in colors_dict:
        for attr_val in colors_dict[attr_key]:
            outfile.write(str(attr_key) + "\t" + str(attr_val) + "\t" + str(colors_dict[attr_key][attr_val]) + "\n")

{'country': {'Russia': '#8000ff', 'Denmark': '#5148fc', 'England': '#238af5', 'Norway': '#0cc1e8', 'Spain': '#3ae8d7', 'France': '#68fcc1', 'The Netherlands': '#97fca7', 'Germany': '#c5e88a', 'Italy': '#f3c16a', 'Switzerland': '#ff8a48', 'Lithuania': '#ff4824', 'Poland': '#ff0000', '?': '#969696'}, 'province': {'Tatarstan': '#8000ff', 'Region of Southern Denmark': '#632cfe', 'Greater London': '#4757fb', 'Oslo': '#2b7ff6', 'Catalonia': '#0ea4f0', 'Occitanie': '#0ec3e7', 'North Brabant': '#2adddd', 'Bavaria': '#47f0d1', 'Lazio': '#63fbc3', 'Central Denmark Region': '#80ffb4', 'East of England': '#9cfba4', 'Brandenburg': '#b8f092', 'Nidwalden': '#d4dd80', 'Baden-Württemberg': '#f1c36c', 'Vilnius County': '#ffa457', 'Pomeranian Voivodeship': '#ff8042', 'Rostov Oblast': '#ff572c', 'Chechnya': '#ff2c16', "Provence-Alpes-Côte d'Azur": '#ff0000', '?': '#969696'}, 'branch_major': {'1.PRE': '#8000ff', '?': '#969696'}}


### Latitude and Longitude

In [9]:
latlon_country_df = pd.DataFrame()
latlon_province_df = pd.DataFrame()

df_list = [latlon_country_df, latlon_province_df]
attr_list = ["country", "province"]

# Create a mapping of geo name to lat,lon
for df,attr in zip(df_list, attr_list):
    # Get data
    for rec in metadata_df.iterrows():
        node_name = rec[0]
        node_type = rec[1]["node_type"]
        name = rec[1][attr]
        country = rec[1]["country"]
        
        if node_type == "internal":
            continue

        if attr == "province" and name == NO_DATA_CHAR and node_type == "terminal" and country != "Russia":
            # Use country instead
            name = rec[1]["country"]
            lat = rec[1]["country_lat"]
            lon = rec[1]["country_lon"]
        else:
            lat = rec[1][attr + "_lat"]
            lon = rec[1][attr + "_lon"]             


        if name not in df.index:
            df.at[name, "lat"] = lat
            df.at[name, "lon"] = lon
            df.at[name, "size"] = 1
        else:
            df["size"][name] += 1

#display(latlon_country_df)
#display(latlon_province_df)

# Mapping file for auspice
out_path_latlon = os.path.join(augur_dir, "latlon.tsv")

# Countries
with open(out_path_latlon, "w") as outfile:
    for country in latlon_country_df.index:
        lat = str(latlon_country_df["lat"][country])
        lon = str(latlon_country_df["lon"][country])
        outfile.write("country" + "\t" + country + "\t" + lat + "\t" + lon + "\n")
# Provinces
with open(out_path_latlon, "a") as outfile:
    for province in latlon_province_df.index:
        lat = str(latlon_province_df["lat"][province])
        lon = str(latlon_province_df["lon"][province])
        outfile.write("province" + "\t" + province + "\t" + lat + "\t" + lon + "\n")
        
#display(metadata_df[metadata_df["continent"] == "Europe"])

---
# Export

## Create Sub Dataframe

In [10]:
# Remember, order atters when dealing with confidence!

columns = [
    # Required
    "branch_length",
    "node_type",
    # Geo
    "country",
    "province",
    # Colors and Filters   
    "branch_support",
    "branch_support_conf_category",  
    "branch_support_conf_char",  
    "continent",
    # Text Description
    "biosample_accession",
    "strain",
    "country_date_strain",
    "province_date_strain",
    "host_human",
    # Tip Dates
    "date_mean",
    "date_err",
    "date_bp_mean",
    # Stats
    "root_rtt_dist",
    "clade_rtt_dist",   
]

# Edit df
auspice_df = copy.copy(metadata_df[columns])
auspice_df["node_name"] = list(auspice_df.index)
auspice_df["blank"] = [" "] * len(auspice_df)

display(auspice_df)

Unnamed: 0_level_0,branch_length,node_type,country,province,branch_support,branch_support_conf_category,branch_support_conf_char,continent,biosample_accession,strain,country_date_strain,province_date_strain,host_human,date_mean,date_err,date_bp_mean,root_rtt_dist,clade_rtt_dist,node_name,blank
sample,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
SAMEA5818830,2.330000e-08,terminal,Switzerland,Nidwalden,0.0,LOW,,Europe,SAMEA5818830,STN021,Switzerland [1485:1635] STN021,Nidwalden [1485:1635] STN021,Human,1560,75,461,1.21946e-05,1.21946e-05,SAMEA5818830,
SAMEA5818829,2.330000e-08,terminal,Switzerland,Nidwalden,0.0,LOW,,Europe,SAMEA5818829,STN020,Switzerland [1485:1635] STN020,Nidwalden [1485:1635] STN020,Human,1560,75,461,1.19382e-05,1.19382e-05,SAMEA5818829,
SAMEA5818828,2.330000e-08,terminal,Switzerland,Nidwalden,0.0,LOW,,Europe,SAMEA5818828,STN019,Switzerland [1485:1635] STN019,Nidwalden [1485:1635] STN019,Human,1560,75,461,1.20081e-05,1.20081e-05,SAMEA5818828,
SAMEA5818826,2.330000e-08,terminal,Switzerland,Nidwalden,0.0,LOW,,Europe,SAMEA5818826,STN014,Switzerland [1485:1635] STN014,Nidwalden [1485:1635] STN014,Human,1560,75,461,1.19003e-05,1.19003e-05,SAMEA5818826,
SAMEA5818825,2.331000e-07,terminal,Switzerland,Nidwalden,0.0,LOW,,Europe,SAMEA5818825,STN013,Switzerland [1485:1635] STN013,Nidwalden [1485:1635] STN013,Human,1560,75,461,1.24044e-05,1.24044e-05,SAMEA5818825,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
NODE43,5.391000e-07,internal,,,95.0,HIGH,*,,,,,,,,,,,,NODE43,
NODE44,9.323000e-07,internal,,,99.0,HIGH,*,,,,,,,,,,,,NODE44,
NODE45,4.662000e-07,internal,,,97.0,HIGH,*,,,,,,,,,,,,NODE45,
NODE46,2.330000e-08,internal,,,93.0,LOW,,,,,,,,,,,,,NODE46,


## Augur JSON

In [11]:
augur_dict = augur_export(
    tree_path=None, 
    aln_path=None,  
    tree=tree, 
    tree_df=auspice_df, 
    color_keyword_exclude=["geometry"],
    type_convert = {
        "branch_number" : (lambda x : str(x))
    },
)

first_node = list(augur_dict["nodes"].keys())[0]

print(augur_dict["nodes"][first_node])

out_path_augur_json = os.path.join(augur_dir, "all.json" )
utils.write_json(data=augur_dict, file_name=out_path_augur_json, indent=JSON_INDENT)

{'branch_length': 0.0, 'node_type': 'internal', 'country': 'NA', 'province': 'NA', 'branch_support': 100.0, 'branch_support_conf_category': 'HIGH', 'branch_support_conf_char': '*', 'continent': 'NA', 'biosample_accession': 'NA', 'strain': 'NA', 'country_date_strain': 'NA', 'province_date_strain': 'NA', 'host_human': 'NA', 'date_mean': 'NA', 'date_err': 'NA', 'date_bp_mean': 'NA', 'root_rtt_dist': 'NA', 'clade_rtt_dist': 'NA', 'node_name': 'NODE0', 'blank': ' '}


## Auspice JSON

In [13]:
#import sys, importlib
#importlib.reload(sys.modules['functions'])
#from functions import auspice_export, branch_attributes

auspice_dict = auspice_export(
    tree=tree,
    augur_json_paths=[out_path_augur_json], 
    auspice_config_path=auspice_config_path, 
    auspice_colors_path=out_path_colors,
    auspice_latlons_path=out_path_latlon, 
    auspice_geo_res=["country", "province",]    
    )


label_col = list(auspice_df.columns)

# Recursively add branch attrs
branch_attributes(
    tree_dict=auspice_dict["tree"], 
    sub_dict=auspice_dict["tree"], 
    df=auspice_df,
    label_col=label_col,
    )


# Last manual changes
auspice_dict_copy = copy.deepcopy(auspice_dict)
for i in range(0, len(auspice_dict_copy["meta"]["colorings"])):
    coloring = auspice_dict_copy["meta"]["colorings"][i]
    for key in coloring:
        # Node type as internal or terminal
        if coloring[key] == "node_type":
            auspice_dict["meta"]["colorings"][i]['scale'] = [['internal', '#FFFFFF'], ['terminal', '#000000']]
            #print(auspice_dict["meta"]["colorings"][i])
        # Confidence category
        if "conf_category" in coloring[key]:
            auspice_dict["meta"]["colorings"][i]['scale'] = [['LOW', '#FFFFFF'], ['HIGH', '#000000']]
            #print(auspice_dict["meta"]["colorings"][i])
        # Host Human binary
        if "host_human" in coloring[key]:
            auspice_dict["meta"]["colorings"][i]['scale'] = [['Human', '#CBB742'], ['Non-Human', "#60B6F2"], ['NA', "#D6D6D6"]]

        
# Write outputs - For Local Rendering
out_path_auspice_local_json = os.path.join(auspice_dir, "all.json" )
utils.write_json(data=auspice_dict, file_name=out_path_auspice_local_json, indent=JSON_INDENT, include_version=False)
export_v2.validate_data_json(out_path_auspice_local_json)
print("Validation successful for local JSON.\n")

#out_path_auspice_remote_json = os.path.join(auspice_dir, "{}_all.json".format(AUSPICE_PREFIX))
#utils.write_json(data=auspice_dict, file_name=out_path_auspice_local_json, indent=JSON_INDENT, include_version=False)
#export_v2.validate_data_json(out_path_auspice_local_json)
#print("Validation successful for local JSON.\n")

DEPRECATED: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.





























Validating schema of '/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/denmark/config/auspice_config.json'...
Validation success.
Validating produced JSON
Validating schema of '/mnt/c/Users/ktmea/Projects/plague-phylogeography-projects/denmark/auspice/all/chromosome/full/filter30/ml/all.json'...
Validating that the JSON is internally consistent...
Validation successful for local JSON.



