In [None]:
import os
import baltic as bt
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
from datetime import datetime, timedelta, date
from collections import OrderedDict
import numpy as np
import matplotlib.lines as mlines
import plotly.express as px
from matplotlib.ticker import ScalarFormatter
from collections import defaultdict, Counter

In [None]:
#todays date for exporting with date tag
todays_date = str(date.today())
print(todays_date)

In [None]:
#define key things- tree_path and traitName will need to be configured
tree_path = 'output_title.json'
treename = tree_path.split('/')[-1].replace('.json', '')
traitName = "TC_status"

json_translation={'absoluteTime':lambda k: k.traits['node_attrs']['num_date']['value'],'name':'name'} ## allows baltic to find correct attributes in JSON, height and name are required at a minimum
branchWidth=2

In [None]:
#import tree
tree,meta = bt.loadJSON(tree_path,json_translation=json_translation)

# can uncomment this below if you want to see info on what is in these trees in terms of attributes/format and such
# for k in tree.Objects: 
#         if k.branchType == "node":
#             print(k.traits)
print(tree.root.absoluteTime,tree.root.parent.index)

In [None]:
# Initialize a list to store confidence data for each node
confidence_data = []

# Loop through each node in the tree
for node in tree.Objects:
    if getattr(node, 'branchType', None) == "node":
        node_attrs = node.traits.get('node_attrs', {})
        if traitName in node_attrs and 'confidence' in node_attrs[traitName]:
            confidence_dict = node_attrs[traitName]['confidence']
            confidence_data.append(confidence_dict)

# Define all expected columns
expected_cols = ['mcleod', 'greater-MN', 'non-MN', 'TC_county']

# Create DataFrame and enforce expected columns
confidence_df = pd.DataFrame(confidence_data, columns=expected_cols)

# Compute max confidence score
confidence_df['trait_conf'] = confidence_df[expected_cols].max(axis=1)

# Round results
confidence_df = confidence_df.round(2)

# Preview
print(confidence_df.tail(10))

In [None]:
confidence_df.tail(20)

In [None]:
##this is just showing you the confidence values for the inferred traits for internal nodes- helpful for considering subsampling

# Plotting histogram of 'trait_conf' column
plt.figure(figsize=(10, 6))
sns.histplot(confidence_df['trait_conf'], bins=20, kde=False, color='skyblue')

# Customizing plot
plt.title('Dist. of trait confidence for ' + str(treename))
plt.xlabel("Trait Confidence")
plt.ylabel("Frequency")
# plt.yscale('log')
# # Set more x-axis ticks with evenly spaced values between the min and max of trait_conf
# min_conf, max_conf = confidence_df['trait_conf'].min(), confidence_df['trait_conf'].max()
# plt.xticks(np.linspace(min_conf, max_conf, 15))  # 15 evenly spaced ticks



# Show the plot
plt.show()

In [None]:
def extract_subtrees(tree, traitName):
    
    # tree_strings = defaultdict(list)
    subtype_trees = defaultdict(list)
    transitions = Counter()

    for t in tree.Objects:  # Iterate over branches
        k = t  # branch
        kp = t.parent  # branch's parent

        # Get current node's and its parent's trait states
        kloc = k.traits.get(traitName, 'ancestor')
        kploc = kp.traits.get(traitName, 'ancestor')
        kc = kloc  # Current branch trait value
        kpc = kploc  # Parent branch trait value

        # # Debug: print current and parent trait states
        # print(f"Node: {k} - Trait: {kc} | Parent: {kp} - Trait: {kpc}")

        # Count transitions
        if kpc != kc:
            transitions[(kpc, kc)] += 1
            # print(f"Transition detected: {kpc} -> {kc}")

        # If states do not match, extract subtree
        if kc != kpc:
            traverse_condition = lambda w: w.traits.get(traitName, 'ancestor') == kc
            subtree = tree.subtree(k, traverse_condition=traverse_condition)
            
            if subtree:
                # print(f"Subtree extracted for transition: {kpc} -> {kc}")
                subtree.traverse_tree()  
                subtree.sortBranches()
                # tree_strings[kc].append(subtree.toString())
                subtype_trees[kc].append((kpc, subtree))
            else:
                print(f"Subtree extraction failed for {kpc} -> {kc}")

    # # Debug: Print final structures
    # print(f"Tree Strings: {dict(tree_strings)}")
    # print(f"Subtype Trees: {dict(subtype_trees)}")
    # print(f"Transitions: {dict(transitions)}")

    return subtype_trees, dict(transitions)

In [None]:
#execute subtrees function
subtype_trees, transitions = extract_subtrees(tree, traitName)

# # Print subtype_trees to verify its content
print(f"transitions: {transitions}")


In [None]:
# print(f"Tree Strings: {dict(tree_strings)}")

In [None]:
# # Dictionary to hold string representations
# subtree_strings = defaultdict(list)

# # Loop through each trait and its associated subtrees in subtype_trees
# for trait, subtree_list in subtype_trees.items():
#     for origin, subtree in subtree_list:
#         # Convert subtree to string using .toString() and add to the new dictionary
#         subtree_strings[trait].append((origin, subtree.toString()))

# # Print or process subtree_strings as needed
# print(subtree_strings)


In [None]:
def get_transitions_df_plot(transitions):
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Prepare lists to store transitions and their counts
    transition_list = []
    counts_list = []

    for (from_state, to_state), count in transitions.items():
        transition_list.append(f"{from_state}_{to_state}")
        counts_list.append(count)

    transitions_df = pd.DataFrame({
        'Transition': transition_list,
        'Count': counts_list
    })

    # If you want to plot ALL transitions, skip filtering
    filtered_df = transitions_df.copy()

    # Check if DataFrame is empty
    if filtered_df.empty:
        print("⚠️ No transitions to plot.")
    else:
        plt.figure(figsize=(10, 6))
        sns.barplot(data=filtered_df, x='Transition', y='Count')
        plt.xticks(rotation=45, ha='right')
        plt.title('Transitions Bar Plot')
        plt.ylabel('Count')
        plt.xlabel('Transition')
        plt.tight_layout()
        plt.show()

    return transitions_df

In [None]:
##runs function to plot the transitions count, and return the dataframe. 

transitions_df = get_transitions_df_plot(transitions)
print(transitions_df)

In [None]:
transitions_df.to_csv('transitions_parsedon_TCstatus.tsv', sep='\t', index=False)

In [None]:
def subtree_to_dataframe_with_origin(subtype_trees, traitName):
    # List to hold all subtree data
    all_subtree_data = []

    # Iterate over each subtype and its associated subtrees
    for subtype, subtree_list in subtype_trees.items():
        for i, (origin, subtree) in enumerate(subtree_list):
            # Extract the origin date, assuming it's the root's absoluteTime minus its length
            origin_date = subtree.root.absoluteTime - subtree.root.length
            
            # Iterate through nodes in the subtree
            for node in subtree.Objects:
                # Collect relevant information from each node
                node_info = {
                    'branchType': node.branchType,  # 'leaf' or 'internal'
                    'absoluteTime': node.absoluteTime,
                    'length': node.length,
                    'origin': origin,  # Include the origin trait
                    'origin_date': origin_date,  # Add the origin date here
                    'subtree_name': f"{subtype}_{i}",  # Unique name for the subtree
                    'traits': node.traits  # This is a dictionary; we can expand it later if needed
                }
                
                # Expand the traits dictionary into individual columns
                for trait, value in node.traits.items():
                    node_info[trait] = value  # Each trait becomes its own column

                # Add this node's information to the list
                all_subtree_data.append(node_info)
    
    # Convert the list of dictionaries into a DataFrame
    df = pd.DataFrame(all_subtree_data)
    
    # Drop the node_attrs column if it exists (assuming it may not be needed)
    if 'node_attrs' in df.columns:
        df.drop(columns=['node_attrs', 'traits'], inplace=True)

    return df

In [None]:
# for subtype, subtree_list in subtype_trees.items():
#     print(f"Subtype: {subtype}")
#     for origin, subtree in subtree_list:
#         print(f"  Origin trait: {origin}")
        
#         for node in subtree.Objects:
#             print("    Node data:")
#             print(vars(node))
#             print("    --------------------")


In [None]:
# # execute function and store in combined df
combined_df = subtree_to_dataframe_with_origin(subtype_trees, traitName)
combined_df.head(15)  # Display the first 10 rows of the DataFrame
combined_df.to_csv('totalinfo_parsedon_TCstatus.csv', sep='\t', index=False)

In [None]:
combined_df.info()

In [None]:
def make_subtree_df(combined_df):
    # Function to convert decimal year to date
    def decimal_year_to_date(decimal_year):
        year = int(decimal_year)
        decimal_part = decimal_year - year
        days_in_year = 365.25  # Average including leap years
        day_of_year = int(decimal_part * days_in_year)
        return datetime(year, 1, 1) + timedelta(days=day_of_year)

    # Make a new DataFrame that pulls out only leaf nodes and groups by subtree_name
    tip_counts = combined_df[combined_df['branchType'] == 'leaf'].groupby('subtree_name').size().reset_index(name='leaf_count')

    # Make separate sub DataFrame that pulls out the origin date, origin, and trait of tips of each subtree
    origin_info = combined_df[['subtree_name', 'origin_date', 'origin', 'TC_status']].drop_duplicates()

    # Merge origin sub DataFrame into tip_counts
    tip_counts = pd.merge(tip_counts, origin_info, on='subtree_name')

    # Filter to only include leaf nodes
    leaf_nodes = combined_df[combined_df['branchType'] == 'leaf']

    # Group by 'subtree_name' and get the max absoluteTime for each subtree
    last_tip_dates = leaf_nodes.groupby('subtree_name')['absoluteTime'].max().reset_index()

    # Rename the column to 'last_tip_date'
    last_tip_dates.rename(columns={'absoluteTime': 'last_tip_date'}, inplace=True)

    # Merge more recent tip dates into the tip_counts sub DataFrame
    tip_counts = tip_counts.merge(last_tip_dates, on='subtree_name', how='left')

    # Convert origin_date and last_tip_date to datetime
    tip_counts['origin_date'] = tip_counts['origin_date'].apply(decimal_year_to_date)
    tip_counts['last_tip_date'] = tip_counts['last_tip_date'].apply(decimal_year_to_date)

    # Now calculate and add persistence times in days
    tip_counts['subtree_duration_days'] = (tip_counts['last_tip_date'] - tip_counts['origin_date']).dt.days

    # Round the persistence times
    tip_counts['subtree_duration_days'] = tip_counts['subtree_duration_days'].round(1)

    # Return the final DataFrame
    return tip_counts

In [None]:
# execute function of making subtree dataframe for subtree analyses
tip_counts = make_subtree_df(combined_df)
tip_counts.head(5)
tip_counts.to_csv('tipcounts_parsedon_TCstatus.tsv', sep='\t', index=False)

In [None]:
def plot_clusters_over_time(tip_counts, traitName, treename, todays_date):
    """
    Plot leaf counts over time, colored by a given trait.

    Parameters:
    - tip_counts: DataFrame with 'origin_date', 'leaf_count', and traitName column
    - traitName: The column to color points by (e.g., 'urban_rural_combined')
    - treename: Name of the tree (used for title and file output)
    - todays_date: Date string to include in output filename
    """
    # Filter out non-MN entries
    if traitName in tip_counts.columns:
        filtered_tip_counts = tip_counts[tip_counts[traitName] != 'non-MN']
    else:
        raise ValueError(f"Column '{traitName}' not found in tip_counts DataFrame.")

    # Define color palette
    palette = {
        'mcleod': 'darkseagreen',
        'greater-MN': 'hotpink',
        'TC_county': 'red',
        'non-MN': 'cornflowerblue'
    }

    # Create the plot
    plt.figure(figsize=(10, 6))
    sns.scatterplot(
        data=filtered_tip_counts,
        x='origin_date',
        y='leaf_count',
        hue=traitName,
        palette=palette,
        s=75,
        edgecolor='w'
    )

    plt.title('Cluster size over time for ' + str(treename))
    plt.xlabel('Cluster Origin Date')
    plt.ylabel('Cluster Size')
    plt.legend(title='Trait')
    plt.grid(True)

    # Save or show
    output_filename = f"output/clusters_over_time_{treename}_{todays_date}.png"
    # plt.savefig(output_filename, format='png')
    plt.show()

In [None]:
plot_clusters_over_time(tip_counts, traitName='TC_status', treename='tree', todays_date='2025')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_persistence_boxplots(tip_counts, trait_column, treename, todays_date, fig_size=(10, 6)):
    """
    Plots box plots of persistence time grouped by trait with individual points.

    Parameters:
    - tip_counts: DataFrame containing 'subtree_duration_days' and the trait column.
    - trait_column: The column name in tip_counts representing the trait (e.g., 'TC_status').
    - treename: Tree name for plot title and saved filename.
    - todays_date: Date string for file naming.
    - fig_size: Tuple for figure size.
    """
    # Filter out 'non-MN'
    filtered_tip_counts = tip_counts[tip_counts[trait_column] != 'non-MN']

    # Get consistent order of trait categories
    trait_order = sorted(filtered_tip_counts[trait_column].dropna().unique())

    # Create the figure
    plt.figure(figsize=fig_size)

    # Boxplot without hue to avoid overlap
    sns.boxplot(
        data=filtered_tip_counts,
        x=trait_column,
        y="subtree_duration_days",
        palette="Set2",
        order=trait_order,
        showfliers=False
    )

    # Add stripplot (individual points)
    sns.stripplot(
        data=filtered_tip_counts,
        x=trait_column,
        y="subtree_duration_days",
        color='black',
        jitter=True,
        alpha=0.5,
        order=trait_order
    )

    # Customize labels and title
    plt.xlabel('Trait')
    plt.ylabel('Cluster Persistence Time (Days)')
    plt.title(f'Persistence Time by Trait for {treename}')
    plt.tight_layout()

    # Optionally save to file
    output_filename = f"output/persistence_boxplot_{treename}_{todays_date}.png"
    # plt.savefig(output_filename, format='png')

    # Show the plot
    plt.show()


# Call the function with your actual data
plot_persistence_boxplots(
    tip_counts=tip_counts,
    trait_column='TC_status',
    treename='tree',
    todays_date='2025'
)

    """
    Plots a phylogenetic tree with subtrees categorized by traits.
    
    Parameters:
    - subtype_trees: dict with subtrees categorized by traits (e.g., {'Urban': [...], 'Rural': [...]})
    - trait_name: str, name of the trait used to categorize subtrees
    - tree: b.tree object, the main tree object to set x-axis limits and cumulative y calculation
    - tip_size: int, size of the tip markers in the plot
    - fig_size: tuple, size of the overall figure
    - legend_loc: str, location of the legend
    
    Returns:
    - A Matplotlib plot showing the phylogenetic tree.
    """


In [None]:
def plot_tree(subtype_trees, traitName, tree, 
              tip_size=15, fig_size=(10, 10), legend_loc='lower left'):
   
    # Define color scheme for traits
    colors = {
        'mcleod': 'darkseagreen',
        'greater-MN': 'hotpink',
        'TC_county': 'red',
        'non-MN': 'cornflowerblue'
    }
    
    # Color function based on traits
    c_func = lambda k: colors.get(k.traits[traitName], 'whitesmoke')
    
    # Initialize figure and cumulative y offset
    fig = plt.figure(figsize=fig_size, facecolor='w')
    gs = gridspec.GridSpec(1, 1, wspace=0.0)
    ax = plt.subplot(gs[0], facecolor='w')
    cumulative_y = 0

    # Iterate over each trait and its subtrees
    for subtype in colors.keys():
        for t, tr in enumerate(sorted(subtype_trees[subtype], key=lambda x: (-x[1].root.absoluteTime, len(x[1].Objects)))):
            origin, loc_tree = tr
            y_attr = lambda k: k.y + cumulative_y  # Update y-attribute with cumulative offset
            
            # Plot the subtree branches and tips
            loc_tree.plotTree(ax, x_attr=lambda k: k.absoluteTime, y_attr=y_attr, colour=c_func)
            loc_tree.plotPoints(ax, x_attr=lambda k: k.absoluteTime, y_attr=y_attr, size=tip_size, colour=c_func, zorder=100)
            
            # Origin point for each subtree
            oriC = 'dimgrey' if origin == 'ancestor' else c_func(loc_tree.root.parent)
            oriX = loc_tree.root.absoluteTime - loc_tree.root.length
            oriY = loc_tree.root.y + cumulative_y
            ax.scatter(oriX, oriY, 50, facecolor=oriC, edgecolor='w', lw=1, zorder=200)  # Origin circle
            
            cumulative_y += loc_tree.ySpan + 5  # Update cumulative y for next subtree

    # Axis and plot styling
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()
    [ax.spines[loc].set_visible(False) for loc in ['top', 'right', 'left']]
    ax.tick_params(axis='x', size=0, labelsize=15)
    ax.tick_params(axis='y', size=0)
    ax.grid(axis='x', ls='--')
    
    # Set y-axis limits and remove labels
    ax.set_yticklabels([])
    ax.set_ylim(-35, cumulative_y + 35)
    # Determine the min and max dates in the tree
    min_date = min(node.absoluteTime for node in tree.Objects)
    max_date = max(node.absoluteTime for node in tree.Objects)
    
    # Set x-axis limits with some padding
    padding = 0.2  # Adjust padding as desired
    ax.set_xlim(min_date - padding, max_date + padding)    
    # Legend setup
    handles = [
        plt.Line2D([0], [0], marker='o', color='w', label=label,
                   markersize=10, markerfacecolor=color)
        for label, color in colors.items()
    ]
    plt.legend(handles=handles, title='Traits', loc=legend_loc)
    
    # Adjust layout and show plot
    plt.gcf().subplots_adjust(right=0.88)
    plt.title(treename + "_" + todays_date)
    output_filename = "output/subtree_plots_" + str(treename) + "_" + todays_date + ".png"
    # plt.savefig(output_filename, format='png')
    plt.show()


In [None]:
#execute plot trees function
plot_tree(subtype_trees, traitName, tree)

In [None]:
##you can update this to remove trees with tip trait non-MN if you want
def plot_tree_without_non_mcleod(subtype_trees, traitName, tree, tip_size=15, fig_size=(10, 12), legend_loc='lower left'):
   
    # Define color scheme for traits
    colors = {
        'mcleod': 'darkseagreen',
        'greater-MN': 'hotpink',
        'TC_county': 'red',
        'non-MN': 'cornflowerblue'
    }
    
    # Color function based on traits
    c_func = lambda k: colors.get(k.traits[traitName], 'whitesmoke')
    
    # Initialize figure and cumulative y offset
    fig = plt.figure(figsize=fig_size, facecolor='w')
    gs = gridspec.GridSpec(1, 1, wspace=0.0)
    ax = plt.subplot(gs[0], facecolor='w')
    cumulative_y = 0

    # Iterate over each trait and its subtrees
    for subtype in colors.keys():
        for t, tr in enumerate(sorted(subtype_trees[subtype], key=lambda x: (-x[1].root.absoluteTime, len(x[1].Objects)))):
            origin, loc_tree = tr
            
            # Filter to exclude subtrees where all tips are "non-mcleod"
            tip_traits = [node.traits.get(traitName) for node in loc_tree.Objects if node.branchType == 'leaf']
            if all(trait == 'non-MN' for trait in tip_traits):
                continue  # Skip this subtree if all tips are non-mcleod
            
            # Set y-attribute for nodes in the subtree
            y_attr = lambda k: k.y + cumulative_y  # Update y-attribute with cumulative offset
            
            # Plot the subtree branches and tips
            loc_tree.plotTree(ax, x_attr=lambda k: k.absoluteTime, y_attr=y_attr, colour=c_func)
            loc_tree.plotPoints(ax, x_attr=lambda k: k.absoluteTime, y_attr=y_attr, size=tip_size, colour=c_func, zorder=100)
            
            # Plot the origin point for each subtree
            oriC = 'dimgrey' if origin == 'ancestor' else c_func(loc_tree.root.parent)
            oriX = loc_tree.root.absoluteTime - loc_tree.root.length
            oriY = loc_tree.root.y + cumulative_y
            ax.scatter(oriX, oriY, 50, facecolor=oriC, edgecolor='w', lw=1, zorder=200)  # Origin circle
            
            cumulative_y += loc_tree.ySpan + 5  # Update cumulative y for next subtree

   # Axis and plot styling
    ax.xaxis.tick_bottom()
    ax.yaxis.tick_left()
    [ax.spines[loc].set_visible(False) for loc in ['top', 'right', 'left']]
    ax.tick_params(axis='x', size=0, labelsize=15)
    ax.tick_params(axis='y', size=0)
    ax.grid(axis='x', ls='--')
    
    # Set y-axis limits and remove labels
    ax.set_yticklabels([])
    ax.set_ylim(-35, cumulative_y + 35)
    # Determine the min and max dates in the tree
    min_date = min(node.absoluteTime for node in tree.Objects)
    max_date = max(node.absoluteTime for node in tree.Objects)
    
    # Set x-axis limits with some padding
    padding = 0.2  # Adjust padding as desired
    ax.set_xlim(min_date - padding, max_date + padding)    
    # Legend setup
    handles = [
        plt.Line2D([0], [0], marker='o', color='w', label=label,
                   markersize=10, markerfacecolor=color)
        for label, color in colors.items()
    ]
    plt.legend(handles=handles, title='Traits', loc=legend_loc)
    
    # Adjust layout and show plot
    plt.gcf().subplots_adjust(right=0.88)
    plt.title(treename + "_" + todays_date)
    output_filename = "output/subtree_plots_2_" + str(treename) + "_" + todays_date + ".png"
    # plt.savefig(output_filename, format='png')
    plt.show()


In [None]:
plot_tree_without_non_mcleod(subtype_trees, traitName, tree)