In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Extract dataset name from the file path
import os
dataset_path = '/mnt/ceph/users/lbrown/MouseData/David4EPI/LineageGraph.json'
dataset_name = os.path.basename(os.path.dirname(dataset_path))
print(f"Dataset name: {dataset_name}")

with open(dataset_path, 'r') as f:
    data = json.load(f)

KeyboardInterrupt: 

In [None]:
nodes_data = data['Nodes']
edges_data = data['Edges']



# Create a dictionary for node properties
nodes = {node['Name']: {'frame': int(node['Name'].split('_')[0])} for node in nodes_data}

# Build parent-child relationships
parent_of = {edge['EndNodes'][1]: edge['EndNodes'][0] for edge in edges_data}
children_of = {}
for edge in edges_data:
    parent_name = edge['EndNodes'][0]
    child_name = edge['EndNodes'][1]
    if parent_name not in children_of:
        children_of[parent_name] = []
    children_of[parent_name].append(child_name)

# Identify roots (nodes with no parents)
root_nodes = sorted([name for name in nodes if name not in parent_of])

pos = {}
x_offset = 0

def layout_tree(node_name, current_x):
    """Recursively assign x positions to nodes."""
    global x_offset
    
    if node_name in pos:
        return

    pos[node_name] = (current_x, nodes[node_name]['frame'])

    children = children_of.get(node_name, [])
    
    if not children:
        x_offset += 1 
        return
        
    start_x = current_x - (len(children) - 1) / 2.0
    
    child_positions = []
    for i, child in enumerate(sorted(children)):
        temp_x = start_x + i
        is_occupied = any(
            p[1] == nodes[child]['frame'] and abs(p[0] - temp_x) < 0.5
            for p in pos.values()
        )
        if is_occupied:
             temp_x = x_offset + 1

        layout_tree(child, temp_x)
        child_positions.append(pos[child][0])

    if child_positions:
        pos[node_name] = (np.mean(child_positions), nodes[node_name]['frame'])

# Layout all trees
for root in root_nodes:
    if root not in pos:
        layout_tree(root, x_offset)


plt.figure(figsize=(30, 22))

# Plot edges
for child, parent in parent_of.items():
     if parent in pos and child in pos:
        plt.plot([pos[parent][0], pos[child][0]], [pos[parent][1], pos[child][1]], 'k-', lw=0.7, alpha=0.6)

# Plot nodes and labels
for name, p in pos.items():
    color = 'lightblue' 
    if name in children_of:
        if len(children_of[name]) > 1:
            color = 'green'
    else:
        color = 'red'

    plt.plot(p[0], p[1], 'o', markersize=7, color=color, markeredgecolor='black', markeredgewidth=0.5)
    plt.text(p[0] + 0.1, p[1], name, fontsize=5, ha='left', va='center')
    
plt.xlabel("Individual Cell Lineages", fontsize=16)
plt.ylabel("Frame Number", fontsize=16)
plt.title(f"Cell Lineage Tree with Node Labels - {dataset_name}", fontsize=20)
plt.gca().invert_yaxis()
plt.xticks([]) 
plt.grid(axis='y', linestyle='--', linewidth=0.5)

# Create a legend
from matplotlib.lines import Line2D
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Stasis (Continues)', markerfacecolor='lightblue', markersize=12),
                   Line2D([0], [0], marker='o', color='w', label='Mitosis (Division)', markerfacecolor='green', markersize=12),
                   Line2D([0], [0], marker='o', color='w', label='Death / End of Track', markerfacecolor='red', markersize=12)]

plt.legend(handles=legend_elements, loc='upper right', fontsize=14)
plt.tight_layout()

# Use dataset name in the filename
output_filename = f'lineage_tree_with_labels_{dataset_name}.svg'
plt.savefig(output_filename, dpi=800, bbox_inches='tight')
plt.close()

print(f"Lineage tree with node labels saved as {output_filename}")

Lineage tree with node labels saved as lineage_tree_with_labels.png
