In [1]:
%cd ~/cdv/
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import jax.numpy as jnp
import jax
import jax.random as jr
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns
from wat import wat
import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/nmiklaucic/cdv


In [2]:
import treescope as ts
import treescope.figures as tsf

ts.basic_interactive_setup(autovisualize_arrays=False)

tsf.inline(tsf.bolded('Hi'))

In [5]:
from pathlib import Path
import pyrallis
from facet.config import MainConfig
import orbax.checkpoint as ocp

from facet.training_state import TrainingRun
from facet.checkpointing import best_ckpt

# run_dir = Path('logs') / '06-30-19_406'
# run_dir = Path('logs') / 'enb-6'
run_dir = Path('logs') / '09-14-02-21_953'

conf_file = run_dir / 'config.toml'

# conf_file = 'configs/skeleton.toml'

with open(conf_file) as f:
    config = pyrallis.cfgparsing.load(MainConfig, f)


model = config.build_regressor()

In [6]:
from facet.dataset import load_file
cg = load_file(config)
cg

In [8]:
from facet.layers import Context

ckpt = best_ckpt(run_dir)
params = ckpt['state']['params']

# out, params = model.init_with_output(jr.key(29205), cg=cg, ctx=Context(training=True))

In [12]:
from flax import linen as nn
from dataclasses import dataclass
from treescope import rendering_parts as tsr

from typing import Sequence, Any
from facet.layers import Context, Identity
from facet.utils import debug_stat, debug_structure, flax_summary, intercept_stat, callable_name, signature
def add_with_duplicated_name(d: dict, k, v):
    prefix = 0
    while f'{prefix}_{k}' in d:
        prefix += 1
    d[f'{prefix}_{k}'] = v

class Params(dict):
    def __treescope_repr__(self, path, subtree_renderer):
        kwargs = dict(path=path, subtree_renderer=subtree_renderer)
        size = jax.tree.reduce(lambda x, y: x + y, jax.tree.map(lambda x: x.size, dict(self.items())), initializer=0)        

        return ts.repr_lib.render_object_constructor(type(self), {'#': size, **self} if size > 0 else self, **kwargs)
  


@dataclass
class ModuleCall:  
    module: nn.Module
    input: dict[str, Any]
    params: dict[str, Any]
    children: dict[str, 'ModuleCall']
    output: Any

    def __treescope_repr__(self, path, subtree_renderer):
        kwargs = dict(path=path, subtree_renderer=subtree_renderer)
        attributes = {}
        if len(self.input):
            attributes['input'] = self.input

        if len(self.children):
            attributes['children'] = self.children

        if len(self.params):
            attributes['params'] = Params(**self.params)

        if self.output is not None:
            attributes['output'] = {'out': self.output}
        
            
        return ts.repr_lib.render_object_constructor(
            object_type=type(self.module),
            attributes=attributes,
            color=ts.formatting_util.color_from_string(str(type(self.module))),
            **kwargs
        )
  

def insert(stack, call, path):
    if len(path) == 0:
        i = 0
        while f'{i}' in call.children:
            i += 1
        call.children[f'{i}'] = stack
        return call
    
    head, *tail = path
    if head in stack.children:
        stack.children[head] = insert(stack.children[head], call, tail)        
    else:
        stack.children[head] = call

    return stack



class FlowRecorder:
    def __init__(self):
        self.stack = None
        self.call_chain = []

    def __call__(self, next_fun, args, kwargs, context):        
        # print(self.call_chain)
        # print(type(context.module), context.module.path, context.method_name)
        if context.method_name == 'setup' or isinstance(context.module, Identity):
            return next_fun(*args, **kwargs)
        
        if context.method_name == '__call__':
            path = context.module.path
        else:
            *head, tail = context.module.path
            path = (*head, tail + '.' + context.method_name)

        if path:
            self.call_chain.append(path[-1])
        
        sig = signature(next_fun)
        bound = sig.bind(*args, **kwargs)              

        call = ModuleCall(
            context.module,
            {k: v for k, v in bound.arguments.items() if k != 'ctx'},
            {k: v for k, v in context.module.variables.get('params', {}).items() if not isinstance(v, dict)},
            {},
            None
        )

        if self.stack is None:
            self.stack = call
        else:
            self.stack = insert(self.stack, call, self.call_chain)

        out = context.orig_method(*args, **kwargs)
        call.output = out

        if path:
            self.call_chain.remove(path[-1])

        return out


obj = FlowRecorder()
ctx = Context(training=False)
mod = model.bind(params)
with nn.intercept_methods(obj):
    out = mod(cg=cg, ctx=ctx)

In [13]:
import e3nn_jax as e3nn
from typing import Any
from flax import struct
from pymatgen.core import Element

elements = {
   z: Element.from_Z(z).symbol
   for z in range(1, 100)
}

colors = pd.read_csv('https://raw.githubusercontent.com/CorySimon/JMolColors/master/jmolcolors.csv')
jmol_palette = [(row['R'], row['G'], row['B']) for i, row in colors.iterrows()]

# x = e3nn.normal('128x0e + 64x1e + 32x2e', leading_shape=(32,))


if is_dark:
    div = rp.mpl_div_icefire_shift
else:
    div = rp.mpl_div_coolwarm_shift

ts.default_diverging_colormap.set_globally((255 * div(jnp.linspace(0, 1, 20))).tolist())

def render_tensor(arr, **kwargs):    
    axis_item_labels = kwargs.get('axis_item_labels', {})    
    axis_labels = kwargs.get('axis_labels', {})
    for i, size in enumerate(arr.shape):
        if size == config.data.num_species:
            axis_labels[i] = 'species'
            axis_item_labels[i] = config.data.metadata['elements']
        elif size == config.data.batch_n_nodes:
           node_mask = cg.padding_mask[cg.nodes.graph_i]
           new_shape = [1 for _ in range(i)] + [size] + [1 for _ in range(i + 1, len(arr.shape))]
           kwargs['valid_mask'] = node_mask.reshape(*new_shape)    

    kwargs['axis_item_labels'] = axis_item_labels
    kwargs['axis_labels'] = axis_labels

    if arr.dtype == np.int16 and jnp.max(arr) <= 95:
       kwargs['value_item_labels'] = elements
       kwargs['colormap'] = jmol_palette
        
    return tsf.figure_from_treescope_rendering_part(tsr.build_full_line_with_annotations(tsr.build_custom_foldable_tree_node(contents=ts.render_array(
        arr,
        pixels_per_cell=5,
        truncate=True,      
        **kwargs,
    ).treescope_part, label=tsr.text(str(arr.shape)))))

# def render_tensor(arr, **kwargs):
#   return tsf.inline(str(arr.shape))

def irrep_array_visualizer(
    value: Any,
    path: tuple[Any, ...] | None,
):
  if isinstance(value, (np.ndarray, jax.Array)):
    return ts.IPythonVisualization(render_tensor(value.squeeze()), replace=True)
  elif isinstance(value, e3nn.IrrepsArray):
    abs_max = jnp.max(jnp.abs(value.array)).item()
    vmin = -abs_max
    vmax = abs_max

    visualizations = []
    for ir_mul, chunk in zip(value.irreps, value.chunks):
      if chunk is None:
        continue
      color = cs[ir_mul.ir.l]
      ndim = chunk.ndim
      visualizations.append(tsf.indented(tsf.with_color(render_tensor(
        chunk,
        rows=[ndim-1],        
        sliders=list(range(0, ndim-2)),
        vmin=vmin,
        vmax=vmax,
        axis_labels={
          (ndim-1):str(ir_mul.ir)
        },       
      ), color)))
    return ts.IPythonVisualization(        
        tsf.inline(*visualizations),
        replace=True
    )
  
  

ts.display(obj.stack, autovisualize=irrep_array_visualizer)

In [11]:
with open('reports/model_flow.html', 'w') as f:
    with ts.active_autovisualizer.set_scoped(irrep_array_visualizer):
        f.write(ts.render_to_html(obj.stack, compressed=False))