In [1]:
%cd ~/cdv
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
import jax.numpy as jnp
import jax
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/nmiklaucic/cdv


In [2]:
from pathlib import Path
from facet.data.databatch import CrystalGraphs
from facet.layers import Context
from facet.utils import load_pytree
import pyrallis
from facet.config import MainConfig
from copy import deepcopy
import orbax.checkpoint as ocp
from facet.data.dataset import load_file
from facet.training_state import TrainingRun
from facet.checkpointing import best_ckpt

num_basis = 8

configs: dict[str, MainConfig] = {}

def normalization_constant(config):
    return config.data.metadata.avg_num_neighbors(config.model.edge_embed.r_max) ** config.model.interaction.message.radial_power

with open('configs/sevennet.toml') as f:
    configs['original'] = pyrallis.cfgparsing.load(MainConfig, f)    
    configs['radial'] = deepcopy(configs['original'])
    configs['radial'].checkpoint_params = f'precomputed/sevennet-trimmed-emb{num_basis}.ckpt'
    configs['radial'].model.interaction.message.radial_weight.inner_dims = []
    configs['radial'].model.edge_embed.radial_basis.num_basis = num_basis
    configs['radial-head'] = deepcopy(configs['radial'])    
    configs['radial-head'].checkpoint_params = f'precomputed/sevennet-trimmed-emb{num_basis}-linhead.ckpt'
    configs['radial-head'].model.head.inner_dims = []
    configs['radial-head-norm'] = deepcopy(configs['radial-head'])
    # configs['radial-head-norm'].model.edge_embed.r_max = 6.0
    configs['radial-head-norm'].model.interaction.message.radial_power = 0.7
    configs['radial'].checkpoint_params = f'precomputed/sevennet-trimmed-emb{num_basis}-linhead-07.ckpt'

for _name, config in configs.items():
    config.data.dataset_name = 'mptrj'

cgs = []
for i in range(1):
    cgs.append(load_file(configs['original'], group_num=15, file_num=i))

cg: CrystalGraphs = sum(cgs[1:], start=cgs[0])

models = {name: conf.build_regressor() for name, conf in configs.items()}

params = {}
params['original'] = load_pytree('precomputed/sevennet.ckpt')
params['radial'] = load_pytree(f'precomputed/sevennet-trimmed-emb{num_basis}.ckpt')
params['radial-head'] = load_pytree(f'precomputed/sevennet-trimmed-emb{num_basis}-linhead.ckpt')
params['radial-head-norm'] = load_pytree(f'precomputed/sevennet-trimmed-emb{num_basis}-linhead.ckpt')


bound = {}
bound = {name: models[name].bind(param) for name, param in params.items()}

def normalization_constant(name):
    return configs[name].data.metadata.avg_num_neighbors(bound[name].edge_embedding.r_max) ** configs[name].model.interaction.message.radial_power

for name in configs:
    conversion_factor = normalization_constant(name) / normalization_constant('original')
    print(name, conversion_factor)
    mace = params[name]['params']['mace']
    for layer in mace:
        mlp = mace[layer]['interaction']['SimpleInteraction_0']['SevenNetConv_0']['LazyInMLP_0'] 
        mlp[max(mlp.keys())]['kernel'] = mlp[max(mlp.keys())]['kernel'] * conversion_factor

bound = {name: models[name].bind(param) for name, param in params.items()}

sizes = {k: sum(x.size for x in jax.tree.leaves(v)) for k, v in params.items()}
sizes

original 1.0
radial 1.0
radial-head 1.0
radial-head-norm 0.34415728


[1m{[0m[32m'original'[0m: [1;36m842748[0m, [32m'radial'[0m: [1;36m624380[0m, [32m'radial-head'[0m: [1;36m616252[0m, [32m'radial-head-norm'[0m: [1;36m616252[0m[1m}[0m

In [3]:
p1 = load_pytree(f'precomputed/sevennet-trimmed-emb{num_basis}-linhead-07.ckpt')

In [4]:
from facet.utils import debug_stat


ctx = Context(training=False)
rng = jax.random.key(29205)
results = {}
for name, mod in bound.items():
    results[name] = mod(cg=cg, ctx=ctx).reshape(-1)

debug_stat(results=results);

In [49]:
results['p1'] = models['radial-head-norm'].apply(p1, cg=cg, ctx=ctx).reshape(-1)
debug_stat(results=results);

In [47]:
debug_stat(results=results);

In [38]:
jnp.mean(jnp.abs(results['radial-head'] - results['radial-head-norm']))

[1;35mArray[0m[1m([0m[1;36m0.00023297[0m, [33mdtype[0m=[35mfloat32[0m[1m)[0m

In [5]:
from facet.utils import save_pytree


if num_basis == 8:
    with open('configs/sevennet-trimmed.toml', 'w') as f:
        pyrallis.cfgparsing.dump(configs['radial-head-norm'], f)

    save_pytree(params['radial-head-norm'], f'precomputed/sevennet-trimmed-emb{num_basis}-linhead-07.ckpt')