In [None]:
!brew install graphviz

In [None]:
!pip install graphviz

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from graphviz import Digraph

taxonomy_path = "../data/metahit/taxonomy.tsv"

# Read Taxonomy
taxon = pd.read_csv(taxonomy_path, delimiter='\t', header=None)
id2genus = {i: label for i, label in enumerate(taxon[2])}

# ids of the most frequent 10 genera
ids_interest = [92, 139, 183, 273, 104, 169, 260, 234, 15, 214]

# Define the list of edges representing the tree structure
edges = [
    ("Bacteria", "Bacillati"),
    ("Bacillati", "Actinomycetota"),
    ("Actinomycetota", "Actinomycetes"),
    ("Actinomycetes", "Bifidobacteriales"),
    ("Bifidobacteriales", "Bifidobacteriaceae"),
    ("Bifidobacteriaceae", "Bifidobacterium"),
    ("Bacillati", "Bacillota"),
    ("Bacillota", "Bacilli"),
    ("Bacilli", "Lactobacillales"),
    ("Lactobacillales", "Lactobacillaceae"),
    ("Lactobacillaceae", "Lactobacillus"),
    ("Lactobacillales", "Streptococcaceae"),
    ("Streptococcaceae", "Streptococcus"),
    ("Bacillota", "Clostridia"),
    ("Clostridia", "Eubacteriales"),
    ("Eubacteriales", "Eubacteriaceae"),
    ("Eubacteriaceae", "Eubacterium"),
    ("Eubacteriales", "Clostridiaceae"),
    ("Clostridiaceae", "Clostridium"),
    ("Eubacteriales", "Oscillospiraceae"),
    ("Oscillospiraceae", "Ruminococcus"),
    ("Bacteria", "Pseudomonadati"),
    ("Pseudomonadati", "Bacteroidota"),
    ("Bacteroidota", "Bacteroidia"),
    ("Bacteroidia", "Bacteroidales"),
    ("Bacteroidales", "Tannerellaceae"),
    ("Tannerellaceae", "Parabacteroides"),
    ("Bacteroidales", "Rikenellaceae"),
    ("Rikenellaceae", "Alistipes"),
    ("Bacteroidales", "Bacteroidaceae"),
    ("Bacteroidaceae", "Bacteroides"),
    ("Pseudomonadati", "Pseudomonadota"),
    ("Pseudomonadota", "Gammaproteobacteria"),
    ("Gammaproteobacteria", "Enterobacterales"),
    ("Enterobacterales", "Enterobacteriaceae"),
    ("Enterobacteriaceae", "Escherichia"),
]

# Create a color dictionary mapping each leaf node to a hex color
colors = plt.cm.tab10(np.linspace(0, 1, len(ids_interest)))
hex_colors = [mcolors.to_hex(c) for c in colors]
color_dict = {label: color for label, color in zip(ids_interest, hex_colors)}

# Create a Digraph object from Graphviz
dot = Digraph(comment='Phylogenetic Tree', format='png')

# Set node attributes: font, size, shape, style, and color
dot.attr(
    'node',
    fontname='Helvetica',
    fontsize='12',           
    shape='box',
    style='filled,rounded',
    color='black',           
    fillcolor='white',       
    margin='0.5,0.1', 
    fixedsize='false'       
)

# Set edge attributes: font and size
dot.attr('edge', fontname='Helvetica', fontsize='10')

# Set fill color based on color_dict
for label in ids_interest:
    node_fillcolor = color_dict.get(label)  
    dot.node(id2genus[label], fillcolor=node_fillcolor)


# Add edges to the graph based on the parent-child relationships
for parent, child in edges:
    dot.edge(parent, child)

# Render the tree diagram and save it as a PNG image
dot.render('../figures/phylogenetic_tree', view=True)
