In [None]:
%cd ~/cdv
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

import rho_plus as rp

is_dark = True
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

In [None]:
from pymatgen.core import Structure

s = Structure.from_file('data/087.cif')
s

In [44]:
from sevenn.sevennet_calculator import SevenNetCalculator
seven_calc = SevenNetCalculator("7net-0", device='cpu')  # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...

In [45]:
# ckpt_params = dict(seven_calc.model.named_parameters())
# ckpt_params = {k: v.detach().cpu().numpy() for k, v in ckpt_params.items()}
# np.save('precomputed/sevennet_ckpt.npy', ckpt_params)

In [None]:
from functools import partial
import torch
from copy import deepcopy


def serialize_atomgraph(data):    
    values = []
    for value in data:
        if hasattr(value, 'num_atoms'):
            values.append({
                k: (np.array(torch.clone(v).numpy(force=True)) if isinstance(v, torch.Tensor) else v)
                for k, v in dict(value).items()
            })
        else:
            values.append(value)

    return values


class Recorder:    
    def __init__(self):
        self.inputs = {}
        self.outputs = {}

    def pre_hook(self, mod, args, name='module'):        
        self.inputs[name] = serialize_atomgraph(args)[0]

    def post_hook(self, mod, args, output, name='module'):        
        self.outputs[name] = serialize_atomgraph([output])[0]


rec = Recorder()
atoms = s.to_ase_atoms()
handles = []
for name, mod in seven_calc.model.named_modules():
    try:
        handle = mod.register_forward_hook(partial(rec.post_hook, name=name))
        handles.append(handle)
        handle = mod.register_forward_pre_hook(partial(rec.pre_hook, name=name))
        handles.append(handle)
    except RuntimeError:
        continue
out = seven_calc.calculate(atoms=atoms)
print(seven_calc.results['free_energy'] / s.num_sites)

for handle in handles:
    handle.remove()

len(rec.inputs)

In [None]:
seven_calc.get_forces()

In [None]:
list(rec.outputs.keys())[-30:]

In [None]:
y = rec.outputs['0_self_connection_intro.linear'].numpy(force=True)
y[0, 0]

In [None]:
x = rec.inputs['0_self_connection_intro.linear'].numpy(force=True)
x

In [None]:
params = dict(seven_calc.model.get_submodule('0_self_connection_intro.linear').named_parameters())
w = params['weight'].numpy(force=True).reshape(128, 224)
w

In [None]:
((x / np.sqrt(128)) @ w)[0, 0]

In [None]:
rec.outputs['reduce_hidden_to_energy.linear']

In [11]:
ag = rec.outputs['']

In [None]:
i = np.where((ag['edge_index'][0] == 48) & (ag['edge_index'][1] == 0))[0].item()
i

In [None]:
rec.outputs['edge_embedding']['edge_embedding'][i]

In [None]:
rec.inputs['0_self_connection_intro']['x']

In [None]:
rec.outputs['0_self_connection_intro.linear'][0, 1].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x']['x'], robust=True, center=0)

In [None]:
params = dict(seven_calc.model.onehot_to_feature_x.named_parameters())
w = params['linear.weight'].detach().cpu().numpy()
sns.heatmap(w.reshape(89, 128)[[11]], center=0, robust=True)

In [None]:
seven_calc.model

In [None]:
rec.outputs['edge_embedding.spherical.sph'][i][6].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x'][0]['x'], robust=True)

In [None]:
rec.outputs['onehot_to_feature_x.linear'][0][0].item()

In [None]:
{k: getattr(v, 'shape', v) for k, v in rec.outputs['onehot_to_feature_x.linear'][0].items()}