In [None]:
import baltic as bt
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
from datetime import datetime as dt
from datetime import timedelta
import matplotlib.lines as mlines
import time
from io import StringIO
import altair as alt
import numpy as np
import random
alt.data_transformers.disable_max_rows()

In [None]:
trees = "../../mascot_glm/results/glm_mcc_map_randomkc_clusters_combined_new.typed.trees"

### Work on calculating ancestral root states of clusters


In [None]:
dates1= pd.date_range('2020-01','2022-04' , freq='1M').strftime('%Y-%m')#-pd.offsets.MonthBegin(1)
date_df = pd.DataFrame(dates1)
date_df = date_df.rename(columns = {0: 'yearmonth'})
date_df['first_day'] = pd.date_range('2020-01','2022-04' , freq='1M')-pd.offsets.MonthBegin(1)
date_df['last_day'] =  pd.date_range('2020-01','2022-04' , freq='1M')
date_df.first_day= date_df.first_day.map(toYearFraction)
date_df.last_day= date_df.last_day.map(toYearFraction)

In [None]:
def bl_overtime(date_df, tree, condition=lambda x:True):
    output_dict = {}
    for index, row in date_df.iterrows():
        north_bl = []
        south_bl = []
        for k in tree.Objects:
            try:
                if k.parent.absoluteTime != None and row.first_day<=k.parent.absoluteTime<row.last_day and condition(k) and k.parent.traits['typeTrait'] =="North_King_County":
                    if k.absoluteTime > row.last_day:
                        child_time = row.last_day 
                    else:
                        child_time = k.absoluteTime
                    
                    bl = child_time- k.parent.absoluteTime
                    north_bl.append(bl)
                    
            except KeyError:
                continue
 
            try:
                if k.parent.absoluteTime != None and row.first_day<=k.parent.absoluteTime<row.last_day and condition(k) and k.parent.traits['typeTrait'] =="South_King_County":
                    if k.absoluteTime > row.last_day:
                        child_time = row.last_day 
                    else:
                        child_time = k.absoluteTime
                    
                    bl = child_time- k.parent.absoluteTime
                    south_bl.append(bl)
                    
            except KeyError:
                continue
                
        total_bl_n = sum(north_bl)
        
        total_bl_s = sum(south_bl)
        
        output_dict[index] = {"yearmonth":row.yearmonth, "total_bl_n":total_bl_n, "total_bl_s": total_bl_s}    
    return(output_dict)  

In [None]:
def enumerate_root_states(tree):
        
    output_dict = {}
    migration_events_counter = 0
        
    for k in tree.Objects:
        if k.traits == {}:
            k.traits = {'obs': 0.0, 'typeTrait':"root"}
        elif k.traits['obs'] ==  0.0:
            k.traits = {'obs': 0.0, 'typeTrait':"none"}
        
        trait = k.traits['typeTrait']
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are from none/other to a region
        else:

        
            parent_trait = parent_node.traits['typeTrait']
        
            if trait != parent_trait:
                migration_events_counter += 1
                migration_event = parent_trait + "-to-" + trait
                migration_date = parent_node.absoluteTime  
                if trait == "South_King_County" or trait == "North_King_County":
                    try:
                        concurrent_lineages = countLineages_ns(tree, parent_node.absoluteTime, region = trait)
                        concurrent_bl = countBL_ns(tree, parent_node.absoluteTime, region = trait)
                    except: 
                        pass
                else: 
                    concurrent_lineages = np.nan
                    concurrent_bl = np.nan

                

                # write to output dictionary
                output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date, 'lineages': concurrent_lineages,
                                                         "branch_lengths" : concurrent_bl,
                                                        "parent_host":parent_trait,
                                                        "child_host": trait}
    
    return(output_dict)

In [None]:
start_time = time.time()

with open(trees, "r") as infile:
    
    tree_counter = 0
    trees_processed = 0
    migrations_dict = {}
    time_dict = {}
    
    for line in infile:
       # print(line)
        if 'tree STATE_' in line:
            tree_counter += 1
            

            temp_tree = StringIO(line)
            tree = bt.loadNexus(temp_tree, absoluteTime = False)
            tree.setAbsoluteTime(2022.1753424657534)
            trees_processed += 1

            # iterate through the tree and pull out all migration events
            migrations_dict[tree_counter] = enumerate_root_states(tree)
            time_dict[tree_counter] = bl_overtime(date_df, tree)

# print the amount of time this took
total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", trees_processed, "trees")

In [None]:
migrations_df = pd.DataFrame.from_dict({(i,j): migrations_dict[i][j] 
                           for i in migrations_dict.keys() 
                           for j in migrations_dict[i].keys()},
                       orient='index')
migrations_df
migrations_df.reset_index(inplace=True)
migrations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
migrations_df.head()

In [None]:
time_df = pd.DataFrame.from_dict({(i,j): time_dict[i][j] 
                           for i in time_dict.keys() 
                           for j in time_dict[i].keys()},
                       orient='index')
time_df
time_df.reset_index(inplace=True)
time_df.rename(columns={'level_0': 'tree_number', 'level_1': 'event_number', 'yearmonth': 'year-month'}, inplace=True)
time_df

In [None]:
migrations_df.lineages.value_counts()

In [None]:
#need to convert the decimal dates back to calendar dates cause it be like that sometimes 
def convert_partial_year(number):

    year = int(number)
    d = timedelta(days=(number - year)*(365 + is_leap(year)))
    day_one = dt(year,1,1)
    date = d + day_one
    date = dt.strftime(date, '%Y-%m-%d')
    return date

In [None]:
def is_leap(number):
    if number == 2020:
        leap = 1
    else:
        leap = 0
    return leap

In [None]:
def convert_format(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%Y-%m')
    return date

In [None]:
migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)

In [None]:
migrations_df['year-month'] = migrations_df['calendar_date'].map(convert_format)

In [None]:
migrations_df

In [None]:
merged_mr_df = pd.merge(migrations_df, time_df, on = ['tree_number', 'year-month'], how = 'left')

In [None]:
merged_mr_df

In [None]:

def return_proportions_dataframe(input_df, time_unit):
    output_df = pd.DataFrame()
    north_kc = ["South_King_County-to-North_King_County", "none-to-North_King_County"]
    south_kc = ["North_King_County-to-South_King_County", "none-to-South_King_County"]
    
    for tree_number in set(input_df['tree_number'].tolist()):
        local_df1 = input_df[input_df['tree_number'] == tree_number]
        
        for v in list(set(input_df['type'].tolist())):
            local_df = local_df1[local_df1['type'] == v]
            total_transitions = len(local_df)

            for item in set(input_df[time_unit].tolist()):
                local_df2 = local_df[local_df[time_unit] == item]
                transitions_in_time_unit = len(local_df2)
                average_lin = local_df2.lineages.mean()
                if v in north_kc:
                    average_bl = local_df2.total_bl_n.mean()
                elif v in south_kc: 
                    average_bl = local_df2.total_bl_s.mean()
                    
                if total_transitions != 0:
                    prop_transitions_in_time_unit = transitions_in_time_unit/total_transitions
                else:
                    prop_transitions_in_time_unit = 0
                
                if transitions_in_time_unit != 0:
                    mig_per_bl = transitions_in_time_unit/average_bl
                else:
                    mig_per_bl = 0
                    
                if transitions_in_time_unit != 0:
                    mig_per_lineage = transitions_in_time_unit/average_lin
                else:
                    mig_per_lineage = 0

                to_add = pd.DataFrame({"migration_direction":[v],time_unit:[item],"tree_number":[tree_number], 
                                       "total_transitions":[total_transitions],
                                       "transitions_in_time_interval":[transitions_in_time_unit],
                                      "proportion_transitions_in_time_interval":[prop_transitions_in_time_unit], "mig_per_lineage":[mig_per_lineage], "mig_per_bl":[mig_per_bl]})
                output_df = output_df.append(to_add)
            
    return(output_df)

In [None]:
start_time = time.time()

mig = return_proportions_dataframe(merged_mr_df, "year-month")

total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print(total_time_minutes)

mig.head()

In [None]:

mig_coll = mig
mig_coll['region'] = np.nan
mig_coll.loc[(mig_coll.migration_direction == "none-to-South_King_County"), "region"] = "South King County"
mig_coll.loc[(mig_coll.migration_direction == "none-to-North_King_County"), "region"] = "North King County"
mig_coll = mig_coll[mig_coll['year-month'] >"2020-01"]



In [None]:
mig_coll = mig_coll.dropna(subset = ["region"])

In [None]:
mig_coll

In [None]:
mig_coll = mig_coll[mig_coll['year-month'] <"2022-04"]
#mig_coll.to_csv("../data-files/root_states_df.csv")

In [None]:
#highlighting important NPIs in WA
data = {'date': [ "2020-03-23", "2020-06-01", "2020-11-18", "2021-02-14"], 'event':[ "Stay at home", "Stay at home lifted", "Closing restaurants", "Reopening restaurants"]}

npidf = pd.DataFrame(data)
npidf.date = pd.to_datetime(npidf.date)

rule = alt.Chart(npidf).mark_rule(
    color="black",
    strokeWidth=2, 
    opacity = 0.3
).encode(
    alt.X('date:T', axis=alt.Axis(title=None))
).properties(
    width=800,
    height=300
)

text = alt.Chart(npidf).mark_text(
    align='left',
    baseline='middle',
    dx=2,
    dy=-135,
    size=11
).encode(
    alt.X('date:T',axis=alt.Axis(title=None)),
    text='event',
    color=alt.value('#000000')
).properties(
    width=800,
    height=300
)

In [None]:
lineplot =  alt.Chart(mig_coll, width = 750).mark_line(interpolate='monotone').encode(
    x=alt.X('year-month:T',axis=alt.Axis( grid=False)),
    y=alt.Y('mean(mig_per_bl)',  axis=alt.Axis(title="Cluster Root States", grid=False)),
    color=alt.Color('region:N')).properties(
    width=800,
    height=300
)

band = alt.Chart(mig_coll).mark_errorband(extent='ci', interpolate='monotone').encode(
    x=alt.X('year-month:T'),
    y=alt.Y('mig_per_bl', axis = alt.Axis(title = "", grid = False)), 
    color =alt.Color('region:N')
    
).properties(
    width=800,
    height=300
)

In [None]:
root_states = lineplot + band + rule + text

In [None]:
root_states 

In [None]:
chart = alt.Chart(mig_coll).mark_bar().encode(
    alt.X('year-month:O'), 
    alt.Color('region:N'),
    alt.Y("mean(mig_per_bl)", stack="normalize", title='root states per bl'))



In [None]:
chart

In [None]:
#chart.save("root_states_normalized.png")

In [None]:
error_bars = alt.Chart(mig_coll).mark_errorbar(extent='ci').encode(
  x=alt.X('mig_per_bl:Q', scale=alt.Scale(zero=False)),
  y=alt.Y('region:N')
)

points = alt.Chart(mig_coll).mark_point(filled=True, color='black').encode(
  x=alt.X('mig_per_bl:Q', aggregate='mean'),
  y=alt.Y('region:N'),
)

ave = error_bars + points
ave