In [1]:
import pandas as pd
import numpy as np
import re
import baltic as bt

from functools import reduce

In [2]:
xml_path = "segment4_final.xml"

tree_paths = "segment4_comp.trees.txt"

In [None]:
# Extract locations from XML
locations = []
read_loc = False

with open(xml_path, 'r') as f:
    for line in f:
        if 'state.dataType' in line:
            read_loc = True
        if read_loc:
            state_match = re.search(r'<state code="([A-Za-z_]+)"/>', line)
            if state_match:
                locations.append(state_match.group(1))

print(f"Found {len(locations)} locations.")


burnin = 10000000


Found 14 locations.


In [4]:
# CALCULATE TRANSITION RATES

trait = "state"
all_states = []  # list to store rows per MCMC sample

tree_file = tree_paths  # tree_paths is now ONE file
print(f"Processing: {tree_file} ...")

burn_passed = False

for line in open(tree_file, "r"):
    # identify MCMC tree line
    tree_match = re.search(r"^tree STATE_(\d+).*", line)
    if not tree_match:
        continue

    state_number = int(tree_match.group(1))
    if state_number < burnin:
        continue
    burn_passed = True

    # extract treestring
    treestring_match = re.search(r"= \[.*?\] (.*);", line)
    if not treestring_match:
        continue
    treestring = treestring_match.group(1) + ";"

    # build baltic tree
    ll = bt.make_tree(treestring)

    # Compute total branch length 
    total_branch_length = sum(
        getattr(k, "length", 0.0) for k in ll.Objects if hasattr(k, "length")
    )

    # Count transitions
    tree_migration_count = {
        origin: {dest: 0 for dest in locations} for origin in locations
    }

    for k in ll.Objects:
        par_state = k.parent.traits.get(trait) if k.parent and k.parent.traits else None
        cur_state = k.traits.get(trait)
        if par_state is None or cur_state is None:
            continue
        if par_state != cur_state:
            if par_state in locations and cur_state in locations:
                tree_migration_count[par_state][cur_state] += 1

    # Compute rate per branch length
    data = []
    for origin in locations:
        for dest in locations:
            if origin == dest:
                continue
            mig_count = tree_migration_count[origin][dest]
            if total_branch_length > 0:
                mig_rate = mig_count / total_branch_length
            else:
                mig_rate = np.nan
            data.append({
                "mcmc_state": state_number,
                "origin": origin,
                "destination": dest,
                "migration_count": mig_count,
                "total_branch_length": total_branch_length,
                "transition_rate": mig_rate
            })

    # store data for this MCMC sample
    all_states.append(pd.DataFrame(data))

# Combine into single DataFrame
transition_df = pd.concat(all_states, ignore_index=True)

print("Total MCMC states extracted:", transition_df['mcmc_state'].nunique())
print(transition_df.head())


Processing: segment4_comp.trees.txt ...
Total MCMC states extracted: 90001
   mcmc_state     origin destination  migration_count  total_branch_length  \
0    10000000  Australia      Canada                0          1259.016861   
1    10000000  Australia       China                0          1259.016861   
2    10000000  Australia     Denmark                0          1259.016861   
3    10000000  Australia      France                0          1259.016861   
4    10000000  Australia       Italy                0          1259.016861   

   transition_rate  
0              0.0  
1              0.0  
2              0.0  
3              0.0  
4              0.0  


In [None]:
# save as csv
transition_df.to_csv('transition_rates_fixed.csv')
print("Saved transition rates to transition_rates_fixed.csv")

Saved transition rates to transition_rates_fixed.csv
