In [None]:
import baltic as bt
import pandas as pd
import arviz as az

from datetime import datetime as dt
from datetime import timedelta
import time
from io import StringIO
import altair as alt
from zipfile import ZipFile
import math
import re
import random


import sys, subprocess, glob, os, shutil, re, importlib
from subprocess import call
import imp
from scipy.stats import gaussian_kde
import geopandas

%matplotlib inline
import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
from matplotlib.font_manager import FontProperties
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as clr
from matplotlib import rc
import textwrap as textwrap
from textwrap import wrap

import numpy as np
from scipy.special import binom

from altair import datum
import arviz as az
from scipy.stats import gaussian_kde

alt.data_transformers.disable_max_rows()


In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_partial_year(number):

    year = int(number)
    d = timedelta(days=(number - year)*(365 + is_leap(year)))
    day_one = dt(year,1,1)
    date = d + day_one
    date = dt.strftime(date, '%Y-%m-%d')
    return date

In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_persistence(number):

    
    d = timedelta(days=(number)*(365))
    
    return d.total_seconds()


In [None]:
def is_leap(number):
    if number == 2024:
        leap = 1
    else:
        leap = 0
    return leap


In [None]:
def convert_format(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = date - timedelta(days=date.weekday())
    return date

In [None]:
def convert_format_month(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%Y-%m')
    return date

In [None]:
def convert_format_month_only(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%m')
    return date

In [None]:
def convertDate(x,start,end):
    """ Converts calendar dates between given formats """
    return dt.strftime(dt.strptime(x,start),end)

In [None]:
def decimal_to_days(decimal_value):
    days = int(decimal_value * 365)
    return days

In [None]:
def get_taxa_lines(tree_path):    

    lines_to_write = ""
    with open(trees, 'rU') as infile:
        for line in infile: ## iterate through each line
            if 'state' not in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                lines_to_write = lines_to_write + line

    return(lines_to_write)


In [None]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)



In [None]:
#making decimal date from string dates adapted from stackoverflow (thank you coding geniuses)
def toYearFraction(date):
    def sinceEpoch(date): # returns seconds since epoch
        return time.mktime(date.timetuple())
    s = sinceEpoch

    year = date.year
    startOfThisYear = dt(year=year, month=1, day=1)
    startOfNextYear = dt(year=year+1, month=1, day=1)

    yearElapsed = s(date) - s(startOfThisYear)
    yearDuration = s(startOfNextYear) - s(startOfThisYear)
    fraction = yearElapsed/yearDuration

    return date.year + fraction


In [None]:
def enumerate_migration_events(tree, traitType):
        
    output_dict = {}
    migration_events_counter = 0
    #tree_leaves = [leaf for leaf in tree.getExternal()]
    for k in tree.Objects:
        
        if traitType not in k.traits:
            trait = "root"
        else:
            trait = str(k.traits[traitType])
            
        parent_node = k.parent
        
        if traitType not in parent_node.traits:
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        else:
            parent_trait = str(parent_node.traits[traitType])
        
        if (trait != parent_trait) and (parent_trait != "root"):
            migration_events_counter += 1

            migration_event = parent_trait + "-to-" + trait
            migration_date = parent_node.absoluteTime + (k.absoluteTime - parent_node.absoluteTime) *random.uniform(0,1)
            parent_tmrca = parent_node.absoluteTime 
            chain_tmrca = k.absoluteTime
            size_of_chain =len([leaf for leaf in parent_node.leaves])
            leaf_list = []
            for leaf in tree.getExternal():
                if leaf.name in parent_node.leaves:
                    leaf_list.append(leaf)
            chain_latest_tip = max(x.absoluteTime for x in leaf_list)
                    
            # write to output dictionary
            output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date, "parent_tmrca":parent_tmrca,
                                                     "chain_tmrca": chain_tmrca, "chain_latest_tip": chain_latest_tip, "size_of_chain":size_of_chain,
                                                     "parent_host":parent_trait,
                                                     "child_host": trait, "tree_length": sum([x.length for x in tree.Objects])}

    return(output_dict)


In [None]:
#counts all migration events and records parent and child nodes
def run_mig_counts(all_trees, traitType):
    start_time = time.time()
    with open(all_trees, "r") as infile:

        tree_counter = 0
        trees_processed = 0
        migrations_dict = {}

        for line in infile:
            if 'tree STATE_' in line:
                tree_counter += 1

                if tree_counter > burnin:
                    temp_tree = StringIO(taxa_lines + line)
                    
                    tree = bt.loadNexus(temp_tree, absoluteTime = False)
                    tree.setAbsoluteTime(2024.6025)
                    trees_processed += 1

                    # iterate through the tree and pull out all migration events
                    migrations_dict[tree_counter] = enumerate_migration_events(tree, traitType)

    # print the amount of time this took
    total_time_seconds = time.time() - start_time
    total_time_minutes = total_time_seconds/60
    print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", trees_processed, "trees")
   
    """this will generate a multi-index dataframe from the migrations dictionary"""
    migrations_df = pd.DataFrame.from_dict({(i,j): migrations_dict[i][j] 
                           for i in migrations_dict.keys() 
                           for j in migrations_dict[i].keys()},
                       orient='index')

    migrations_df.reset_index(inplace=True)
    migrations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
    
    return(migrations_df)

In [None]:
def return_proportions_dataframe(input_df, time_unit):
    output_df = pd.DataFrame()

    
    for tree_number in set(input_df['tree_number'].tolist()):
        local_df1 = input_df[input_df['tree_number'] == tree_number]
        
        for v in list(set(input_df['type'].tolist())):
            local_df = local_df1[local_df1['type'] == v]
            total_transitions = len(local_df)

            for item in set(input_df[time_unit].tolist()):
                local_df2 = local_df[local_df[time_unit] == item]
                transitions_in_time_unit = len(local_df2)

                   
                if total_transitions != 0 :
                    prop_transitions_in_time_unit = transitions_in_time_unit/total_transitions
                else:
                    prop_transitions_in_time_unit = 0
                    


                to_add = pd.DataFrame({"migration_direction":[v],time_unit:[item],"tree_number":[tree_number], 
                                       "total_transitions":[total_transitions],
                                       "transitions_in_time_interval":[transitions_in_time_unit],
                                      "proportion_transitions_in_time_interval":[prop_transitions_in_time_unit]})
                output_df = output_df.append(to_add)
            
    return(output_df)


In [None]:
#read in the introdcution rate
def read_in_forward_migration_rates_mascot(log_file_path):
    
    mig_rates_dict = {"sample":[]}
    
    with open(log_file_path, "r") as infile:
        line_number = 0
        for line in infile:
            #print(line_number)
            line_number += 1
            if not line.startswith("#"):  # log combiner will sometimes put the entire xml at the start of the log file
                # use the first line to find the migration rate columns
                
            # use the first line to find the migration rate columns
                if "posterior" in line:
                    all_cols = line.split("\t")
                    mig_column_indices = []   # list to store column indices
                    mig_key = {}   # dictionary to store the column index to map to column name

                    for i in range(len(all_cols)):
                        col = all_cols[i]
                        if "immigrationRate" in col:
                            mig_column_indices.append(i)

                    # make an empty dictionary to store Nes and generate dictionary to convert index to name
                    for n in mig_column_indices:
                        name = line.split("\t")[n]
                        interval = name.split(".")[1]# the syntax here is "NeLog.state01" where 0 is deme and 1 is interval 1
                        #interval = name.split(".")[2]
                       
                        mig_key[n] = name
                        mig_rates_dict[name] = []


                # read in actual parameter estimates and store in dictionary
                else:
                    sample = line.split("\t")[0]
                    mig_rates_dict["sample"].append(sample)

                    for index in mig_column_indices:
                        name = mig_key[index]
                        mig_rates_dict[name].append(line.split("\t")[index])
                    
                
                
                
    return(mig_rates_dict)

In [None]:
# make a new dataframe that summarizes the 95% HPD estimate with mean for each deme and interval 
def generate_summary_mig_df(input_df):
    
    
    new_df = pd.DataFrame()

    for i in input_df.columns.tolist():
        if "immigrationRate." in i:
            interval = i.split(".")[1]
            local_series = input_df[i].astype('float').to_numpy()
            mean_log = local_series.mean()
            median_log = np.median(local_series)
            mean_linear = math.exp(mean_log)
            hpd_95 = az.hdi(local_series, 0.95)
            lower_hpd_log_95 = hpd_95[0]
            lower_hpd_linear_95 = math.exp(lower_hpd_log_95)
            upper_hpd_log_95 = hpd_95[1]
            upper_hpd_linear_95 = math.exp(upper_hpd_log_95)
            hpd_50 = az.hdi(local_series, 0.50)
            lower_hpd_log_50 = hpd_50[0]
            lower_hpd_linear_50 = math.exp(lower_hpd_log_50)
            upper_hpd_log_50 = hpd_50[1]
            upper_hpd_linear_50 = math.exp(upper_hpd_log_50)
            
            
            try:
                local_df = pd.DataFrame.from_dict({"interval":interval, "mean_mig_log":mean_log,"mean_mig_linear":mean_linear, 
                                                   "median_mig_log" : median_log, 
                                                   "upper_hpd_log_95":upper_hpd_log_95,"lower_hpd_log_95":[lower_hpd_log_95], 
                                                   "upper_hpd_log_50":upper_hpd_log_50,"lower_hpd_log_50":lower_hpd_log_50,
                                                   "upper_hpd_linear":upper_hpd_linear_95,"lower_hpd_linear":lower_hpd_linear_95,
                                                   "upper_hpd_linear_50":upper_hpd_linear_50, "lower_hpd_linear_50":lower_hpd_linear_50,
                                                  })
                new_df = new_df.append(local_df)
                #print(new_df)
            except:
                pass
            
   
            

            
    return(new_df)

In [None]:
def read_in_Ne_changes_mascot(log_file_path):
    
    Ne_skyline_dict = {"sample":[]}
    
    with open(log_file_path, "r") as infile:
        line_number = 0
        for line in infile:
            line_number += 1
            if not line.startswith("#"):  # log combiner will sometimes put the entire xml at the start of the log file
                # use the first line to find the migration rate columns
                #print(line)
            # use the first line to find the migration rate columns
                if "posterior" in line:
                    all_cols = line.split("\t")
                    Ne_column_indices = []   # list to store column indices
                    Nes_key = {}   # dictionary to store the column index to map to column name

                    for i in range(len(all_cols)):
                        col = all_cols[i]
                        if "Ne." in col:
                            Ne_column_indices.append(i)

                    # make an empty dictionary to store Nes and generate dictionary to convert index to name
                    for n in Ne_column_indices:
                        name = line.split("\t")[n]
                        #deme = name.split(".")[1]# the syntax here is "Ne_region.1" where region is deme and 1 is interval 1
                        interval = name.split(".")[1]
                       
                        Nes_key[n] = name
                        Ne_skyline_dict[name] = []


                # read in actual parameter estimates and store in dictionary
                else:
                    sample = line.split("\t")[0]
                    Ne_skyline_dict["sample"].append(sample)

                    for index in Ne_column_indices:
                        name = Nes_key[index]
                        Ne_skyline_dict[name].append(line.split("\t")[index])
                    
                
    return(Ne_skyline_dict)

In [None]:
# make a new dataframe that summarizes the 95% HPD estimate with mean for each deme and interval 
def generate_summary_ne_df(input_df):
    
    
    new_df = pd.DataFrame()

    for i in input_df.columns.tolist():
        if "Ne." in i:
            #deme = i.split(".")[1]
           # print(deme)
            #interval = 
            #print(interval)
#             if "\n" in i.split(".")[2]:
#                 interval = i.split(".")[2][0:2]
#             else:
            interval = i.split(".")[1]
           # print(interval)
            #print(interval)
            #print(i)
            #next_interval = int(interval)+1
            local_series = input_df[i].astype('float').to_numpy()
            #print(local_series)
            mean_log = local_series.mean()
            median_log = np.median(local_series)
            mean_linear = math.exp(mean_log)
            hpd_95 = az.hdi(local_series, 0.95)
            lower_hpd_log_95 = hpd_95[0]
            lower_hpd_linear_95 = math.exp(lower_hpd_log_95)
            upper_hpd_log_95 = hpd_95[1]
            upper_hpd_linear_95 = math.exp(upper_hpd_log_95)
            hpd_50 = az.hdi(local_series, 0.50)
            lower_hpd_log_50 = hpd_50[0]
            lower_hpd_linear_50 = math.exp(lower_hpd_log_50)
            upper_hpd_log_50 = hpd_50[1]
            upper_hpd_linear_50 = math.exp(upper_hpd_log_50)
            

            
            try:
                local_df = pd.DataFrame.from_dict({"interval":interval, "mean_Ne_log":mean_log,"mean_Ne_linear":mean_linear, 
                                                   "median_Ne_log" : median_log, 
                                                   "upper_hpd_log_95":upper_hpd_log_95,"lower_hpd_log_95":[lower_hpd_log_95], 
                                                   "upper_hpd_log_50":upper_hpd_log_50,"lower_hpd_log_50":lower_hpd_log_50,
                                                   "upper_hpd_linear":upper_hpd_linear_95,"lower_hpd_linear":lower_hpd_linear_95,
                                                   "upper_hpd_linear_50":upper_hpd_linear_50, "lower_hpd_linear_50":lower_hpd_linear_50,
                                                  })
                new_df = new_df.append(local_df)
                #print(new_df)
            except:
                pass
            
    return(new_df)

In [None]:
trees =  "../multitree_coalescent/results/trees_10_07_case_prior.trees"
log_file_path = "../multitree_coalescent/results/updated_multicoal_updated_case_prior_la_clusters_with_metadata_10_07.log"



In [None]:
all_trees = trees
burnin_percent = 0.3
taxa_lines = get_taxa_lines(all_trees)
burnin = get_burnin_value(all_trees, burnin_percent)
print(burnin)



In [None]:
#identify each migration jump across posterior set of trees
migrations_df = run_mig_counts(all_trees, traitType = "obs")

In [None]:
migrations_df.head()

In [None]:
migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)
migrations_df['year-week'] = migrations_df['calendar_date'].map(convert_format)


In [None]:
migrations_df

In [None]:
migrations_for_plot = migrations_df.groupby(["migration_event_number"])["date",'parent_tmrca', "chain_tmrca", "chain_latest_tip", "size_of_chain"].median().reset_index()
migrations_for_plot = migrations_for_plot.sort_values(by=['parent_tmrca']).reset_index()

In [None]:
migrations_for_plot["calendar_date"] = migrations_for_plot.date.map(convert_partial_year)
migrations_for_plot['month'] = migrations_for_plot['calendar_date'].map(convert_format_month_only)

In [None]:
migrations_for_plot

In [None]:
migrations_df.groupby(["migration_event_number"])["date",'parent_tmrca', "chain_tmrca", "chain_latest_tip", "size_of_chain"]





In [None]:
migrations_for_plot

In [None]:
migrations_for_export = migrations_for_plot.copy()

In [None]:
migrations_for_export['length_of_chain'] = migrations_for_export.chain_latest_tip - migrations_for_export.date
migrations_for_export['length_of_chain_days'] = migrations_for_export.length_of_chain.apply(decimal_to_days)
migrations_for_export['calendar_date_of_import'] = migrations_for_export.date.map(convert_partial_year)
migrations_for_export['latest_case_of_chain_calendar_date'] = migrations_for_export.chain_latest_tip.map(convert_partial_year)
migrations_for_export= migrations_for_export[["migration_event_number", "calendar_date_of_import","latest_case_of_chain_calendar_date", "size_of_chain", "length_of_chain_days"]]


In [None]:
migrations_for_export.to_csv("importation_transmission_chains.csv")


In [None]:
migrations_for_export

In [None]:
migration_rates_f = read_in_forward_migration_rates_mascot(log_file_path)
mig_df_f = pd.DataFrame.from_dict(migration_rates_f)

burnin_percent = 0.3
print(len(mig_df_f))
rows_to_remove = int(len(mig_df_f)* burnin_percent)
mig_df_f = mig_df_f.iloc[rows_to_remove:]

print(len(mig_df_f))
mig_df_f = mig_df_f.reset_index()
mig_df_f.head()

In [None]:
mig_summary = generate_summary_mig_df(mig_df_f)


In [None]:
test_mig = mig_summary
test_mig['days'] = (test_mig.interval.astype(int))*7
test_mig['date'] = dt.strptime("2024-09-12",  "%Y-%m-%d") - test_mig.days.map(timedelta)
test_mig["decimal_date"]= test_mig.date.map(toYearFraction)

In [None]:
test_mig.head()

In [None]:
Ne_skyline = read_in_Ne_changes_mascot(log_file_path)

In [None]:
Ne_df = pd.DataFrame.from_dict(Ne_skyline)
print(len(Ne_df))
Ne_df

burnin_percent = 0.3
print(len(Ne_df))
rows_to_remove = int(len(Ne_df)* burnin_percent)
Ne_df = Ne_df.iloc[rows_to_remove:]

print(len(Ne_df))
Ne_df = Ne_df.reset_index()
Ne_df.head()

In [None]:
ne_summary = generate_summary_ne_df(Ne_df)
test_ne = ne_summary
test_ne['days'] = (test_ne.interval.astype(int))*7
test_ne['date'] = dt.strptime("2024-09-12",  "%Y-%m-%d") - test_ne.days.map(timedelta)
test_ne["decimal_date"]= test_ne.date.map(toYearFraction)

In [None]:
test_ne.head()

In [None]:
colors = ["#D0A854",
          "#2664A5",
          "#A76BB1",
          "#D07954",
          "#356D4C",
          "#B9B9B9"
         ]

In [None]:
fig,ax = plt.subplots(figsize=(16,12),facecolor='w')


    
# set blank white face for background    
ax.set_facecolor('white')
# remove grid 
ax.grid(False)


for index, values in migrations_for_plot.iterrows():
   # print(index)
    clust = index + 10
    if values.size_of_chain <2:
        col = colors[5]
    elif (values.size_of_chain >1) & (values.size_of_chain <5):
        col = colors[1]
    elif (values.size_of_chain >4) & (values.size_of_chain <11):
        col = colors[4]
    else:
        col = colors[3]
    linewidth = 3
    
    ax.scatter([values.date, values.date], [clust, clust], color=col, linewidth=linewidth)
    ax.plot([values.date, values.chain_latest_tip], [clust, clust], color=col, linewidth=1, linestyle = ":")

   # ax.plot([mrca[0], mrca[1]], [clust, clust], color=col, linewidth=linewidth)
    # add small vertical lines at the start and end of each mrca
    # ax.plot([mrca[0], mrca[0]], [clust-0.2, clust+0.2], color=col, linewidth=1)
    # ax.plot([mrca[1], mrca[1]], [clust-0.2, clust+0.2], color=col, linewidth=1)

# # set ylabel, with a long arrow at the end
#ax.set_ylabel('Importation (from earliest to latest) →', fontsize=fontsize)
fc = colors[0]
ec = colors[0]
ax2 = ax.twinx()

ax2.plot(test_mig.decimal_date,test_mig["mean_mig_linear"],color=fc,ls='--',lw=2)

#ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
ax2.fill_between(test_mig.decimal_date,test_mig.lower_hpd_linear_50,test_mig.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
ax2.plot(test_mig.decimal_date,test_mig.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax2.plot(test_mig.decimal_date,test_mig.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax2.set_ylim(0,30)

fc = colors[2]
ec = colors[2]
# ax3 = ax.twinx()

# ax3.plot(test_ne.decimal_date,test_ne["mean_Ne_linear"],color=fc,ls='--',lw=2)

# #ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
# ax3.fill_between(test_ne.decimal_date,test_ne.lower_hpd_linear_50,test_ne.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
# ax3.plot(test_ne.decimal_date,test_ne.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax3.plot(test_ne.decimal_date,test_ne.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax3.set_ylim(0,30)

legend_list = [mlines.Line2D([0], [0], color=colors[5], lw=4, label='Singletons'),
                mlines.Line2D([0], [0], color=colors[1], lw=4, label='2-4'),
                mlines.Line2D([0], [0], color=colors[4], lw=4, label='5-9'),
                mlines.Line2D([0], [0], color=colors[3], lw=4, label='10+')]
ax.legend(handles=legend_list, title='Size of Local transmission cluster', fontsize=13, title_fontsize=13, loc='center right')

xDates=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]
xDates2=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]


every=1
[ax.axvspan(bt.decimalDate(xDates2[x]),bt.decimalDate(xDates2[x])+1/float(12),facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xDates2),2)]
ax.set_xticks([bt.decimalDate(x)+1/24.0 for x in xDates if (int(x.split('-')[1])-1)%every==0])

ax.set_xticklabels([convertDate(x,'%Y-%m-%d','%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xDates if (int(x.split('-')[1])-1)%every==0])
ax.tick_params(axis='x',labelsize=10,size=0)  

#ax1.xaxis.tick_bottom()
ax.yaxis.tick_left()

[ax2.spines[loc].set_visible(False) for loc in ['top','left']]
[ax.spines[loc].set_visible(False) for loc in ['top','right','left']]

ax.tick_params(axis='y',size=0)
ax.set_yticklabels([])
ax.set_ylim(-5,clust+80)
ax.set_xlim(2022,2025.1)

ax.xaxis.set_tick_params(which='both', top=False, bottom=True, labelbottom=True)
ax.yaxis.set_tick_params(which='both', right=False, left=True, labelleft=True)
#plt.savefig('../figures/mpox_la_introduction_rate.png',dpi=300,bbox_inches='tight')


In [None]:
fig,ax = plt.subplots(figsize=(25,10),facecolor='w')


    
# set blank white face for background    
ax.set_facecolor('white')
# remove grid 
ax.grid(False)


for index, values in migrations_for_plot.iterrows():
   # print(index)
    clust = index + 10
    if values.size_of_chain <2:
        col = colors[5]
    elif (values.size_of_chain >1) & (values.size_of_chain <5):
        col = colors[1]
    elif (values.size_of_chain >4) & (values.size_of_chain <11):
        col = colors[4]
    else:
        col = colors[3]
    linewidth = 3
    persist_days = decimal_to_days(values.chain_latest_tip - values.date)
    radius = np.sqrt(values.size_of_chain/np.pi)*300.0
    
    ax.scatter(values.date,persist_days,s=radius,facecolor=col,edgecolor='k',lw=2,zorder=200) ## add big circle at base of tree to indicate origin
    ax.plot([values.date, values.chain_latest_tip], [persist_days, persist_days], color=col, linewidth=1, linestyle = ":")

    
fc = colors[0]
ec = colors[0]
ax2 = ax.twinx()

ax2.plot(test_mig.decimal_date,test_mig["mean_mig_linear"],color=fc,ls='--',lw=2)

#ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
ax2.fill_between(test_mig.decimal_date,test_mig.lower_hpd_linear_50,test_mig.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
ax2.plot(test_mig.decimal_date,test_mig.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax2.plot(test_mig.decimal_date,test_mig.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax2.set_ylim(0,30)

fc = colors[2]
ec = colors[2]
ax3 = ax.twinx()

ax3.plot(test_ne.decimal_date,test_ne["mean_Ne_linear"],color=fc,ls='--',lw=2)

#ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
ax3.fill_between(test_ne.decimal_date,test_ne.lower_hpd_linear_50,test_ne.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
ax3.plot(test_ne.decimal_date,test_ne.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax3.plot(test_ne.decimal_date,test_ne.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
ax3.set_ylim(0,30)
    #ax.scatter([values.parent_tmrca, values.chain_tmrca], [clust, clust], color=col, linewidth=linewidth)
    #ax.plot([values.chain_tmrca, values.chain_latest_tip], [clust, clust], color=col, linewidth=1, linestyle = ":")

   # ax.plot([mrca[0], mrca[1]], [clust, clust], color=col, linewidth=linewidth)
    # add small vertical lines at the start and end of each mrca
    # ax.plot([mrca[0], mrca[0]], [clust-0.2, clust+0.2], color=col, linewidth=1)
    # ax.plot([mrca[1], mrca[1]], [clust-0.2, clust+0.2], color=col, linewidth=1)

# # set ylabel, with a long arrow at the end
#ax.set_ylabel('Importation (from earliest to latest) →', fontsize=fontsize)

legend_list = [mlines.Line2D([0], [0], color=colors[5], lw=4, label='Singletons'),
                mlines.Line2D([0], [0], color=colors[1], lw=4, label='2-4'),
                mlines.Line2D([0], [0], color=colors[4], lw=4, label='5-9'),
                mlines.Line2D([0], [0], color=colors[3], lw=4, label='10+')]
ax.legend(handles=legend_list, title='Size of Local transmission cluster', fontsize=13, title_fontsize=13, loc='center right')

xDates=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]
xDates2=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]


every=1
[ax.axvspan(bt.decimalDate(xDates2[x]),bt.decimalDate(xDates2[x])+1/float(12),facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xDates2),2)]
ax.set_xticks([bt.decimalDate(x)+1/24.0 for x in xDates if (int(x.split('-')[1])-1)%every==0])

ax.set_xticklabels([convertDate(x,'%Y-%m-%d','%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xDates if (int(x.split('-')[1])-1)%every==0])
ax.tick_params(axis='x',labelsize=10,size=0)  

#ax1.xaxis.tick_bottom()
ax.yaxis.tick_left()

[ax2.spines[loc].set_visible(False) for loc in ['top',]]
[ax.spines[loc].set_visible(False) for loc in ['top','right']]

ax.tick_params(axis='y',size=0)
#ax.set_yticklabels([])
ax.set_ylim(0,500)
ax.set_xlim(2022,2025.1)

ax.xaxis.set_tick_params(which='both', top=False, bottom=True, labelbottom=True)
ax.yaxis.set_tick_params(which='both', right=False, left=True, labelleft=True)
#plt.savefig('../figures/mpox_la_introduction_persistence_with_ne.png',dpi=300,bbox_inches='tight')


In [None]:
fig,ax = plt.subplots(figsize=(16,12),facecolor='w')

ins_ax = ax.inset_axes([.3, .65, .5, .3])  # [x, y, width, height] w.r.t. ax
    
# set blank white face for background    
ax.set_facecolor('white')
# remove grid 
ax.grid(False)


for index, values in migrations_for_plot.iterrows():
   # print(index)
    clust = index + 10
    if values.size_of_chain <2:
        col = colors[5]
    elif (values.size_of_chain >1) & (values.size_of_chain <5):
        col = colors[1]
    elif (values.size_of_chain >4) & (values.size_of_chain <11):
        col = colors[4]
    else:
        col = colors[3]
    linewidth = 3
    
    ax.plot([values.date, values.date], [0, values.size_of_chain], color=col, linewidth=linewidth)
    #ax.plot([values.chain_tmrca, values.chain_latest_tip], [0, clust], color=col, linewidth=1, linestyle = ":")

   # ax.plot([mrca[0], mrca[1]], [clust, clust], color=col, linewidth=linewidth)
    # add small vertical lines at the start and end of each mrca
    # ax.plot([mrca[0], mrca[0]], [clust-0.2, clust+0.2], color=col, linewidth=1)
    # ax.plot([mrca[1], mrca[1]], [clust-0.2, clust+0.2], color=col, linewidth=1)

# # set ylabel, with a long arrow at the end
#ax.set_ylabel('Importation (from earliest to latest) →', fontsize=fontsize)


    if values.date > 2022.99:
    
        ins_ax.plot([values.date, values.date], [0, values.size_of_chain], color=col, linewidth=linewidth)
#plt.xticks([]); plt.yticks([])  # strip ticks, which collide w/ main ax


legend_list = [mlines.Line2D([0], [0], color=colors[5], lw=4, label='Singletons'),
                mlines.Line2D([0], [0], color=colors[1], lw=4, label='2-4'),
                mlines.Line2D([0], [0], color=colors[4], lw=4, label='5-9'),
                mlines.Line2D([0], [0], color=colors[3], lw=4, label='10+')]
ax.legend(handles=legend_list, title='Size of Local transmission cluster', fontsize=15, title_fontsize=10, loc='center right')

xDates=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,12)]
xDates2=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]


every=1
[ax.axvspan(bt.decimalDate(xDates2[x]),bt.decimalDate(xDates2[x])+1/float(12),facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xDates2),2)]
ax.set_xticks([bt.decimalDate(x)+1/24.0 for x in xDates if (int(x.split('-')[1])-1)%every==0])

ax.set_xticklabels([convertDate(x,'%Y-%m-%d','%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xDates if (int(x.split('-')[1])-1)%every==0])
ax.tick_params(axis='x',labelsize=10,size=0)  

#ax1.xaxis.tick_bottom()
ax.yaxis.tick_left()

[ax.spines[loc].set_visible(False) for loc in ['top','right','left']]

ax.tick_params(axis='y',size=0)
ax.set_yticklabels([])
ax.set_ylim(0,60)
ax.set_xlim(2022,2025.1)

ax.xaxis.set_tick_params(which='both', top=False, bottom=True, labelbottom=True)
ax.yaxis.set_tick_params(which='both', right=False, left=True, labelleft=True)

In [None]:
migrations_for_plot

In [None]:
fig,ax = plt.subplots(figsize=(25,10),facecolor='w')


    
# set blank white face for background    
ax.set_facecolor('white')
# remove grid 
ax.grid(False)

migrations_for_plot = migrations_for_plot.sort_values(by=['month'])

for index, values in migrations_for_plot.iterrows():
   # print(index)
    clust = index + 10
    if values.size_of_chain <2:
        col = colors[5]
    elif (values.size_of_chain >1) & (values.size_of_chain <5):
        col = colors[1]
    elif (values.size_of_chain >4) & (values.size_of_chain <11):
        col = colors[4]
    else:
        col = colors[3]
    linewidth = 3
    persist_days = decimal_to_days(values.chain_latest_tip - values.date)
    radius = np.sqrt(values.size_of_chain/np.pi)*300.0
    if values.date > 2022.5822: ##this is done because there are a lot of intros during early summer 2020 and we want more stable seasonal dynamics
        ax.scatter(values.month,persist_days,s=radius,facecolor=col,edgecolor='k',lw=2,zorder=200) ## add big circle at base of tree to indicate origin
    #ax.plot([values.date, values.chain_latest_tip], [persist_days, persist_days], color=col, linewidth=1, linestyle = ":")

#axs.violinplot(migrations_for_plot.month, positions=[index], widths=0.8,
#                                    showmedians=True, bw_method=0.6, showextrema=False)
    
# fc = colors[0]
# ec = colors[0]
# ax2 = ax.twinx()

# ax2.plot(test_mig.decimal_date,test_mig["mean_mig_linear"],color=fc,ls='--',lw=2)

# #ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
# ax2.fill_between(test_mig.decimal_date,test_mig.lower_hpd_linear_50,test_mig.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
# ax2.plot(test_mig.decimal_date,test_mig.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax2.plot(test_mig.decimal_date,test_mig.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax2.set_ylim(0,30)

# fc = colors[2]
# ec = colors[2]
# ax3 = ax.twinx()

# ax3.plot(test_ne.decimal_date,test_ne["mean_Ne_linear"],color=fc,ls='--',lw=2)

# #ax.scatter(caseDates,[0.0]*len(caseDates),alpha=0.2,s=200,marker='|',lw=3,facecolor='k',zorder=100)
# ax3.fill_between(test_ne.decimal_date,test_ne.lower_hpd_linear_50,test_ne.upper_hpd_linear_50,alpha=0.05,facecolor=fc,edgecolor=ec,zorder=1000)
# ax3.plot(test_ne.decimal_date,test_ne.lower_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax3.plot(test_ne.decimal_date,test_ne.upper_hpd_linear_50,color=fc,lw=1,zorder=1000)
# ax3.set_ylim(0,30)
    #ax.scatter([values.parent_tmrca, values.chain_tmrca], [clust, clust], color=col, linewidth=linewidth)
    #ax.plot([values.chain_tmrca, values.chain_latest_tip], [clust, clust], color=col, linewidth=1, linestyle = ":")

   # ax.plot([mrca[0], mrca[1]], [clust, clust], color=col, linewidth=linewidth)
    # add small vertical lines at the start and end of each mrca
    # ax.plot([mrca[0], mrca[0]], [clust-0.2, clust+0.2], color=col, linewidth=1)
    # ax.plot([mrca[1], mrca[1]], [clust-0.2, clust+0.2], color=col, linewidth=1)

# # set ylabel, with a long arrow at the end
#ax.set_ylabel('Importation (from earliest to latest) →', fontsize=fontsize)

legend_list = [mlines.Line2D([0], [0], color=colors[5], lw=4, label='Singletons'),
                mlines.Line2D([0], [0], color=colors[1], lw=4, label='2-4'),
                mlines.Line2D([0], [0], color=colors[4], lw=4, label='5-9'),
                mlines.Line2D([0], [0], color=colors[3], lw=4, label='10+')]
ax.legend(handles=legend_list, title='Size of Local transmission cluster', fontsize=13, title_fontsize=13, loc='center right')

xDates=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]
xDates2=['%04d-%02d-01'%(y,m) for y in range(2022,2025) for m in range(1,13)]


every=1
#[ax.axvspan(bt.decimalDate(xDates2[x]),bt.decimalDate(xDates2[x])+1/float(12),facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xDates2),2)]
#ax.set_xticks([bt.decimalDate(x)+1/24.0 for x in xDates if (int(x.split('-')[1])-1)%every==0])

#ax.set_xticklabels([convertDate(x,'%Y-%m-%d','%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xDates if (int(x.split('-')[1])-1)%every==0])
#ax.tick_params(axis='x',labelsize=10,size=0)  

#ax1.xaxis.tick_bottom()
ax.yaxis.tick_left()

[ax2.spines[loc].set_visible(False) for loc in ['top',]]
[ax.spines[loc].set_visible(False) for loc in ['top','right']]

ax.tick_params(axis='y',size=0)
#ax.set_yticklabels([])
#ax.set_ylim(0,500)
#ax.set_xlim(2022,2025.1)

ax.xaxis.set_tick_params(which='both', top=False, bottom=True, labelbottom=True)
ax.yaxis.set_tick_params(which='both', right=False, left=True, labelleft=True)
#plt.savefig('../figures/mpox_la_introduction_persistence_with_ne.png',dpi=300,bbox_inches='tight')


In [None]:
migrations_for_plot["clust_cat"] = "Singletons"
#migrations_for_plot.clust_cat[migrations_for_plot.size_of_chain <2] = 1
migrations_for_plot.clust_cat[(migrations_for_plot.size_of_chain >1) & (migrations_for_plot.size_of_chain <5)] = "2-4"
migrations_for_plot.clust_cat[(migrations_for_plot.size_of_chain >4) & (migrations_for_plot.size_of_chain <11)] = "5-9"
migrations_for_plot.clust_cat[migrations_for_plot.size_of_chain >10] = "10+"



In [None]:
#fig,ax = plt.subplots(figsize=(25,10),facecolor='w')

sns.set_style('white')
#iris = sns.load_dataset('iris')
palette = 'Set2'
migrations_for_plot["persist_days"] = (migrations_for_plot.chain_latest_tip - migrations_for_plot.date).apply(decimal_to_days)

ax = sns.violinplot(x="month", y="persist_days", data=migrations_for_plot[migrations_for_plot.date >2022.5822],  dodge=False,
                    palette = [colors[5]] * 12 ,
                    scale="width", inner=None, cut = 0, saturation = 1)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
for violin in ax.collections:
    bbox = violin.get_paths()[0].get_extents()
    x0, y0, width, height = bbox.bounds
    violin.set_clip_path(plt.Rectangle((x0, y0), width / 2, height, transform=ax.transData))

sns.boxplot(x="month", y="persist_days", data=migrations_for_plot[migrations_for_plot.date >2022.5822], saturation=1, showfliers=False,
            width=0.75, boxprops={'zorder': 3, 'facecolor': 'none'}, ax=ax)
old_len_collections = len(ax.collections)
sns.scatterplot(x="month", y="persist_days", hue="clust_cat", size="size_of_chain",
            sizes=(40, 400), alpha=0.7, linewidth = 1,edgecolor = "k", palette=[colors[3],  "gray",colors[1] ,colors[4],],
             data=migrations_for_plot[migrations_for_plot.date >2022.5822], ax=ax)
#sns.stripplot(x="month", y="persist_days", data=migrations_for_plot[migrations_for_plot.date >2022.5822],hue = "clust_cat", linewidth = 1, size = 5, palette=["gray",colors[1], colors[4], colors[3]], dodge=False, ax=ax)
for dots in ax.collections[old_len_collections:]:
    dots.set_offsets(dots.get_offsets() + np.array([0.15, 0]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
legend_list = [mlines.Line2D([], [], color=colors[5], lw=0, label='Singletons', marker = "o"),
                mlines.Line2D([0], [0], color=colors[1], lw=0, label='2-4', marker = "o"),
                mlines.Line2D([0], [0], color=colors[4], lw=0, label='5-9', marker = "o"),
                mlines.Line2D([0], [0], color=colors[3], lw=0, label='10+', marker = "o")]
ax.legend(handles = legend_list,title='Size of Local transmission cluster', fontsize=8, title_fontsize=10, loc='upper right')
ax.set_ylabel('Persistance of Transmission Chain (in days)', fontsize=12)
ax.set_xlabel('Month of Year', fontsize = 12)
ax.set_xticklabels(["Jan", "Feb", "Mar", "Apr", "May", "June", "July", "Aug", "Sep", "Oct", "Nov", "Dec"])
#plt.savefig('../figures/seasonality_persistance.png',dpi=300,bbox_inches='tight')
plt.show()


In [None]:
colors[3]

In [None]:
start_time = time.time()

mig = return_proportions_dataframe(migrations_df, "year-week")

total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print(total_time_minutes)

mig.head()

In [None]:
mig.reset_index(inplace = True, drop = True)


In [None]:
mig.head()

In [None]:
alt.Chart(mig, width = 750).mark_line( opacity = 1).encode(
    x=alt.X('year-week:T'),
    y=alt.Y('mean(transitions_in_time_interval):Q'))

In [None]:
error_bars = alt.Chart(mig).mark_errorbar(size = 100 , extent='ci').encode(
  x=alt.X('year-week:T', axis = alt.Axis(grid = False)),
  y=alt.Y('transitions_in_time_interval:Q', title = "Viral introductions into LA", axis = alt.Axis(grid = False))
).properties(
    width=800,
    height=300
)

points = alt.Chart(mig).mark_point(filled=True,  opacity = 0.55).encode(
  x=alt.X('year-week:T'),
  y=alt.Y('transitions_in_time_interval:Q', aggregate='mean')
).properties(
    width=800,
    height=300
)

lineplot4 =  alt.Chart(mig).mark_line(interpolate='monotone', opacity = 0.35).encode(
    x=alt.X('year-week:T'),
    y=alt.Y('mean(transitions_in_time_interval)')).properties(
    width=800,
    height=300
)


ave = error_bars + points +lineplot4
ave

In [None]:
points = alt.Chart(stats.reset_index()).mark_point(filled=True,  opacity = 0.55).encode(
  x=alt.X('year-week:O'),
  y=alt.Y('mean:Q',)
).properties(
    width=800,
    height=300
)

band = alt.Chart(stats.reset_index()).mark_area(filled=True,  opacity = 0.55).encode(
  x=alt.X('year-week:O'),
  y=alt.Y('ci95_hi'),
    y2= alt.Y2("ci95_lo")
).properties(
    width=800,
    height=300
)
points +band


In [None]:
mig.transitions_in_time_interval

In [None]:
mig_to_export = mig[["year-week", "tree_number", "total_transitions", "transitions_in_time_interval"]]; mig_to_export.head()

In [None]:
mig_to_export.to_csv("mpox_la_introductions_over_time.csv", sep = ",")

In [None]:
stats = mig_to_export.groupby("year-week")["transitions_in_time_interval"].agg(['mean', 'sem'])
stats['ci95_hi'] = stats['mean'] + 1.96* stats['sem']
stats['ci95_lo'] = stats['mean'] - 1.96* stats['sem']
print(stats)

In [None]:
stats.reset_index(inplace= True)

In [None]:
stats['mean'].tolist()

In [None]:
stats.to_csv("mpox_la_introductions_stats.csv", sep = ",")