## Script for quantifying jumps between regions as well as persistence times

In [None]:
import baltic as bt
import pandas as pd
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from datetime import datetime as dt
from datetime import timedelta
import time
from io import StringIO
import altair as alt
from altair import datum
import arviz as az
alt.data_transformers.disable_max_rows()

In [None]:
trees =  "../../mascot_glm/results/glm_mcc_map_randomkc_clusters_combined_new.typed.trees"

In [None]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)

In [None]:
#making decimal date from string dates adapted from stackoverflow (thank you coding geniuses)
def toYearFraction(date):
    def sinceEpoch(date): # returns seconds since epoch
        return time.mktime(date.timetuple())
    s = sinceEpoch

    year = date.year
    startOfThisYear = dt(year=year, month=1, day=1)
    startOfNextYear = dt(year=year+1, month=1, day=1)

    yearElapsed = s(date) - s(startOfThisYear)
    yearDuration = s(startOfNextYear) - s(startOfThisYear)
    fraction = yearElapsed/yearDuration

    return date.year + fraction

In [None]:
toYearFraction(dt.strptime("2022-03-06",  "%Y-%m-%d"))

In [None]:
#counts the number of lineages at a specific point in time. Adapted from baltic's function.
def countBL_ns(tree,t,attr='absoluteTime', region = 'North_King_County', condition=lambda x:True):
    return sum([(k.absoluteTime - k.parent.absoluteTime) for k in tree.Objects if getattr(k.parent,attr)!=None and getattr(k.parent,attr)<t<=getattr(k,attr) and condition(k) and k.parent.traits['typeTrait']!=None and k.parent.traits['typeTrait'] ==region])


## Calculating branch lengths over time

In [None]:
dates1= pd.date_range('2020-01','2022-04' , freq='1M').strftime('%Y-%m')

In [None]:
date_df = pd.DataFrame(dates1)
date_df = date_df.rename(columns = {0: 'yearmonth'})

date_df['first_day'] = pd.date_range('2020-01','2022-04' , freq='1M')-pd.offsets.MonthBegin(1)
date_df['last_day'] =  pd.date_range('2020-01','2022-04' , freq='1M')

date_df.first_day= date_df.first_day.map(toYearFraction)
date_df.last_day= date_df.last_day.map(toYearFraction)

date_df.head()

In [None]:
def bl_overtime(date_df, tree, condition=lambda x:True):
    output_dict = {}
    for index, row in date_df.iterrows():
        north_bl = []
        south_bl = []
        for k in tree.Objects:
            try:
                if (k.parent.absoluteTime != None) and (row.first_day<=k.parent.absoluteTime<row.last_day) and (condition(k)) and (k.parent.traits['typeTrait'] =="North_King_County"):
                    if k.absoluteTime > row.last_day:
                        child_time = row.last_day 
                    else:
                        child_time = k.absoluteTime
                    
                    bl = child_time- k.parent.absoluteTime
                    north_bl.append(bl)
                    
            except KeyError:
                continue
 
            try:
                if (k.parent.absoluteTime != None) and (row.first_day<=k.parent.absoluteTime<row.last_day) and (condition(k)) and (k.parent.traits['typeTrait'] =="South_King_County"):
                    if k.absoluteTime > row.last_day:
                        child_time = row.last_day 
                    else:
                        child_time = k.absoluteTime
                    
                    bl = child_time- k.parent.absoluteTime
                    south_bl.append(bl)
                    
            except KeyError:
                continue
                
        total_bl_n = sum(north_bl)
        
        total_bl_s = sum(south_bl)
        
        output_dict[index] = {"yearmonth":row.yearmonth, "total_bl_n":total_bl_n, "total_bl_s": total_bl_s}    
    return(output_dict)  

In [1]:
def enumerate_migration_events(tree):
        
    output_dict = {}
    migration_events_counter = 0
        
    for k in tree.Objects:
        if k.traits == {}:
            k.traits = {'obs': 0.0, 'typeTrait':"root"}
        elif k.traits['obs'] ==  0.0:
            k.traits = {'obs': 0.0, 'typeTrait':"none"}
        
        trait = k.traits['typeTrait']
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        else:

            
            parent_trait = parent_node.traits['typeTrait']
            
            if trait != parent_trait:
                migration_events_counter += 1
                migration_event = parent_trait + "-to-" + trait
                migration_date = parent_node.absoluteTime  
                if trait == "South_King_County" or trait == "North_King_County":


                

                # write to output dictionary
                output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date, 'lineages': concurrent_lineages,
                                                         "branch_lengths" : concurrent_bl,
                                                        "parent_host":parent_trait,
                                                        "child_host": trait}
    
    return(output_dict)

In [None]:
burnin_percent = 0.1

burnin = get_burnin_value(trees, burnin_percent)
print(burnin)

In [None]:
start_time = time.time()

with open(trees, "r") as infile:
    
    tree_counter = 0
    trees_processed = 0
    migrations_dict = {}
    time_dict = {}
    
    for line in infile:
       # print(line)
        if 'tree STATE_' in line:
            tree_counter += 1
            
            if tree_counter > burnin:
                temp_tree = StringIO(line)
                tree = bt.loadNexus(temp_tree, absoluteTime = False)
                tree.setAbsoluteTime(2022.1753424657534)
                trees_processed += 1

                # iterate through the tree and pull out all migration events
                migrations_dict[tree_counter] = enumerate_migration_events(tree)
                time_dict[tree_counter] = bl_overtime(date_df, tree)

# print the amount of time this took
total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", trees_processed, "trees")

In [None]:
time_df = pd.DataFrame.from_dict({(i,j): time_dict[i][j] 
                           for i in time_dict.keys() 
                           for j in time_dict[i].keys()},
                       orient='index')
time_df.head()
time_df.reset_index(inplace=True)
time_df.rename(columns={'level_0': 'tree_number', 'level_1': 'event_number', 'yearmonth': 'year-month'}, inplace=True)
time_df.head()

In [None]:

alt.Chart(time_df, width = 750).mark_line(size = 10, opacity = 0.2).encode(
    x=alt.X('year-month:O'),
    y=alt.Y('total_bl_n:Q'),
    color=alt.Color('tree_number:N'))

In [None]:

alt.Chart(time_df, width = 750).mark_line(size = 10, opacity = 0.2).encode(
    x=alt.X('year-month:O'),
    y=alt.Y('total_bl_s:Q'),
    color=alt.Color('tree_number:N'))

In [None]:
migrations_df = pd.DataFrame.from_dict({(i,j): migrations_dict[i][j] 
                           for i in migrations_dict.keys() 
                           for j in migrations_dict[i].keys()},
                       orient='index')
migrations_df
migrations_df.reset_index(inplace=True)
migrations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
migrations_df.head()

In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_partial_year(number):

    year = int(number)
    d = timedelta(days=(number - year)*(365 + is_leap(year)))
    day_one = dt(year,1,1)
    date = d + day_one
    date = dt.strftime(date, '%Y-%m-%d')
    return date

In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_persistence(number):

    
    d = timedelta(days=(number)*(365))
    
    return d.total_seconds()

In [None]:
def is_leap(number):
    if number == 2020:
        leap = 1
    else:
        leap = 0
    return leap

In [None]:
def convert_format(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%Y-%m')
    return date

In [None]:
migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)
migrations_df['year-month'] = migrations_df['calendar_date'].map(convert_format)

In [None]:
merged_mr_df = pd.merge(migrations_df, time_df, on = ['tree_number', 'year-month'], how = 'left')
merged_mr_df.head()

### the function below extracts mig jumps for every month of the year. 

In [None]:

def return_proportions_dataframe(input_df, time_unit):
    output_df = pd.DataFrame()
    north_kc = ["South_King_County-to-North_King_County", "none-to-North_King_County"]
    south_kc = ["North_King_County-to-South_King_County", "none-to-South_King_County"]
    
    for tree_number in set(input_df['tree_number'].tolist()):
        local_df1 = input_df[input_df['tree_number'] == tree_number]
        
        for v in list(set(input_df['type'].tolist())):
            local_df = local_df1[local_df1['type'] == v]
            total_transitions = len(local_df)

            for item in set(input_df[time_unit].tolist()):
                local_df2 = local_df[local_df[time_unit] == item]
                transitions_in_time_unit = len(local_df2)
                average_lin = local_df2.lineages.mean()
                average_bl = 0
                if v in north_kc:
                    average_bl = local_df2.total_bl_n.mean()
                elif v in south_kc: 
                    average_bl = local_df2.total_bl_s.mean()
                               
                if total_transitions != 0 :
                    prop_transitions_in_time_unit = transitions_in_time_unit/total_transitions
                else:
                    prop_transitions_in_time_unit = 0
                    
                if transitions_in_time_unit != 0 and average_bl != 0:
                    mig_per_bl = transitions_in_time_unit/average_bl
                else:
                    mig_per_bl = 0
                    
                if transitions_in_time_unit != 0:
                    mig_per_lineage = transitions_in_time_unit/average_lin
                else:
                    mig_per_lineage = 0

                to_add = pd.DataFrame({"migration_direction":[v],time_unit:[item],"tree_number":[tree_number], 
                                       "total_transitions":[total_transitions],
                                       "transitions_in_time_interval":[transitions_in_time_unit],
                                      "proportion_transitions_in_time_interval":[prop_transitions_in_time_unit], "mig_per_lineage":[mig_per_lineage], "mig_per_bl":[mig_per_bl]})
                output_df = output_df.append(to_add)
            
    return(output_df)

In [None]:
start_time = time.time()

mig = return_proportions_dataframe(merged_mr_df, "year-month")

total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print(total_time_minutes)

mig.head()

In [None]:
mig.migration_direction.value_counts()

In [None]:
notwanted = ["none-to-South_King_County", "none-to-North_King_County","root-to-none" ]
mig_clean = mig
for mr in notwanted:
    mig_clean = mig_clean[mig_clean['migration_direction'] != mr]


In [None]:
mig_coll = mig
mig_coll['region'] = np.nan
mig_coll.loc[(mig_coll.migration_direction == "none-to-South_King_County")| (mig_coll.migration_direction == "North_King_County-to-South_King_County"), "region"] = "South"
mig_coll.loc[(mig_coll.migration_direction == "none-to-North_King_County")| (mig_coll.migration_direction == "South_King_County-to-North_King_County"), "region"] = "North"
mig_coll = mig_coll.dropna(subset=['region'])
mig_coll

In [None]:
#plot total introductions over time for N and S

In [None]:
mig_coll_short = mig_coll[mig_coll['year-month'] > "2020-01"] #before march it's mostly sparsely north 

## subsetting to after march 2020


In [None]:
#highlighting important NPIs in WA
data = {'date': [ "2020-03-23", "2020-06-01", "2020-11-18", "2021-02-14"], 'event':[ "Stay at home", "Stay at home lifted", "Closing restaurants", "Reopening restaurants"]}

npidf = pd.DataFrame(data)
npidf.date = pd.to_datetime(npidf.date)

rule = alt.Chart(npidf).mark_rule(
    color="black",
    strokeWidth=2, 
    opacity = 0.3
).encode(
    alt.X('date:T', axis=alt.Axis(title=None))
).properties(
    width=800,
    height=300
)

text = alt.Chart(npidf).mark_text(
    align='left',
    baseline='middle',
    dx=2,
    dy=-135,
    size=11
).encode(
    alt.X('date:T',axis=alt.Axis(title=None)),
    text='event',
    color=alt.value('#000000')
).properties(
    width=800,
    height=300
)

In [None]:
mig_short = mig_clean[(mig_clean['year-month'] > "2020-01") & (mig_clean['year-month'] < "2022-04")] #before march it's mostly sparsely north
mig_short.to_csv("../data-files/migration_jumps_df.csv")

In [None]:
## jumps per BL

In [None]:
domain = ['North_King_County-to-South_King_County', 'South_King_County-to-North_King_County']
#range_ = ['red', 'green']


stripplot3 =  alt.Chart(mig_short, width = 750).mark_circle(size = 10, opacity = 0.2).encode(
    x=alt.X('year-month:O'),
    y=alt.Y('mig_per_bl:Q'),
    color=alt.Color('migration_direction:N'))


lineplot3 =  alt.Chart(mig_short, width = 750).mark_line(interpolate='monotone', clip = True).encode(
    x=alt.X('year-month:T',axis=alt.Axis( grid=False)),
    y=alt.Y('median(mig_per_bl)',axis=alt.Axis(title="Number of Migration Events (Normalized by Branch Length)", grid=False)),
    color=alt.Color('migration_direction:N')).properties(
    width=850,
    height=300
)

band3 = alt.Chart(mig_short).mark_errorband(extent='iqr', interpolate='monotone', clip = True).encode(
    x=alt.X('year-month:T'),
    y=alt.Y('mig_per_bl',axis=alt.Axis(title="", grid=False)), 
    color =alt.Color('migration_direction:N')).properties(
    width=850,
    height=300
) 
jumps_per_bl = lineplot3 + band3 + rule + text 

In [None]:
jumps_per_bl = lineplot3 + band3 + rule + text 

In [None]:
jumps_per_bl

In [None]:
#jumps_per_bl.save("jumps_per_bl.png")

In [None]:
outside_df = mig[(mig['migration_direction'] == "none-to-North_King_County") | (mig['migration_direction'] == "none-to-South_King_County") ]

In [None]:
lineplot4 =  alt.Chart(outside_df, width = 750).mark_line(interpolate='monotone').encode(
    x=alt.X('year-month:T'),
    y=alt.Y('mean(mig_per_lineage)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)

band4 = alt.Chart(outside_df).mark_errorband(extent='ci', interpolate='monotone').encode(
    x=alt.X('year-month:T'),
    y=alt.Y('mig_per_lineage'), 
    color =alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)
    

In [None]:
lineplot4 +band4 + rule + text 

In [None]:
chart = alt.Chart(mig_short).mark_bar().encode(
    alt.X('year-month:O'), 
    alt.Color('migration_direction:N'),
    alt.Y("mean(mig_per_bl)", stack="normalize", title='Mig_per_bl'))



In [None]:
chart

In [None]:
error_bars = alt.Chart(mig_short).mark_errorbar(extent='ci').encode(
  x=alt.X('year-month:O'),
  y=alt.Y('mig_per_bl:Q'), color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
).transform_filter(
    (datum.mig_per_bl < 100)
)

points = alt.Chart(mig_short).mark_point(filled=True,  opacity = 0.55).encode(
  x=alt.X('year-month:O'),
  y=alt.Y('mig_per_bl:Q', aggregate='mean'),
    color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
).transform_filter(
    (datum.mig_per_bl < 100)
)

lineplot4 =  alt.Chart(mig_short).mark_line(interpolate='monotone', opacity = 0.35).encode(
    x=alt.X('year-month:O'),
    y=alt.Y('mean(mig_per_bl)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
).transform_filter(
    (datum.mig_per_bl < 100)
)


ave = error_bars + points +lineplot4
ave

In [None]:
error_bars = alt.Chart(mig_short).mark_errorbar(extent='ci').encode(
  x=alt.X('mig_per_bl:Q', scale=alt.Scale(zero=False)),
  y=alt.Y('migration_direction:N')
)

points = alt.Chart(mig_short).mark_point(filled=True, color='black').encode(
  x=alt.X('mig_per_bl:Q', aggregate='mean'),
  y=alt.Y('migration_direction:N'),
)

ave = error_bars + points
ave

In [None]:
summary = chart | ave
#summary.save("summary_jumps.png")

### Working on calculating persistence times

In [None]:
def estimate_persistence(tree):
        
    output_dict = {}
    persistence_counter = 0
        
    for k in tree.Objects:
        
        if k.traits == {}:
            k.traits = {'obs': 0.0, 'typeTrait':"root"}
        elif k.traits['obs'] ==  0.0:
            k.traits = {'obs': 0.0, 'typeTrait':"none"}
        
        trait = k.traits['typeTrait']
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        elif k.branchType=='leaf':
            tip_date= k.absoluteTime
            tip_name = k.name
            confirm = True
            while confirm == True:
                try:
                    parent_trait = parent_node.traits['typeTrait']
                    
                    if trait != parent_trait:
                        persistence_counter += 1

                        migration_event = parent_trait + "-to-" + trait
                        migration_date = parent_node.absoluteTime
                        
                        persistence =  tip_date - migration_date

    
                        # write to output dictionary
                        output_dict[persistence_counter] = {"type":migration_event, "migration date":migration_date,"tip date": tip_date, "persistance": persistence, "tip_name" : tip_name,
                                                        "parent_host":parent_trait,
                                                        "child_host": trait}
                        confirm = False
                    else:    
                        parent_node = parent_node.parent    

                except:
                    break
                    

        else:
            continue
    
    return(output_dict)

In [None]:
start_time = time.time()

with open(trees, "r") as infile:
    
    tree_counter = 0
    trees_processed = 0
    persistence_dict = {}
    
    for line in infile:
       # print(line)
        if 'tree STATE_' in line:
            tree_counter += 1
            

            temp_tree = StringIO(line)
            tree = bt.loadNexus(temp_tree, absoluteTime = False)
            tree.setAbsoluteTime(2022.1753424657534)
            trees_processed += 1

            # iterate through the tree and pull out all migration events
            persistence_dict[tree_counter] =  estimate_persistence(tree)
            
# print the amount of time this took
total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", trees_processed, "trees")

In [None]:
persistence_df = pd.DataFrame.from_dict({(i,j): persistence_dict[i][j] 
                           for i in persistence_dict.keys() 
                           for j in persistence_dict[i].keys()},
                       orient='index')
persistence_df
persistence_df.reset_index(inplace=True)
persistence_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
persistence_df.head(5)

In [None]:
persistence_df.type.value_counts()#[persistence_df["tree_number"] == 1]

In [None]:
persistence_df['migration_date'] = persistence_df['migration date'].map(convert_partial_year)
persistence_df['tip_date'] = persistence_df['tip date'].map(convert_partial_year)
persistence_df['year-month'] = persistence_df['migration_date'].map(convert_format)
persistence_df['persistence_time'] = persistence_df['persistance'].map(convert_persistence)
persistence_df['persistence_time'] = persistence_df['persistence_time'].div(86400) #calculating number of days from seconds

In [None]:
temp = persistence_df[(persistence_df['year-month'] > "2020-01") ]

In [None]:
#temp.to_csv("../data-files/persistance_df.csv")

In [None]:
stripplot =  alt.Chart(temp, width = 750).mark_circle(size=8, opacity = 0.2).encode(
    x=alt.X('year-month:O'),
    y=alt.Y('persistence_time:Q'),
    color=alt.Color('child_host:N'))

lineplot =  alt.Chart(temp, width = 750).mark_line(interpolate='monotone').encode(
    x=alt.X('year-month:T',axis=alt.Axis(title="", grid=False)),
    y=alt.Y('mean(persistence_time)',axis=alt.Axis(title="Length of Local Transmission (in Days)", grid=False)),
    color=alt.Color('child_host:N')).properties(
    width=800,
    height=300
)

band = alt.Chart(temp).mark_errorband(extent='ci', interpolate='monotone').encode(
    x=alt.X('year-month:T'),
    y=alt.Y('persistence_time:Q',axis=alt.Axis(title="", grid=False)), 
    color =alt.Color('child_host:N')
    
).properties(
    width=800,
    height=300
)

In [None]:
persist = lineplot + band +rule + text

In [None]:
persist

In [None]:
#persist.save("persistance.png")