In [6]:
import sys
sys.path.append('..')

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from tracks_interactions.db.db_model import TrackDB
import pandas as pd
import networkx as nx
import numpy as np

In [22]:
new_db_path = r'D:\test_data\Exp6_gardener_v6.db'

engine = create_engine(f'sqlite:///{new_db_path}')
session = sessionmaker(bind=engine)()

In [16]:
def render_tree_view(self, G, pos):
    """
    Render the hierarchical tree using NetworkX and PyQtGraph.
    """
    y_max = 1

    # Iterate over nodes in the graph
    for node in G.nodes():
        if G.in_degree(node) == 0:  # Skip root for now
            continue

        node_data = G.nodes[node]
        node_name = node_data['name']

        # Get position in time
        x1 = node_data['start']
        x2 = node_data['stop']
        x_signal = [x1, x2]

        # Get y coordinate from the NetworkX position dictionary
        y_signal = np.array([pos[node][1]]).repeat(2)
        y_max = np.max([y_signal[0], y_max])

        label_color = self.labels.get_color(node_name)

        # Pen color and style adjustments based on the node's state
        if node_name == self.active_label:
            pen_color = mkColor((label_color * 255).astype(int))
            pen = mkPen(color=pen_color, width=4)
        else:
            label_color[-1] = 0.4
            pen_color = mkColor((label_color * 255).astype(int))
            pen = mkPen(color=pen_color, width=2)

        if not node_data['accepted']:
            pen.setStyle(Qt.DotLine)

        # Plot the horizontal line for the node
        self.plot_view.plot(x_signal, y_signal, pen=pen)

        # Add text label for the node
        if node_data['accepted']:
            text_item = TextItem(str(node_name), anchor=(1, 1), color='green')
        else:
            text_item = TextItem(str(node_name), anchor=(1, 1))

        text_item.setPos(x2, pos[node][1])
        self.plot_view.addItem(text_item)

        # Plot vertical lines to children
        for child in G.successors(node):
            x_signal = [x2, x2]
            y_signal = [pos[node][1], pos[child][1]]
            self.plot_view.plot(x_signal, y_signal, pen=pen)

    # Set plot axis limits
    self.plot_view.setXRange(0, self.t_max)
    self.plot_view.setYRange(0, 1.1 * y_max)


def _add_children(G, parent, df, n=2):
    """
    Recursively adds children to the NetworkX graph from a dataframe.
    
    G - NetworkX graph
    parent - parent node ID
    df - dataframe with information about children
    n - counter for numbering nodes
    """
    children = df[df['parent_track_id'] == parent]

    for _, row in children.iterrows():
        child_id = row['track_id']
        G.add_node(child_id, name=row['track_id'], start=row['t_begin'], stop=row['t_end'], accepted=row['accepted_tag'], num=n)
        G.add_edge(parent, child_id)

        n += 1
        n = _add_children(G, child_id, df, n)

    return n

def build_Newick_tree(session, root_id):
    """
    Build a NetworkX graph to represent the hierarchical tree structure.
    
    session - database session
    root_id - ID of the root node
    """
    # Get info about the family from the database
    query = session.query(TrackDB).filter(TrackDB.root == root_id)
    df = pd.read_sql(query.statement, session.bind)

    # Ensure the root exists
    assert len(df) > 0, 'No data for this root_id'

    # Create a NetworkX graph
    G = nx.DiGraph()

    # Add the root (trunk) node
    trunk_row = df[df['track_id'] == root_id]
    G.add_node(root_id, name=root_id, start=trunk_row['t_begin'].values[0], stop=trunk_row['t_end'].values[0], accepted=bool(trunk_row['accepted_tag'].values[0]), num=1)

    # Recursively add children
    _add_children(G, root_id, df)

    # add rendering
    pos = nx.spring_layout(G, seed=42)
    nx.set_node_attributes(G, pos, 'pos')


    # Return the NetworkX graph
    return G

In [24]:
G = build_Newick_tree(session,21274)

In [26]:
for node in G.nodes():

    node_data = G.nodes[node]
    node_name = node_data['name']

    # Get position in time (x-coordinates: start and stop)
    x1 = node_data['start']
    x2 = node_data['stop']
    x_signal = [x1, x2]

    # Get y-coordinate from the node's 'pos' attribute
    y_signal = np.array([node_data['pos'][1]]).repeat(2)
    print(node_name, x_signal, y_signal)

21274 [0, 39] [-0.1999618 -0.1999618]
21275 [40, 155] [-0.69401799 -0.69401799]
21276 [40, 97] [0.30172145 0.30172145]
21277 [98, 99] [0.10866962 0.10866962]
21278 [98, 217] [0.48358871 0.48358871]
