In [10]:
#copied from Maria, not sure which libraries i need or don't need

import baltic as bt
import pandas as pd
import json
import os
import matplotlib as mpl
from matplotlib import pyplot as plt
import requests
from io import StringIO as sio
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
import itertools
import re
import sys
import subprocess
from Bio import Phylo


In [12]:
newickPath='forbaltic.nwk' #give it the path name (since forbaltic.nwk is already in the same folder as the script, i just gave it the file name"
myTree = bt.loadNewick(newickPath) #load the tree (newick format) with baltic and give it the variable myTree

In [24]:
#this function loads trait data from a JSON file and processes it into three dictionaries 
#that will later be used to assign traits to nodes and leaves in a tree

def load_trait_data(traits_file): #this function takes in a traits_file
    with open(traits_file) as f:
        data = json.load(f) #reacs the contents of the file as a JSON object, parsed JSON data is stored in the bariable data
    region_dict = {} #dictionaries to hold the data of interest 
    host_dict = {}
    clade_dict = {} #we added the h5_megaclade column/info here
    for node_name, node_data in data['nodes'].items(): #iterates over the data['node'] dictionary, node_name = name of node, node_data = associated trait data
        region_dict[node_name] = node_data.get("region") #retrieves these values from node_data using .get()
        host_dict[node_name] = node_data.get("host")
        clade_dict[node_name] = node_data.get("h5_megaclade")
    return(region_dict, host_dict, clade_dict)

#takes a tree object and the dictionaries returned by load_trait_data function (which map node names to respective trait values)
def traits(tree, region_dict, host_dict, clade_dict):
    for k in tree.Objects: #iterates thorugh all objects in the tree
        if k.branchType == "node": #if k is an internal node (label is used to fetch the region, host, and clade info from respective dictionaries
            k.traits["region"] = region_dict.get(k.traits["label"])
            k.traits["host"] = host_dict.get(k.traits["label"])
            k.traits["h5_megaclade"] = clade_dict.get(k.traits["label"])
        if k.branchType == "leaf": #if k is a leaf/tip, name of leaf is used to getch the region, host, and clade info from respective dictionaries 
            k.traits["region"] = region_dict.get(k.name)
            k.traits["host"] = host_dict.get(k.name)
            k.traits["h5_megaclade"] = clade_dict.get(k.name)
    return(tree)

In [26]:
region_dict, host_dict, clade_dict = load_trait_data("traits/traits_h5nx_ha.json")
tree = traits(myTree, region_dict, host_dict, clade_dict)

In [28]:
region_dict, host_dict, clade_dict = load_trait_data("traits/traits_h5nx_ha.json")
tree = traits(myTree, region_dict, host_dict, clade_dict)

count_rea = 0
count_node = 0
count_leaf = 0
branch_length = 0

for k in myTree.Objects:
    
    if k.is_node():
        count_node += 1
        branch_length += k.length 
        if k.traits["is_reassorted"] == 1.0: 
            count_rea += 1
    # print(k.length)  
    #print(k.traits["label"], k.traits["is_reassorted"], k.traits["h5_megaclade"], k.length)
    if k.is_leaf():
        count_leaf += 1
        branch_length += k.length 
        if k.traits["is_reassorted"] == 1.0: 
            count_rea += 1
        # print(k.name, k.traits["is_reassorted"], k.traits["h5_megaclade"], k.length)
    else: 
        pass

clock_rate=3.11e-3

print("Total reassortment events:", count_rea)
print("Total number of nodes and leaves:", count_node+count_leaf)
print("Total branch length of the whole tree:", branch_length) 
print("Fake reassortment rate:", count_rea/(count_node+count_leaf))
print("Better, more accurate? reassortment rate,:", count_rea/branch_length) ##reasortment events/divergence units
print("Reassortment rate per year:", (count_rea/branch_length)*clock_rate)

Total reassortment events: 1591
Total number of nodes and leaves: 8879
Total branch length of the whole tree: 19.578090000000003
Fake reassortment rate: 0.17918684536546908
Better, more accurate? reassortment rate,: 81.26431127857721
Reassortment rate per year: 0.25273200807637514


In [34]:
# Load the tree and trait data
newickPath = 'forbaltic.nwk'
myTree = bt.loadNewick(newickPath)

# Load the trait data from JSON file
region_dict, host_dict, clade_dict = load_trait_data("traits/traits_h5nx_ha.json")
tree = traits(myTree, region_dict, host_dict, clade_dict)

# Initialize a dictionary to store count and branch length for each megaclade
megaclade_stats = {
    "2.3.2.1s": {'count_rea': 0, 'branch_length': 0},
    "2.3.4.4s": {'count_rea': 0, 'branch_length': 0},
    "2.3.4.4b": {'count_rea': 0, 'branch_length': 0},
    "Am-nonGsGD": {'count_rea': 0, 'branch_length': 0},
    "EA-nonGsGD": {'count_rea': 0, 'branch_length': 0},
    "other": {'count_rea': 0, 'branch_length': 0}
}

# Loop through all objects in the tree
for k in myTree.Objects:
    
    # Check if the object is either a node or a leaf
    if k.is_node() or k.is_leaf():
        megaclade_value = k.traits["h5_megaclade"]  # Get the megaclade value for the current object
       
        # Only process if the megaclade is in the list
        if megaclade_value in megaclade_stats:
            
            # Check if the object is reassorted
            if k.traits["is_reassorted"] == 1.0:
                
                # Update the count and branch length for the specific megaclade
                megaclade_stats[megaclade_value]['count_rea'] += 1
                megaclade_stats[megaclade_value]['branch_length'] += k.length

            if k.traits["is_reassorted"] == 0.0:
                    
                # Update the count and branch length for the specific megaclade
                megaclade_stats[megaclade_value]['branch_length'] += k.length

clock_rate=3.11e-3

# Print the results for each megaclade
for megaclade_name in megaclade_stats:

    # Retrieve the statistics for the current megaclade
    count_rea = megaclade_stats[megaclade_name]['count_rea']
    branch_length = megaclade_stats[megaclade_name]['branch_length']
    reassortment_rate = count_rea/branch_length
    per_year = reassortment_rate*clock_rate
    
    # Output the results in a readable format
    print(f"Megaclade: {megaclade_name}")
    print(f"  Number of reassorted branches: {count_rea}")
    print(f"  Total branch length for reassorted branches: {branch_length}")
    print(f" Reassortment rate: {reassortment_rate}") 
    print(f" Reassortment rate per year: {per_year}\n") 

Megaclade: 2.3.2.1s
  Number of reassorted branches: 114
  Total branch length for reassorted branches: 2.7191299999999985
 Reassortment rate: 41.92517459628633
 Reassortment rate per year: 0.13038729299445048

Megaclade: 2.3.4.4s
  Number of reassorted branches: 209
  Total branch length for reassorted branches: 2.2993399999999995
 Reassortment rate: 90.89564831647344
 Reassortment rate per year: 0.2826854662642324

Megaclade: 2.3.4.4b
  Number of reassorted branches: 138
  Total branch length for reassorted branches: 1.3832299999999929
 Reassortment rate: 99.76648858107525
 Reassortment rate per year: 0.310273779487144

Megaclade: Am-nonGsGD
  Number of reassorted branches: 541
  Total branch length for reassorted branches: 4.693749999999994
 Reassortment rate: 115.25965379494022
 Reassortment rate per year: 0.3584575233022641

Megaclade: EA-nonGsGD
  Number of reassorted branches: 482
  Total branch length for reassorted branches: 3.634989999999993
 Reassortment rate: 132.6000896838

In [44]:
def calculate_reassortment(newick_path, trait_data_path):
    """
    Process a phylogenetic tree and its associated trait data to compute reassortment rates by megaclade.

    Parameters:
    - newick_path: str, path to the Newick file containing the tree.
    - trait_data_path: str, path to the JSON file containing trait data.

    Returns:
    - None: Prints results for each megaclade, including reassortment rate and other statistics.
    """

    # Load the tree in Newick format
    myTree = bt.loadNewick(newick_path)

    # Load the trait data from the JSON file
    region_dict, host_dict, clade_dict = load_trait_data(trait_data_path)

    # Process the tree with the trait data
    tree = traits(myTree, region_dict, host_dict, clade_dict)

    # Initialize a dictionary to store count and branch length for each megaclade
    megaclade_stats = {
        "2.3.2.1s": {'count_rea': 0, 'branch_length': 0},
        "2.3.4.4s": {'count_rea': 0, 'branch_length': 0},
        "2.3.4.4b": {'count_rea': 0, 'branch_length': 0},
        "Am-nonGsGD": {'count_rea': 0, 'branch_length': 0},
        "EA-nonGsGD": {'count_rea': 0, 'branch_length': 0},
        "other": {'count_rea': 0, 'branch_length': 0}
    }

    # Loop through all objects in the tree
    for k in myTree.Objects:
        
        # Check if the object is either a node or a leaf
        if k.is_node() or k.is_leaf():
            megaclade_value = k.traits["h5_megaclade"]  # Get the megaclade value for the current object
           
            # Only process if the megaclade is in the list
            if megaclade_value in megaclade_stats:
                
                # Check if the object is reassorted
                if k.traits["is_reassorted"] == 1.0:
                    
                    # Update the count and branch length for the specific megaclade
                    megaclade_stats[megaclade_value]['count_rea'] += 1
                    megaclade_stats[megaclade_value]['branch_length'] += k.length
                    
                if k.traits["is_reassorted"] == 0.0:
                    
                # Update the count and branch length for the specific megaclade
                    megaclade_stats[megaclade_value]['branch_length'] += k.length

    # Print the results for each megaclade
    for megaclade_name in megaclade_stats:

        # Retrieve the statistics for the current megaclade
        count_rea = megaclade_stats[megaclade_name]['count_rea']
        branch_length = megaclade_stats[megaclade_name]['branch_length']
        
        # Handle division by zero for reassortment rate calculation
        if branch_length > 0:
            reassortment_rate = count_rea / branch_length
        else:
            reassortment_rate = 0

        clock_rate = 3.11e-3
        per_year = reassortment_rate*clock_rate
        
        # Output the results in a readable format
        print(f"Megaclade: {megaclade_name}")
        print(f"  Number of reassorted branches: {count_rea}")
        print(f"  Total branch length for reassorted branches: {branch_length}")
        print(f"  Reassortment rate: {reassortment_rate:.4f}")
        print(f" Reassortment rate per year: {per_year}\n") 

In [46]:
calculate_reassortment('forbaltic.nwk', 'traits/traits_h5nx_ha.json')


Megaclade: 2.3.2.1s
  Number of reassorted branches: 114
  Total branch length for reassorted branches: 2.7191299999999985
  Reassortment rate: 41.9252
 Reassortment rate per year: 0.13038729299445048

Megaclade: 2.3.4.4s
  Number of reassorted branches: 209
  Total branch length for reassorted branches: 2.2993399999999995
  Reassortment rate: 90.8956
 Reassortment rate per year: 0.2826854662642324

Megaclade: 2.3.4.4b
  Number of reassorted branches: 138
  Total branch length for reassorted branches: 1.3832299999999929
  Reassortment rate: 99.7665
 Reassortment rate per year: 0.310273779487144

Megaclade: Am-nonGsGD
  Number of reassorted branches: 541
  Total branch length for reassorted branches: 4.693749999999994
  Reassortment rate: 115.2597
 Reassortment rate per year: 0.3584575233022641

Megaclade: EA-nonGsGD
  Number of reassorted branches: 482
  Total branch length for reassorted branches: 3.634989999999993
  Reassortment rate: 132.6001
 Reassortment rate per year: 0.412386278

In [20]:
#okay now we do it by megaclade! 
#brainstorm
newickPath='forbaltic.nwk' #give it the path name (since forbaltic.nwk is already in the same folder as the script, i just gave it the file name"
myTree = bt.loadNewick(newickPath) #load the tree (newick format) with baltic and give it the variable myTree

region_dict, host_dict, clade_dict = load_trait_data("traits/traits_h5nx_ha.json")
tree = traits(myTree, region_dict, host_dict, clade_dict)

count_rea = 0
branch_length = 0

# List of megaclade names
megaclade_names = ["2.3.2.1s", "2.3.4.4s", "2.3.4.4b", "Am-nonGsGD", "EA-nonGsGD", "other"]

# Initialize the megaclade_stats dictionary
megaclade_stats = {}

# Loop through the list of megaclade names
for megaclade_name in megaclade_names:
    # Initialize the dictionary for each megaclade
    megaclade_stats[megaclade_name] = {'count_rea': 0, 'branch_length': 0}

# Now megaclade_stats is populated
# Loop through all objects in the tree
for k in myTree.Objects:
    # Check if the object is either a node or a leaf
    if k.is_node() or k.is_leaf():
        # Check if the object is reassorted and has a megaclade in the list
        if k.traits["is_reassorted"] == 1.0 and k.traits["h5_megaclade"] in megaclade_names:
            count_rea += 1  # Increment reassorted count
            branch_length += k.length  # Add branch length

# Output the results
print("Number of reassorted branches:", count_rea)
print("Total branch length for reassorted branches:", branch_length)


Number of reassorted branches: 1591
Total branch length for reassorted branches: 5.060599999999983


In [82]:
#this function reads a tree, detects which nodes/leaves are reassorted 
#counts and names the reassorted segments and outputs that info into a json (visualized by auspice)???

def reassortment_counter(counter_input, counter_output):
    #load the tree, absoluteTime = False (nodes are not placed using calendar time), verbose = False (suppresses console output during loading)
    mytree = bt.loadNewick(counter_input, absoluteTime= False, verbose=False) 
    rea_dict = {} #will hold per-node or per-tip reassortment metadata 
    segments = ["PB2", "PB1", "PA", "HA", "NP", "NA", "MP", "NS"]

    reassortment_count = 0 
    node_count = 0
    leaf_count = 0
    
    for k in mytree.Objects: #iterate thorugh all objects (nodes and leaves) in the tree
        # Count node and leaf types
        if k.is_node():
            node_count += 1
        elif k.is_leaf():
            leaf_count += 1   
            
        if k.traits["is_reassorted"]: #checks if k.traits is *reassorted* 
            reassortment_count += 1
            # Extract and clean reassorted segments
            raw_segments = k.traits["rea"] #raw string from the rea trait (ie "PB2(1)-PB1(2))
            segment_names = [seg.split("(")[0] for seg in raw_segments.split("-")] #cleans segment name, removes parantheses/extra info
            reassorted_segments = f"{len(segment_names)} ({', '.join(segment_names)})" #formats the count and names (ie "2 (PB2, PB1)")
            if k.is_node(): #if node, use its label as the key, store whether it is reassorted and what segments are involved
                rea_dict[k.traits.get("label")] = {
                    "Reassorted": "True",
                    "Reassorted Segments": reassorted_segments
                }
            elif k.is_leaf(): #if leaf, use its name as the key
                rea_dict[k.name] = {
                    "Reassorted": "True",
                    "Reassorted Segments": reassorted_segments
                }
        else: #for nonreassorted nodes/leaves, just mark "Reassorted" : "False"
            key = (k.traits["label"] if k.is_node() else k.name)
            rea_dict[key] = {"Reassorted": "False"}
    branch_dict = {} #holds per branch labels
    
    for k in mytree.Objects: #iterates through the tree nodes to generate labels
        #for each reassorted node or leaf, creates a labels entry with reassorted segments
        if k.traits["is_reassorted"]:
            # Use the cleaned reassorted segments from `rea_dict`
            if k.is_node():
                branch_dict[k.traits.get("label")] = {
                    "labels": {'Reassorted Segments': rea_dict[k.traits.get("label")]['Reassorted Segments']}
                }
            elif k.is_leaf():
                branch_dict[k.name] = {
                    "labels": {'Reassorted Segments': rea_dict[k.name]['Reassorted Segments']}
                }
    out_dict = {'nodes': rea_dict, 'branches': branch_dict, 'summary':{
            'total_reassortment_events': reassortment_count,
            'total_internal_nodes': node_count,
            'total_leaves': leaf_count,
            'total_objects': node_count + leaf_count
        }
    }#comvines all node and branch metadata into a single dictionary
    with open(counter_output, 'w') as f:
        json.dump(out_dict, f)
        
    # Optional: print the count for quick reference
    print(f"Total reassortment events: {reassortment_count}")
    print(f"Total internal nodes: {node_count}")
    print(f"Total leaves: {leaf_count}")
    print(f"Total objects in tree: {node_count + leaf_count}")

In [84]:
reassortment_counter("forbaltic.nwk", "rea.json")

Total reassortment events: 1591
Total internal nodes: 4439
Total leaves: 4440
Total objects in tree: 8879


In [22]:
#TEST for one megaclade (and to double check?)
newickPath='forbaltic.nwk' #give it the path name (since forbaltic.nwk is already in the same folder as the script, i just gave it the file name"
myTree = bt.loadNewick(newickPath) #load the tree (newick format) with baltic and give it the variable myTree

region_dict, host_dict, clade_dict = load_trait_data("traits/traits_h5nx_ha.json")
tree = traits(myTree, region_dict, host_dict, clade_dict)

count_rea = 0
branch_length = 0

for k in myTree.Objects:
    # Check if the object is either a node or a leaf
    if k.is_node() or k.is_leaf():
        # Check if the object is reassorted and has a megaclade in the list
        if k.traits["is_reassorted"] == 1.0 and k.traits["h5_megaclade"] == "other" :
            count_rea += 1  # Increment reassorted count
            branch_length += k.length  # Add branch length

# Output the results
print("Number of reassorted branches:", count_rea)
print("Total branch length for reassorted branches:", branch_length)
print("Better, more accurate? reassortment rate,:", count_rea/branch_length) ##reasortment events/divergence units


Number of reassorted branches: 107
Total branch length for reassorted branches: 0.3061199999999999
Better, more accurate? reassortment rate,: 349.5361296223704


In [124]:
 def reassortment_counter_megaclade(counter_input, counter_output, traits_file=None):
    # Load the tree
    mytree = bt.loadNewick(counter_input, absoluteTime=False, verbose=False)

    if traits_file:
        region_dict, host_dict, clade_dict = load_trait_data(traits_file)
        mytree = traits(mytree, region_dict, host_dict, clade_dict)
     
    rea_dict = {}
    branch_dict = {}
    segments = ["PB2", "PB1", "PA", "HA", "NP", "NA", "MP", "NS"]

    # Initialize counters
    reassortment_count = 0
    node_count = 0
    leaf_count = 0
    clade_reassortment_count = {}  # ← NEW: store reassortment count by clade

    for k in mytree.Objects:
        # Count node and leaf types
        if k.is_node():
            node_count += 1
        elif k.is_leaf():
            leaf_count += 1

        # Check and count reassortments
        if k.traits["is_reassorted"]:
            reassortment_count += 1

            # Clean segment names
            raw_segments = k.traits["rea"]
            segment_names = [seg.split("(")[0] for seg in raw_segments.split("-")]
            reassorted_segments = f"{len(segment_names)} ({', '.join(segment_names)})"

            # Get clade name
            clade = k.traits.get("h5_megaclade", "Unknown")

            # Increment clade-specific reassortment count
            clade_reassortment_count[clade] = clade_reassortment_count.get(clade, 0) + 1

            # Store reassortment details
            key = k.traits.get("label") if k.is_node() else k.name
            rea_dict[key] = {
                "Reassorted": "True",
                "Reassorted Segments": reassorted_segments,
                "Clade": clade
            }
        else:
            key = k.traits.get("label") if k.is_node() else k.name
            rea_dict[key] = {"Reassorted": "False"}

    # Create branch labels
    for k in mytree.Objects:
        if k.traits["is_reassorted"]:
            key = k.traits.get("label") if k.is_node() else k.name
            branch_dict[key] = {
                "labels": {
                    'Reassorted Segments': rea_dict[key]['Reassorted Segments']
                }
            }

    # Final output dictionary
    out_dict = {
        'nodes': rea_dict,
        'branches': branch_dict,
        'summary': {
            'total_reassortment_events': reassortment_count,
            'total_internal_nodes': node_count,
            'total_leaves': leaf_count,
            'total_objects': node_count + leaf_count,
            'reassortments_by_clade': clade_reassortment_count  # ← NEW
        }
    }

    # Write JSON
    with open(counter_output, 'w') as f:
        json.dump(out_dict, f, indent=2)

    # Print summary
    print(f"Total reassortment events: {reassortment_count}")
    print(f"Total internal nodes: {node_count}")
    print(f"Total leaves: {leaf_count}")
    print(f"Total objects in tree: {node_count + leaf_count}")
    print("\nReassortment events by H5 megaclade:")
    for clade, count in clade_reassortment_count.items():
        print(f"  {clade}: {count}")


In [126]:
reassortment_counter_megaclade("forbaltic.nwk", "rea.json")

Total reassortment events: 1591
Total internal nodes: 4439
Total leaves: 4440
Total objects in tree: 8879

Reassortment events by H5 megaclade:
  Unknown: 1591


In [108]:
for k in myTree.Objects:
    if k.traits:
        print(k.traits)
        break  # just print one to inspect


{'label': 'NODE_0000001', 'is_reassorted': 0.0, 'region': 'China', 'host': 'Avian', 'h5_megaclade': 'other'}
