In [3]:
import os
import re
from datetime import *
import pandas as pd
import baltic as bt

In [4]:
# I want to pull traits information from the DTA and match to Continuous phylogeography
# Luckily, the cont phyl is a strictly bifurcating tree, which means each node1 is seen twice
# So I could parse the DTA by object and then have a new row for each node.children

In [5]:
# for each node, make a list of the children
# For each child, add a row
# node 1 is the node
# node 2 is the child
# start date
# end date
# start node flyway, etc
# end node flyway, etc

In [36]:
# Fetch tree from the provided URL
# treeFile = "../../test_continuous/dates_added/wi_824_allwidt_MCC.tree"
# treeFile = "../SERAPHIM/nov2021-2024_no-suliformes/process_dta/NAm_wild-aves_nov2021-2024_no-suliformes_ha_emptree-contphyl_flyway-order-latgrp_MCC.tree"
# treeFile = "../BEAST/continuous_phylogeography/results/2025-09-24_nov2021-2024_no-suliformes_ha_contphyl_order-DTA_thorney/NAm_wild-aves_nov2021-2024_contphyl_order-DTA_2025-09-24_MCC.tree"
# treeFile = "../SERAPHIM/nov2021-2024_no-suliformes_contphyl-orderDTA/NAm_wild-aves_nov2021-2024_contphyl_order-DTA_tinytest.tree"
# treeFile = "../SERAPHIM/nov2021-2024_no-suliformes_contphyl-orderDTA/NAm_wild-aves_nov2021-2024_contphyl_order-DTA_mediumtest.tree"
# treeFile = "../SERAPHIM/nov2021-2024_no-suliformes_contphyl-orderDTA/NAm_wild-aves_nov2021-2024_contphyl_order-DTA_2025-09-24_MCC.tree"
# treeFile="../SERAPHIM/nov2021-2024_no-suliformes/NAm_wild-aves_nov2021-2024_no-suliformes_ha_contphyl_2025-08-18_MCC.tree"
treeFile = "../SERAPHIM/nov2021-2024_no-suliformes_contphyl-duckgoose-DTA/NAm_wild-aves_nov2021-2024_no-suliformes_ha_contphyl_duckgoose-DTA_MCC.tree"


# Load the tree using Baltic
tree = bt.loadNexus(treeFile)
tree.treeStats()  # Report stats about the tree


Tree height: 3.262405
Tree length: 725.812239
strictly bifurcating tree
annotations present

Numbers of objects in tree: 10995 (5497 nodes and 5498 leaves)



In [37]:
label_count = 0

for k in tree.Objects:
    k.label = label_count
    label_count += 1

In [39]:
metadata = []
most_recent_sampling_date = 2024.9262295081967
# count = 0

for obj in tree.Objects:
    children = []

    # if count < 10:
    if obj.is_node():
        children.extend(obj.children) # Use .extend instead of .append, the former makes a flat list while the latter appends a list into the list
        node1 = obj.label
        startYear = obj.absoluteTime
        length = obj.traits["length"]
        # startFly = obj.traits["Flyway"]
        # startLatGrp = obj.traits["lat_group"]
        # startOrder = obj.traits["corrected_order"]
        startOrder = obj.traits["final_ddg_grouping"]
        startLat = obj.traits["location1"]
        startLong = obj.traits["location2"]
        
        for child in children:
            node2 = child.label
            endYear = child.absoluteTime
            # endFly = child.traits["Flyway"]
            # endLatGrp = child.traits["lat_group"]
            # endOrder = child.traits["corrected_order"]
            endOrder = child.traits["final_ddg_grouping"]
            endLat = child.traits["location1"]
            endLong = child.traits["location2"]
            tip = getattr(child, 'name', pd.NA)
            
            metadata.append({
                "node1": node1,
                "node2": node2,
                "length": length,
                "startYear": startYear,
                "endYear": endYear,
                "startLat": startLat,
                "startLon": startLong,
                "endLat": endLat,
                "endLon": endLong,
                # "startFlyway": startFly,
                # "endFlyway": endFly,
                # "startLatitudeGroup": startLatGrp,
                # "endLatitudeGroup": endLatGrp,
                "startOrder": startOrder,
                "endOrder": endOrder,
                "tipLabel": tip
                }) 
        # count += 1
                
    else:
        pass

In [40]:
df = pd.DataFrame(metadata)

In [41]:
print(len(df))
df.head(15)

10994


Unnamed: 0,node1,node2,length,startYear,endYear,startLat,startLon,endLat,endLon,startOrder,endOrder,tipLabel
0,0,10670,0.0,2021.663824,2021.813835,41.63728,-73.827024,46.144388,-62.4924,duck,geese_swans,
1,0,1,0.0,2021.663824,2021.717547,41.63728,-73.827024,38.215773,-81.171769,duck,duck,
2,1,10587,0.143709,2021.717547,2021.810269,38.215773,-81.171769,39.872703,-75.587129,duck,duck,
3,1,2,0.143709,2021.717547,2021.797453,38.215773,-81.171769,37.193614,-79.296969,duck,duck,
4,2,3,0.059374,2021.797453,2021.833064,37.193614,-79.296969,36.443422,-77.72091,duck,duck,
5,2,388,0.059374,2021.797453,2021.816497,37.193614,-79.296969,37.488521,-79.718483,duck,duck,
6,3,4,0.035612,2021.833064,2021.842945,36.443422,-77.72091,36.160853,-79.846864,duck,duck,
7,3,135,0.035612,2021.833064,2021.865895,36.443422,-77.72091,37.893318,-83.57344,duck,duck,
8,4,70,0.00988,2021.842945,2021.881822,36.160853,-79.846864,34.807159,-79.822489,duck,duck,
9,4,5,0.00988,2021.842945,2021.872405,36.160853,-79.846864,35.367761,-82.223008,duck,duck,


In [42]:
# df.to_csv("../SERAPHIM/nov2021-2024_no-suliformes/process_dta/NAm_wild-aves_nov2021-2024_no-suliformes_ha_flyway-latgrp-order_MCC.csv", sep = ',')
# df.to_csv("../BEAST/continuous_phylogeography/results/2025-09-24_nov2021-2024_no-suliformes_ha_contphyl_order-DTA_thorney/NAm_wild-aves_nov2021-2024_no-suliformes_contphyl-order_MCC.csv", sep = ',')
# df.to_csv("../SERAPHIM/nov2021-2024_no-suliformes_contphyl-orderDTA/NAm_wild-aves_nov2021-2024_no_suliformes_ha_contphyl_order-DTA_MCC.csv", sep = ',')
df.to_csv("../SERAPHIM/nov2021-2024_no-suliformes_contphyl-duckgoose-DTA/NAm_wild-aves_nov2021-2024_no-suliformes_ha_contphyl_duckgoose-DTA_MCC.csv", sep = ',')

In [29]:
# Feel free to ignore all the code below! 
# Its me troubleshooting
# And I don't wanna delete it cause What If

In [8]:
# This code chunck is good for checking which objects are nodes and which are tips in a small tree

# doublecheckdict = {}

# for k in tree.Objects:
#     if k.is_node():
#         doublecheckdict.update({k.label:"node"})
#     elif k.is_leaf():
#         doublecheckdict.update({k.label:k.name})
#     else:
#         print("what")

# print(doublecheckdict)

In [10]:
# print(list(tree.Objects[0].traits.keys()))
print(list(tree.Objects[2].traits.keys()))
# print(list(tree.Objects[5].traits.keys()))
# print(tree.Objects[4].height)

['corrected_order.rate_median', 'corrected_order', 'location.rate_median', 'height_median', 'location.rate', 'length_median', 'height', 'location2_median', 'corrected_order.prob', 'location1_median', 'length', 'posterior', 'corrected_order.rate', 'location1', 'location2', 'length_range', 'height_range', 'location1_range', 'corrected_order.set.prob', 'corrected_order.rate_range', 'location.rate_range', 'location1_80%HPD_3', 'location1_80%HPD_2', 'location1_80%HPD_1', 'length_95%_HPD', 'corrected_order.rate_95%_HPD', 'height_95%_HPD', 'location.rate_95%_HPD', 'location2_80%HPD_1', 'corrected_order.set', 'location2_80%HPD_2', 'location2_80%HPD_3', 'location2_range']


In [11]:
# This code chunck was where I was troubleshooting how to get the date of each object

most_recent_sampling_date = 2024.9262295081967

test_meta = []

for obj in tree.Objects:
    children = []

    tree_max_height = max(obj.height for obj in tree.Objects)

    if obj.is_node():
        children.extend(obj.children) # Use .extend instead of .append, the former makes a flat list while the latter appends a list into the list
        # node1 = obj
        node1 = obj.label
        height_trait = obj.traits["height"] # Height relative to absolute time
        height = obj.height # Height relative to the other nodes and tips?
        length_trait = obj.traits["length"]
        length = obj.length
        startYear1 = most_recent_sampling_date - obj.traits["height"]
        startYear2 = most_recent_sampling_date - (tree_max_height - obj.height)
        startYear3 = obj.absoluteTime
        
        for child in children:
            node2 = child.label
            node2_height_trait = child.traits["height"]
            node2_height = child.height
            node2_length_trait = child.traits["length"]
            node2_length = child.length
            total_length = length + node2_length
            endYear_1 = most_recent_sampling_date - child.traits["height"]
            endYear_3 = child.absoluteTime # still not the exact same as endYear_1, but VERY close

            # print(child, child.name, child.traits["height"], child.traits["length"], endYear_1, endYear_2)
            # print(child.parent, child.parent.traits["height"], child.parent.traits["length"])
            negative_time = endYear_1 - startYear1
           
            test_meta.append({
                "node1": node1,
                "node2": node2,
                # "node1_heighttrait": height_trait,
                # "node1_height": height,
                # "node1_lengthtrait": length_trait,
                # "node1_length": length,
                # "node2_heightrait": node2_height_trait,
                # "node2_height": node2_height,
                # "node2_lengthtrait": node2_length_trait,
                # "node2_length": node2_length,
                # "length": total_length,
                "startYear_good": startYear1,
                "startYear_testing": startYear3,
                "endYear_good": endYear_1,
                "endYear_testing": endYear_3,
                "negativeTime": negative_time
                }) 
                
    else:
        pass

In [12]:
test_df = pd.DataFrame(test_meta)

pd.set_option('display.max_rows', None)
test_df_tips = test_df[test_df["node2"].isin([19, 20, 22, 23, 24, 25, 26, 29, 30, 31, 33, 34, 35, 36, 37, 39, 40, 44, 45, 46, 47,
                                             54, 55, 56, 58, 59, 60, 63, 64, 65, 67, 68, 70, 71, 72, 73, 75, 76, 78, 79, 83, 84, 85, 86])]
print(test_df[0:20])

#test_df.to_csv("./testing_heightlength.tsv", sep = '\t')

    node1  node2  startYear_good  startYear_testing  endYear_good  \
0       0     80     2023.625109        2023.625109   2023.687715   
1       0      1     2023.625109        2023.625109   2023.663333   
2       1     77     2023.663333        2023.638340   2023.731289   
3       1      2     2023.663333        2023.638340   2023.731763   
4       2     74     2023.731763        2023.725194   2023.765866   
5       2      3     2023.731763        2023.725194   2023.836136   
6       3     73     2023.836136        2023.893871   2024.120219   
7       3      4     2023.836136        2023.893871   2023.889390   
8       4     72     2023.889390        2023.916366   2023.923288   
9       4      5     2023.889390        2023.916366   2024.122246   
10      5     69     2024.122246        2024.071469   2024.276561   
11      5      6     2024.122246        2024.071469   2024.169386   
12      6     48     2024.169386        2024.171942   2024.198201   
13      6      7     2024.169386  

In [13]:
#print(len(test_df[test_df["node2_length"] < 0]))

In [15]:
# test_df_negtime = test_df[test_df["endYear_testing"] < test_df["startYear_testing"]]

#test_df_negtime["change"] = test_df_negtime["endYear_good"] - test_df_negtime["startYear"]

#print(test_df_negtime)
# print(len(test_df_negtime))

In [16]:
# This code chunk prints the parent and child (and corresponding heights) for each pair where the child height is greater than parent height

# for obj in tree.Objects:
#     if obj.is_node():
#         for child in obj.children:
#             if child.traits["height"] > obj.traits["height"]:
#                 print(f"{obj.label} → {child.label}: PARENT ({obj.traits['height']}) < CHILD ({child.traits['height']}) ❌")

In [17]:
# import matplotlib.pyplot as plt

# def plot_nexus(mytree, save_path):

#     plt.rcParams["font.family"] = "Arial"

#     fig,ax = plt.subplots(figsize=(10,20),facecolor='w')
        
#     x_attr=lambda k: k.absoluteTime
    
#     # for k in mytree.getExternal(): # Extract the flyway for each tip name using the name_to_flyway dictionary
#     #     flyway = name_to_flyway.get(k.name, None)  # None if no match
#     #     k.flyway = flyway

#     def color_by(k): # New color-by function 
#         if k.is_node():
#             for child in k.children:
#                 if child.traits["height"] > k.traits["height"]:
#                     return "#FF0303"
#         return "#5E5E5E"
#         # if hasattr(k, 'flyway') and k.flyway in flyway_colors:
#         #     return flyway_colors[k.flyway]
#         # return '#898989'  # default color
    
#     mytree.plotTree(ax,x_attr=x_attr, colour=color_by) ## tree
#     mytree.plotPoints(ax,
#                       x_attr=x_attr,
#                       size=30,
#                       colour="#5E5E5E",
#                       zorder=100)
#     ax.plot() ## need to call plot when only drawing the tree to force drawing of line collections

#     # # # If you want tree annotations
#     # # # target_func=lambda k: k.is_leaf() ## which branches will be annotated
#     # # # text_func=lambda k: k.traits['rea'] if k.traits['is_reassorted'] else "" ## what text is plotted
#     # # # text_x_attr=lambda k: k.x+0.003 ## where x coordinate for text is
#     # # # mytree.addText(ax,x_attr=text_x_attr,target=target_func,text=text_func, size='10')

#     # If you want the legend in the tree
#     # legend_handles = [Patch(color=color, label=NAflyway) for NAflyway, color in flyway_colors.items()]
#     # legend = ax.legend(handles=legend_handles, title="$\\bf{Flyway}$", loc="center right", fontsize='15')
#     # plt.setp(legend.get_title(),fontsize=25)

#     # # # If you want a scale bar
#     # # # scale_length = 0.01
#     # # # x_min, x_max = ax.get_xlim()
#     # # # scale_x_start = x_min + (x_max - x_min) * 0.1
#     # # # scale_x_end = scale_x_start + scale_length * (x_max - x_min)
#     # # # scale_y = 0.05 * (ax.get_ylim()[1] - ax.get_ylim()[0]) + ax.get_ylim()[0]
    
#     # ax.plot([scale_x_start, scale_x_end], [scale_y, scale_y], color='black', lw=2)
#     # ax.text((scale_x_start + scale_x_end) / 2, scale_y, f'{scale_length}', 
#     #         ha='center', va='bottom', fontsize=15)
    
#     ax.set_yticks([])
#     ax.set_yticklabels([])
#     [ax.spines[loc].set_visible(False) for loc in ax.spines if loc not in ['bottom']]
#     ax.tick_params(axis='x',labelsize=10,size=15, width=2,color='grey')
#     fig.tight_layout()
#     # plt.savefig(save_path)
#     # plt.show()

In [18]:
# plot_nexus(tree, "../SERAPHIM/nov2021-2024_no-suliformes_contphyl-orderDTA/tree_negativetimebranches.pdf")