In [None]:
"""Introduction to pytorch computation graph"""
from functools import wraps
from decorator import decorator
from collections.abc import Iterable, Callable

import nographs as nog
from graphviz import Digraph

import torch
# import torch.nn as nn
# import torch.nn.functional as F
from torchvision import models

In [None]:
nn_model = models.convnext_tiny()
#nn_model

In [None]:
def inge(x):
    """infinite range"""
    for i in range(1000):
        yield i
    throw

In [None]:
node_style ={"style": 'filled',            "shape": 'box',
             "align": 'left',              "fontsize": '12',
             "ranksep":'0.01',             "height":'0.02',   "width":'0.04',}
gr_style = {"linelength": '16', "rankdir": 'LR'}#"landscape":'True',

def adapter_graphviz( wrapped ):
    """translation to GraphViz visitor"""
    def node_id( item ):
        return str(hash(item))
    dot = Digraph( node_attr = node_style, graph_attr = gr_style, )
    @wraps( wrapped )
    def wrapper_bg( _item, _trav ):
        dot.node( node_id(_item), f"{str(type(_item))[:20]}\n{str(_item)[:20]}" )
        for result in wrapped(_item, _trav):
            dot.edge( node_id(_item), node_id(result[0]), str(result[1])[:20], )
            yield result
    wrapper_bg.dot = dot
    return wrapper_bg

#### The translation stage. For each node performing list of transformation, repeatedly, while we can

In [None]:
stop_types = {str, torch.Tensor, int, bool, }

stop_words = set(dir( object() )) | {'_is_full_backward_hook','__module__',
                                    '__covariant__','__contravariant__',
                                    'training','output_size','zero_grad',}

def filter_boring( node ):
    """Decorator for filtering, lets shrink graph a bit"""
    # print(f"1111{str(node[1])[:100]}--11--{str(type(node[0]))[:100]}<111>")
    if node[1] not in stop_words and type(node[0]) not in stop_types:
        yield node, False

def expand_iterables( item ):
    """Decorator for sequenses substitution"""
    # print(f"222222{str(item)[:100]}")
    if isinstance( item[0], dict):
        pass
        # for i,n in zip(item[0].items(),inge(200)):
        #     yield (i[1], f"dict[{i[0]}]"),True
    elif isinstance( item[0], Iterable) and type(item[0]) != str:
        for i,n in zip(item[0],inge(20)):
            yield (i,f"[{n}]"),True
    else:
        yield item, False

def expand_call( item ):
    """Decorator for following call"""
    print(f"33333333{str(item)[:100]}")
    if isinstance( item[0], Callable):
        try:
            yield (item[0](),f"{item[1]}()"), True
        except:
            print(f"!!!!!!!!!!!!!!!!!!{str(item[0])[:100]}-{item[1]} TROUBLE!")
            pass
    else:
        yield item, False

def do_stage( nodes, stage_transforms ):
    # annotate node, was it changed (and need to recheck)
    nodes = ( (n,False) for n in nodes )
    # while nodes needed to process
    while nodes:
        # perform sequence once
        for transform in stage_transforms:
            new_nodes = ( (new_n, new_changed or old_changed) 
                     for ( old_n, old_changed) in nodes 
                     for ( new_n, new_changed) in transform( old_n ) )
            nodes = list(new_nodes)
        # sort out not changed and repeat
        repeat_nodes = []
        for node, changed in nodes:
            print('#'*300 + node[1] + '#'*300 )
            if changed:
                repeat_nodes.append( (node,False) )
            else:
                yield node
        nodes = repeat_nodes
        print('$-'*300 + str(len(nodes)) + '$-'*300 )
        

In [None]:
@decorator
def stage_1( wrapped, item, _trav ):
    yield from do_stage( wrapped( item, _trav ),
                        [filter_boring, expand_iterables, expand_call] )

In [None]:
stat = {}

@adapter_graphviz
@stage_1
def all_attr(var, _):
    for attr_name, _ in zip(dir(var), inge(200)):
        attr = getattr(var, attr_name)
        if attr is not None:
            # print( f"0000000, {str(attr)[:100]}  {attr_name=} {type(attr)}")
            stat[ str(attr_name) ] = stat.get( str(attr_name) ,0 ) + 1
            yield ( attr, str(attr_name))
    print('@'*1000)

trav_forward = nog.TraversalBreadthFirst(next_labeled_edges=all_attr)
trav_forward.start_from( start_vertices=list(nn_model.children()), build_paths=True )

for _, _ in zip(trav_forward, inge(300)):
    pass

In [None]:
print(all_attr.dot.source)              # pylint: disable=pointless-statement
all_attr.dot

In [None]:
#dir(list(nn_model.children())[0])
# type(nn_model)
# type(nn_model).__bases__

In [None]:
list(list(nn_model.children())[0].children())
#list(nn_model.children())

In [None]:
issubclass(type(nn_model), torch.nn.modules.module.Module)

In [None]:
stat

In [None]:
isinstance( {1:1}, dict)

In [None]:
for i,j in {1:1}.items():
    print(i,j)

In [None]:
set(dir( object() ))

In [None]:
list(range(0,1,-1))[:100]