In [None]:
import pandas   as pd
import polars   as pl
import numpy    as np
import networkx as nx
import time
import sys
sys.path.insert(1, '../framework')
from racetrack import *
rt = RACETrack()

ofi  = rt.ontologyFrameworkInstance(base_filename='../../data/kaggle_imdb_600k/imdb_600k_international_movies')

_figures_out_filter_ = '''
g_nx = rt.createNetworkXGraph(df, [('sbj','obj','vrb')])
print(f'{len(g_nx)=} | {len(g_nx.edges)=}')
to_keep = set(nx.shortest_path(g_nx, _src_, _dst_))
_path_  = nx.shortest_path(g_nx, _src_, _dst_)
for i in range(len(_path_)-1): 
    g_nx.remove_edge(_path_[i],_path_[i+1])
print(f'{len(g_nx)=} | {len(g_nx.edges)=} | {len(to_keep)=}')
for n0 in _path_:
    for n1 in _path_:
        if n0 == n1: continue
        try:
            _p_ = nx.shortest_path(g_nx, n0, n1)
            to_keep = to_keep | set(_p_)
        except:
            pass
print(f'{len(to_keep)=}')
'''

_filter_ = {467458, 2983428, 352774, 5190151, 956937, 1873418, 1873420, 610830, 4527122, 2225684, 3892244, 5252128, 3370535, 3588136, 
            6330927, 233519, 1389103, 1461815, 324153, 6438969, 1891900, 2176575, 1937990, 2799176, 120910, 245839, 3163219, 7760468, 
            770644, 116822, 649307, 264295, 335976, 335977, 329840, 7055472, 1860218, 492671, 5746830, 1862294, 499866, 407200, 121508, 
            501419, 6417068, 5052588, 2226863, 4174018, 1621189, 467663, 6429909, 406744, 226522, 490203, 410845, 214749, 6302444, 
            1391343, 1369338, 611067, 3402499, 6714118, 212231, 503566, 215321, 6438681, 6331163, 6809887, 3147039, 603944, 5917483, 
            1583412, 5259071, 249665, 5902660, 6729546, 2452812, 439628, 256336, 4828499, 214869, 2642781, 2390373, 316774, 324089, 
            6729582, 927087, 5343089, 7745394, 1199987, 631157, 497526, 496503, 645505, 241539, 3066244, 514437, 6646662, 2271624, 
            601481, 742283, 2223509, 5339543, 406938, 3559324, 4130717, 500126, 514469, 1939878, 1939879, 1713576, 5889449, 1939880, 
            5339561, 5953968, 2444210, 310198, 6327734, 6305720, 241593, 1371586, 214980, 2223559, 4529097, 333771, 8009169, 767961, 
            332252, 4862943, 121318, 3146727, 1193960, 5346281, 7494636, 6156782, 295408, 982000, 5339637, 146425, 1889786, 7760382}
df   = ofi.df_triples.filter((pl.col('sbj').is_in(_filter_)) | (pl.col('obj').is_in(_filter_)))
len(ofi.df_triples), len(df)

In [None]:
def linkNodeShortest(df, 
                     relationships, 
                     pairs, 
                     view_path_index=1, 
                     use_digraph=False,
                     y_path_gap=15,
                     x_ins=10,
                     y_ins=10,
                     txt_h=10,
                     w=512):
    return RTLinkNodeShortest(rt_self=rt, df=df, relationships=relationships, pairs=pairs, 
                              view_path_index=view_path_index, 
                              use_digraph=use_digraph, 
                              y_path_gap=y_path_gap, x_ins=x_ins, y_ins=y_ins, txt_h=txt_h, w=w)

class RTLinkNodeShortest(object):
    def __init__(self, rt_self, **kwargs):
        self.rt_self       = rt_self
        self.df            = kwargs['df']
        self.relationships = kwargs['relationships']    # [('fm','to'), (('fm1','fm2'),('to1','to2'))]
        self.pairs         = kwargs['pairs']            # [('node_0', 'node_1'), ('node_2', 'node_3'), ...]
        self.vpi           = kwargs['view_path_index']  # path index for the centered view
        self.use_digraph   = kwargs['use_digraph']      # use a directed graph
        self.y_path_gap    = kwargs['y_path_gap']
        self.x_ins         = kwargs['x_ins']
        self.y_ins         = kwargs['y_ins']
        self.txt_h         = kwargs['txt_h']
        self.w             = kwargs['w']
        self.time_lu       = {}

        # If either from or to are tuples, concat them together... // could improve a little by ensuring any same tuples are not created more than once
        _ts_ = time.time()
        new_relationships = []
        for i in range(len(self.relationships)):
            _fm_ = self.relationships[i][0]
            if type(_fm_) == list or type(_fm_) == tuple:
                new_fm = f'__fmcat{i}__'
                self.df = self.rt_self.createConcatColumn(self.df, _fm_, new_fm)
                _fm_ = new_fm
            _to_ = self.relationships[i][1]
            if type(_to_) == list or type(_to_) == tuple:
                new_to = '__tocat{i}__'
                self.df = self.rt_self.createConcatColumn(self.df, _to_, new_to)
                _to_ = new_to
            if len(self.relationships[i]) == 2: new_relationships.append((_fm_,_to_))
            else:                               new_relationships.append((_fm_,_to_,self.relationships[i][2]))
        self.relationships = new_relationships
        self.time_lu['concat_columns'] = time.time() - _ts_

        self.node_size_px = 4

    # def _repr_svg_(self):
    def _repr_svg_(self):
        return self.renderSVG()

    # def renderSVG(self):
    def renderSVG(self):
        svg = []        
        y_base = self.y_ins
        for _pair_ in self.pairs:
            g = self.rt_self.createNetworkXGraph(self.df, self.relationships, use_digraph=self.use_digraph)
            p = nx.shortest_path(g, _pair_[0], _pair_[1])
            y_top       = y_base
            y_floors    = max(abs(len(p) - self.vpi), self.vpi - len(p))
            y_base     += self.y_path_gap*y_floors
            x_path_gap  = (self.w - 2*self.x_ins)/(len(p)-1)
            node_to_xy  = {}
            for i in range(len(p)-1):
                n0, n1 = p[i], p[i+1]
                x0 = self.x_ins + x_path_gap*i
                node_to_xy[n0] = (x0, y_base)
                x1 = self.x_ins + x_path_gap*(i+1)
                node_to_xy[n1] = (x1, y_base)
                svg.append(f'<line x1="{x0}" y1="{y_base}" x2="{x1}" y2="{y_base}" stroke="black" stroke-width="2" />')
            for _node_ in node_to_xy:
                svg.append(f'<circle cx="{node_to_xy[_node_][0]}" cy="{node_to_xy[_node_][1]}" r="{self.node_size_px}" stroke="black" stroke-width="2" />')
            y_bot = y_base + self.y_path_gap*y_floors

            if self.vpi > 0 and self.vpi < len(p)-1:
                node_center = p[self.vpi]
                svg.append(f'<line x1="{node_to_xy[node_center][0]}" y1="{y_top}" x2="{node_to_xy[node_center][0]}" y2="{y_bot}" stroke="gray" stroke-width="0.5" stroke-dasharray="1 5 1" />')
                for i in range(len(p)-1): g.remove_edge(p[i],p[i+1])
            offset = 1
            while offset < len(p):
                j, y = self.vpi - offset, y_base - self.y_path_gap*offset
                if j >= 0:
                    try:    pp = nx.shortest_path(p[j],p[self.vpi])
                    except: pp = None
                    if pp is not None:
                        svg.append(self.rt_self.svgText(f'{len(pp)}', node_to_xy[p[self.vpi]][0] - x_path_gap/4, y, self.txt_h, anchor='end'))
                    else:
                        svg.append(self.rt_self.svgText('0',          node_to_xy[p[self.vpi]][0] - x_path_gap/4, y, self.txt_h, anchor='end'))
                j = self.vpi + offset
                if j < len(p)-1:
                    try:    pp = nx.shortest_path(p[self.vpi],p[j])
                    except: pp = None
                    if pp is not None:
                        svg.append(self.rt_self.svgText(f'{len(pp)}', node_to_xy[p[self.vpi]][0] + x_path_gap/4, y, self.txt_h, anchor='start'))
                    else:
                        svg.append(self.rt_self.svgText('0',          node_to_xy[p[self.vpi]][0] + x_path_gap/4, y, self.txt_h, anchor='start'))
                offset += 1

            y_base = y_bot+self.y_path_gap
        y_base += self.y_ins
        svg.insert(0, f'<svg width="{self.w}" height="{y_base}">')
        svg.append('</svg>')
        return ''.join(svg)
_src_, _dst_ = 1891900, 1939878
linkNodeShortest(df, relationships=[('sbj','obj','vrb')], pairs=[(_src_,_dst_)], view_path_index=5)

In [None]:
g_nx = rt.createNetworkXGraph(df, [('sbj','obj','vrb')])
pos  = nx.spring_layout(g_nx) 
rt.linkNode(df, [('sbj','obj','vrb')], pos, draw_labels=True, label_links=True, color_by='vrb', link_color='vary', link_size=4, w=1024, h=768)