In [None]:
import pandas as pd
import json
import os
import matplotlib as mpl
from matplotlib import pyplot as plt
import requests
from io import StringIO as sio
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
import itertools
import re
import sys
import pprint
import statistics

module_dir = "./scripts/"
sys.path.append(module_dir)

import balticEdited as bt
import reassortment_rates_host as rea

#use treesort_prepper to prep files and run treesort

#run before_plotting on treesort trees so then it is ready for plotting with baltic


In [None]:
def load_tree(filename):
    meta, mytree = bt.loadJSON(filename)
    return(mytree)

region_colors = {
    'Europe': '#9e0142',
    'Japan Korea': '#d53e4f',
    'Southeast Asia': '#f46d43',
    'West Asia': '#fdae61',
    'South America': '#fee08b',
    'South Asia' : '#e6f598',
    'Africa': '#abdda4',
    'China': '#66c2a5',
    'Oceania': '#3288bd',
    'North America': '#5e4fa2'
}

host_colors = {
    'Avian': '#d73027',
    'Equine': '#f46d43',
    'Mink': '#fdae61',
    'Canine': '#fee090',
    'Feline': '#ffffbf',
    'Seal': '#e0f3f8',
    'Human': '#abd9e9',
    'Camel': '#74add1',
    'Swine': '#4575b4'
}

In [None]:
filename = 'h3nx_ha.json'
mytree = load_tree(filename)

In [None]:
def plot_host(mytree, output_path, fig_name):
    
    plt.rcParams["font.family"] = "Arial"

    fig, ax = plt.subplots(figsize=(15, 15))

    x_attr = lambda k: k.absoluteTime
    
    color_by = lambda k: 'red' if k.traits['host'] == 'ancestor' else host_colors.get(k.traits['host'])

    mytree.plotTree(ax, x_attr=x_attr,colour=color_by, width = 3)
    
    mytree.plotPoints(ax,
                   x_attr=x_attr,
                   size=100,
                   colour=color_by,
                   outline_colour='#3f3f3f',
                   zorder=3,
                   marker='o',
                   edgecolor='#3f3f3f'
                     )

    legend_handles = [Patch(color=color, label=host) for host, color in host_colors.items()]

    legend = ax.legend(handles=legend_handles, title="$\\bf{Hosts}$", loc="lower left", fontsize='25')
    plt.setp(legend.get_title(),fontsize=25)
    
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))

    ax.set_yticks([])
    ax.set_yticklabels([])
    [ax.spines[loc].set_visible(False) for loc in ax.spines if loc not in ['bottom']]
    ax.tick_params(axis='x',labelsize=25,size=15, width=2,color='grey')
    ax.set_xlabel("Divergence", fontsize=25, fontweight="bold")
    fig.tight_layout()
    
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    else:
        pass
    
    
    plt.savefig(f"{output_path}/{fig_name}")
    plt.show()