In [198]:
import pandas as pd
import numpy as np
import altair as alt
from altair import datum
import baltic as bt
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.patheffects as path_effects
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Polygon
import random

from datetime import datetime as dt
from datetime import timedelta
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [158]:
def enumerate_migration_events(tree):
        
    output_dict = {}
    migration_events_counter = 0
        
    for k in tree.Objects:
#         if k.traits == {}:
#             k.traits = {'obs': 0.0, 'ns_kc':"root"}
#         elif k.traits['obs'] ==  0.0:
#             k.traits = {'obs': 0.0, 'ns_kc':"none"}
        
        trait = k.traits['ns_kc']
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        elif (trait == "South_King_County") or (trait == "North_King_County"):


        
            parent_trait = parent_node.traits['ns_kc']
            
            if (trait != parent_trait):
                if (parent_trait != "South_King_County") and (parent_trait != "North_King_County"):
                    migration_events_counter += 1
                    migration_event = parent_trait + "-to-" + trait
                    migration_date = parent_node.absoluteTime + (k.absoluteTime - parent_node.absoluteTime) *random.uniform(0,1)



                    # write to output dictionary
                    output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date, 
                                                            "parent_host":parent_trait,
                                                            "child_host": trait}
    
    return(output_dict)

In [159]:
def enumerate_migration_events_other(tree):
    output_dict = {}
    migration_events_counter = 0
    clades_to_remove = [ '21L (Omicron)', '20J (Gamma, V3)', '21J (Delta)', '21M (Omicron)',  '21K (Omicron)', '20H (Beta, V2)' ]
    list_clades = []   
    for k in tree.Objects:
#         if k.traits == {}:
#             k.traits = {'obs': 0.0, 'ns_kc':"root"}
#         elif k.traits['obs'] ==  0.0:
#             k.traits = {'obs': 0.0, 'ns_kc':"none"}
        #clade = k.traits[]
        trait = k.traits['ns_kc']
        list_clades.append(k.traits["clade_membership"])
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        elif (trait == "South_King_County") or (trait == "North_King_County"):

            if k.traits["clade_membership"] in clades_to_remove:
                #print(k.traits["clade_membership"])
                pass
            else:
        
                parent_trait = parent_node.traits['ns_kc']

                if (trait != parent_trait):
                    if (parent_trait != "South_King_County") and (parent_trait != "North_King_County"):
                        migration_events_counter += 1
                        migration_event = parent_trait + "-to-" + trait
                        migration_date = parent_node.absoluteTime + (k.absoluteTime - parent_node.absoluteTime) *random.uniform(0,1)  



                        # write to output dictionary
                        output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date, 
                                                                "parent_host":parent_trait,
                                                                "child_host": trait}
                        
    #myset = set(list_clades)
    #print(myset)
    
    return(output_dict)

In [160]:
#need to convert the decimal dates back to calendar dates 
def convert_partial_year(number):

    year = int(number)
    d = timedelta(days=(number - year)*(365 + is_leap(year)))
    day_one = dt(year,1,1)
    date = d + day_one
    date = dt.strftime(date, '%Y-%m-%d')
    return date

In [161]:
def is_leap(number):
    if number == 2020:
        leap = 1
    else:
        leap = 0
    return leap

In [162]:
def convert_format(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%Y-%m')
    return date

In [163]:
 def plot_quick_tree(tree):   
    fig,ax = plt.subplots(figsize=(10,20),facecolor='w')

    x_attr=lambda k: k.absoluteTime ## x coordinate of branches will be absoluteTime attribute
    #c_func=lambda k: 'indianred' if k.traits['PB1']=='V' else 'steelblue' ## colour of branches
    s_func=lambda k: 50-30*k.height/ll.treeHeight ## size of tips

    ll.plotTree(ax,x_attr=x_attr,colour=c_func) ## plot branches
    ll.plotPoints(ax,x_attr=x_attr,size=s_func,zorder=100) ## plot circles at tips

    #ax.set_ylim(-5,ll.ySpan+5)
    plt.show()

In [164]:
def source_of_intro(tree):

    json_translation={'absoluteTime':lambda k: k.traits['node_attrs']['num_date']['value'],'name':'name'} ## allows baltic to find correct attributes in JSON, height and name are required at a minimum
    #json_meta={'file':meta,'traitName':'region'} ## if you want auspice stylings you can import the meta file used on nextstrain.org
    ll,meta=bt.loadJSON(tree,json_translation=json_translation) ## give loadJSON the name of the tree file, the translation dictionary and (optionally) the meta file

    
    
    #pull out migration events labeling parent and child nodes
    if "ncov-king-county_other.json" in tree: #tree == "/Users/mparedes/Desktop/new_variants_pull/auspice/ncov_other_sub.json":
        migrations_dict = enumerate_migration_events_other(ll)
    else:
        migrations_dict = enumerate_migration_events(ll)
    migrations_df = pd.DataFrame.from_dict(migrations_dict).T
    migrations_df.reset_index(inplace=True)

    migrations_df.child_host[migrations_df.child_host == "North_King_County"] = "North King County" 
    migrations_df.child_host[migrations_df.child_host == "South_King_County"] = "South King County" 

    #convert decimal year into year-month
    migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)
    migrations_df['year-month'] = migrations_df['calendar_date'].map(convert_format)
    
    if tree == "/Users/mparedes/Desktop/new_variants_pull/auspice/ncov_alpha_sub.json":
        migrations_df= migrations_df[migrations_df["year-month"] > "2020-10"]
    elif tree == "/Users/mparedes/Desktop/new_variants_pull/auspice/ncov_delta_sub.json":
        migrations_df= migrations_df[migrations_df["year-month"] > "2020-11"]
    elif tree == "/Users/mparedes/Desktop/new_variants_pull/auspice/ncov_omicron_sub.json":
        migrations_df= migrations_df[migrations_df["year-month"] > "2021-08"]   
    elif tree == "/Users/mparedes/Desktop/new_variants_pull/auspice/ncov_other_sub.json":
        migrations_df= migrations_df[migrations_df["year-month"] > "2020-01"]



    return(migrations_df)

In [187]:
overall_df = pd.DataFrame()
overall_df

In [189]:
list_trees = ["../nextstrain_build/auspice/ncov-king-county_other.json", "../nextstrain_build/auspice/ncov-king-county_alpha.json", "../nextstrain_build/auspice/ncov-king-county_delta.json", "../nextstrain_build/auspice/ncov-king-county_omicron.json" ]
variant_dict = {}
overall_df = pd.DataFrame()
for tree in list_trees:
    name = tree.split("/")[-1].split(".")[0]
    variant_dict[name] = source_of_intro(tree)
    overall_df = pd.concat([overall_df, variant_dict[name]])
    


    




Tree height: 2.686368
Tree length: 1783.811191
annotations present

Numbers of objects in tree: 34840 (16533 nodes and 18307 leaves)



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "North_King_County"] = "North King County"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "South_King_County"] = "South King County"



Tree height: 2.402729
Tree length: 1483.016192
annotations present

Numbers of objects in tree: 35446 (17006 nodes and 18440 leaves)



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "North_King_County"] = "North King County"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "South_King_County"] = "South King County"



Tree height: 2.633705
Tree length: 2326.109713
annotations present

Numbers of objects in tree: 42905 (20537 nodes and 22368 leaves)



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "North_King_County"] = "North King County"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "South_King_County"] = "South King County"



Tree height: 2.658406
Tree length: 1513.402866
annotations present

Numbers of objects in tree: 40455 (19496 nodes and 20959 leaves)



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "North_King_County"] = "North King County"
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  migrations_df.child_host[migrations_df.child_host == "South_King_County"] = "South King County"


In [190]:
overall_df

Unnamed: 0,index,type,date,parent_host,child_host,calendar_date,year-month
0,1,North America-to-North_King_County,2020.264174,North America,North King County,2020-04-06,2020-04
1,2,North America-to-South_King_County,2020.24735,North America,South King County,2020-03-31,2020-03
2,3,Asia-to-North_King_County,2020.267525,Asia,North King County,2020-04-07,2020-04
3,4,Asia-to-North_King_County,2020.233866,Asia,North King County,2020-03-26,2020-03
4,5,Asia-to-North_King_County,2020.282787,Asia,North King County,2020-04-13,2020-04
...,...,...,...,...,...,...,...
1805,1806,Washington-to-South_King_County,2022.069863,Washington,South King County,2022-01-26,2022-01
1806,1807,North America-to-South_King_County,2022.10529,North America,South King County,2022-02-08,2022-02
1807,1808,North America-to-South_King_County,2022.103493,North America,South King County,2022-02-07,2022-02
1808,1809,Europe-to-North_King_County,2022.123449,Europe,North King County,2022-02-15,2022-02


In [206]:

# Assuming variant_dict is a dictionary of DataFrames

charts = []

chart = alt.Chart(overall_df).transform_aggregate(
        count='count()',
        groupby=["child_host",'year-month', "parent_host"]
        ).transform_filter(
        'datum.count >= 3'
        ).mark_area(interpolate="monotone").encode(
            alt.X("year-month:T", title="",  axis=alt.Axis(title=None, grid=False, tickCount = "month",format="%b %Y")),  # Assuming "year-month" is a time field
            alt.Y("count:Q", stack="normalize", axis=alt.Axis(title="Percent of Total Introductions", grid=False, format='%')),  # Assuming "count()" is a quantitative field
            alt.Color("parent_host:N", title="Source of Introduction", scale = alt.Scale(scheme="tableau10")),
            alt.Column("child_host:N", title = "", spacing =50)

        ).properties(
            title="Overall"
        ) 

charts.append(chart)

for df_name, df in variant_dict.items():
    if df_name != "ncov-king-county_other":
        chart = alt.Chart(df).transform_aggregate(
        count='count()',
        groupby=["child_host",'year-month', "parent_host"]
        ).transform_filter(
        'datum.count >= 3'
        ).mark_area(interpolate="monotone").encode(
            alt.X("year-month:T", title="",  axis=alt.Axis(title=None, grid=False, tickCount = "month",format="%b %Y")),  # Assuming "year-month" is a time field
            alt.Y("count:Q", stack="normalize", axis=alt.Axis(title="Percent of Total Introductions", grid=False, format='%')),  # Assuming "count()" is a quantitative field
            alt.Color("parent_host:N", title="Source of Introduction", scale = alt.Scale(scheme="tableau10")),
            alt.Column("child_host:N", title = "", spacing =50)

        ).properties(
            title=df_name.split("_")[1].capitalize()
        ) 


        charts.append(chart)
        


# Concatenate charts horizontally
concatenated_chart = alt.vconcat(*charts, spacing=10).resolve_scale(x = "shared")

# Display the concatenated chart
concatenated_chart.configure_title(
    anchor='start', fontSize= 20
).configure_axis(
    labelFontSize=16,
    titleFontSize=16
).configure_legend(
    labelFontSize = 16,
    titleFontSize = 16
).configure_header(
    labelFontSize =16)
