In [1]:
%cd ~/cdv
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

import rho_plus as rp

is_dark = True
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

/home/nmiklaucic/cdv


In [2]:
from pymatgen.core import Structure

s = Structure.from_file('data/087.cif')
s

Structure Summary
Lattice
    abc : 3.46566108 5.912730870000001 5.45788908
 angles : 89.99986914 89.99991061 90.01299071999999
 volume : 111.84044720244309
      A : 3.465661079995782 0.0 5.406950504201276e-06
      B : -0.0013405984239926575 5.912730718007075 1.35043098847761e-05
      C : 0.0 0.0 5.45788908
    pbc : True True True
PeriodicSite: Mn0 (Mn) (1.731, 4.91, 2.721) [0.4999, 0.8303, 0.4985]
PeriodicSite: Mn1 (Mn) (1.733, 1.003, 5.45) [0.5001, 0.1697, 0.9985]
PeriodicSite: Mn2 (Mn) (3.465, 1.953, 2.72) [0.9998, 0.3304, 0.4984]
PeriodicSite: Mn3 (Mn) (-0.0003616, 3.959, 5.449) [0.0001547, 0.6696, 0.9984]
PeriodicSite: O4 (O) (-0.0002849, 1.977, 0.6282) [4.711e-05, 0.3343, 0.1151]
PeriodicSite: O5 (O) (1.732, 4.933, 0.6285) [0.5001, 0.8342, 0.1151]
PeriodicSite: O6 (O) (1.732, 0.9801, 3.357) [0.4999, 0.1658, 0.6151]
PeriodicSite: O7 (O) (3.465, 3.936, 3.357) [1.0, 0.6657, 0.6151]

In [3]:
from sevenn.sevennet_calculator import SevenNetCalculator
seven_calc = SevenNetCalculator("7net-0", device='cpu')  # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...

In [4]:
# ckpt_params = dict(seven_calc.model.named_parameters())
# ckpt_params = {k: v.detach().cpu().numpy() for k, v in ckpt_params.items()}
# np.save('precomputed/sevennet_ckpt.npy', ckpt_params)

In [5]:
from functools import partial
import torch
from copy import deepcopy


def serialize_atomgraph(data):    
    values = []
    for value in data:
        if hasattr(value, 'num_atoms'):
            values.append({
                k: (np.array(torch.clone(v).numpy(force=True)) if isinstance(v, torch.Tensor) else v)
                for k, v in dict(value).items()
            })
        else:
            values.append(value)

    return values


class Recorder:    
    def __init__(self):
        self.inputs = {}
        self.outputs = {}

    def pre_hook(self, mod, args, name='module'):        
        self.inputs[name] = serialize_atomgraph(args)[0]

    def post_hook(self, mod, args, output, name='module'):        
        self.outputs[name] = serialize_atomgraph([output])[0]


rec = Recorder()
atoms = s.to_ase_atoms()
handles = []
for name, mod in seven_calc.model.named_modules():
    try:
        handle = mod.register_forward_hook(partial(rec.post_hook, name=name))
        handles.append(handle)
        handle = mod.register_forward_pre_hook(partial(rec.pre_hook, name=name))
        handles.append(handle)
    except RuntimeError:
        continue
out = seven_calc.calculate(atoms=atoms)
print(seven_calc.results['free_energy'] / s.num_sites)

for handle in handles:
    handle.remove()

len(rec.inputs)

-7.8014678955078125


133

In [6]:
seven_calc.get_forces()

array([[ 0.00138224,  0.00842797,  0.02226343],
       [-0.00134953, -0.00842321,  0.02189642],
       [ 0.00248519,  0.00626346,  0.02657895],
       [-0.00245262, -0.00637611,  0.02643918],
       [ 0.0008379 , -0.0194728 , -0.02410344],
       [ 0.00157933, -0.01839091, -0.0245294 ],
       [-0.00161705,  0.0183937 , -0.02446074],
       [-0.00086542,  0.01957793, -0.02408434]], dtype=float32)

In [7]:
list(rec.outputs.keys())[-30:]

['3_equivariant_gate.gate',
 '3_equivariant_gate',
 '4_self_connection_intro.linear',
 '4_self_connection_intro',
 '4_self_interaction_1.linear',
 '4_self_interaction_1',
 '4_convolution.weight_nn.layer0.act',
 '4_convolution.weight_nn.layer0',
 '4_convolution.weight_nn.layer1',
 '4_convolution.weight_nn.layer2',
 '4_convolution.weight_nn',
 '4_convolution.convolution',
 '4_convolution',
 '4_self_interaction_2.linear',
 '4_self_interaction_2',
 '4_self_connection_outro',
 '4_equivariant_gate.gate.sc.cut',
 '4_equivariant_gate.gate.sc',
 '4_equivariant_gate.gate.act_scalars.acts.0',
 '4_equivariant_gate.gate.act_scalars',
 '4_equivariant_gate.gate',
 '4_equivariant_gate',
 'reduce_input_to_hidden.linear',
 'reduce_input_to_hidden',
 'reduce_hidden_to_energy.linear',
 'reduce_hidden_to_energy',
 'rescale_atomic_energy',
 'reduce_total_enegy',
 'force_output',
 '']

In [8]:
y = rec.outputs['0_self_connection_intro.linear'].numpy(force=True)
y[0, 0]

0.1413189

In [9]:
x = rec.inputs['0_self_connection_intro.linear'].numpy(force=True)
x

array([[-0.02162305,  0.02764658,  0.03464852, ...,  0.02918951,
        -0.03034204,  0.09723409],
       [-0.02162305,  0.02764658,  0.03464852, ...,  0.02918951,
        -0.03034204,  0.09723409],
       [-0.02162305,  0.02764658,  0.03464852, ...,  0.02918951,
        -0.03034204,  0.09723409],
       ...,
       [ 0.05921005, -0.01971585,  0.00247053, ...,  0.1316716 ,
         0.07415292, -0.06879845],
       [ 0.05921005, -0.01971585,  0.00247053, ...,  0.1316716 ,
         0.07415292, -0.06879845],
       [ 0.05921005, -0.01971585,  0.00247053, ...,  0.1316716 ,
         0.07415292, -0.06879845]], dtype=float32)

In [25]:
rec.outputs['reduce_input_to_hidden']['x'].round(2)

array([[ 0.04, -0.01,  0.  , -0.07, -0.01,  0.02, -0.01, -0.05,  0.  ,
        -0.02,  0.01,  0.  , -0.04, -0.  , -0.  ,  0.02,  0.27,  0.02,
         0.75,  0.02,  0.03,  0.01, -0.  , -0.16,  0.09, -0.02,  0.01,
         0.01, -0.04, -0.  , -0.  ,  0.01,  0.03, -0.05,  0.  , -0.02,
        -0.  ,  0.02,  0.1 ,  0.02,  0.05, -0.08,  0.01, -0.  , -0.03,
         0.  ,  0.03,  0.02,  0.03, -0.02, -0.01,  0.01, -0.03,  0.01,
         0.  ,  0.02,  0.09,  0.01,  0.01, -0.08,  0.02, -0.01, -0.01,
        -0.01],
       [ 0.04, -0.01,  0.  , -0.07, -0.01,  0.02, -0.01, -0.05,  0.  ,
        -0.02,  0.01,  0.  , -0.04, -0.  , -0.  ,  0.02,  0.27,  0.02,
         0.75,  0.02,  0.03,  0.01, -0.  , -0.16,  0.09, -0.02,  0.01,
         0.01, -0.04, -0.  , -0.  ,  0.01,  0.03, -0.05,  0.  , -0.02,
        -0.  ,  0.02,  0.1 ,  0.02,  0.05, -0.08,  0.01, -0.  , -0.03,
         0.  ,  0.03,  0.02,  0.03, -0.02, -0.01,  0.01, -0.03,  0.01,
         0.  ,  0.02,  0.09,  0.01,  0.01, -0.08,  0.02, -0.0

In [22]:
rec.outputs['reduce_input_to_hidden.linear']

tensor([[ 4.3396e-02, -7.3099e-03,  1.2003e-03, -6.8865e-02, -9.5513e-03,
          1.6791e-02, -7.4942e-03, -4.8386e-02,  3.4664e-03, -2.2581e-02,
          1.1399e-02,  1.7949e-03, -3.9539e-02, -4.9298e-03, -5.1056e-04,
          1.8400e-02,  2.7434e-01,  2.0446e-02,  7.4928e-01,  1.5916e-02,
          2.6746e-02,  1.2863e-02, -1.0949e-03, -1.6486e-01,  8.9331e-02,
         -2.0863e-02,  1.4980e-02,  1.4410e-02, -4.0151e-02, -1.3759e-03,
         -1.1532e-03,  6.8857e-03,  3.3827e-02, -5.1540e-02,  3.4535e-03,
         -1.9746e-02, -4.7797e-03,  2.2620e-02,  1.0333e-01,  2.4986e-02,
          5.1410e-02, -7.8219e-02,  9.5045e-03, -2.5534e-03, -3.0774e-02,
          5.5592e-04,  2.5200e-02,  1.9146e-02,  2.5337e-02, -2.3996e-02,
         -7.5140e-03,  8.3058e-03, -2.7945e-02,  1.3871e-02,  2.0098e-04,
          2.4092e-02,  9.1418e-02,  1.4491e-02,  8.6250e-03, -7.9950e-02,
          2.0490e-02, -5.3512e-03, -8.1044e-03, -1.2528e-02],
        [ 4.3397e-02, -7.3111e-03,  1.2007e-03, -6

In [19]:
rec.outputs['reduce_hidden_to_energy.linear']

tensor([[-0.0417],
        [-0.0417],
        [-0.0416],
        [-0.0416],
        [ 0.0339],
        [ 0.0340],
        [ 0.0340],
        [ 0.0339]], grad_fn=<ReshapeAliasBackward0>)

In [None]:
params = dict(seven_calc.model.get_submodule('0_self_connection_intro.linear').named_parameters())
w = params['weight'].numpy(force=True).reshape(128, 224)
w

In [None]:
((x / np.sqrt(128)) @ w)[0, 0]

In [None]:
rec.outputs['reduce_hidden_to_energy.linear']

In [11]:
ag = rec.outputs['']

In [None]:
i = np.where((ag['edge_index'][0] == 48) & (ag['edge_index'][1] == 0))[0].item()
i

In [None]:
rec.outputs['edge_embedding']['edge_embedding'][i]

In [None]:
rec.inputs['0_self_connection_intro']['x']

In [None]:
rec.outputs['0_self_connection_intro.linear'][0, 1].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x']['x'], robust=True, center=0)

In [None]:
params = dict(seven_calc.model.onehot_to_feature_x.named_parameters())
w = params['linear.weight'].detach().cpu().numpy()
sns.heatmap(w.reshape(89, 128)[[11]], center=0, robust=True)

In [None]:
seven_calc.model

In [None]:
rec.outputs['edge_embedding.spherical.sph'][i][6].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x'][0]['x'], robust=True)

In [None]:
rec.outputs['onehot_to_feature_x.linear'][0][0].item()

In [None]:
{k: getattr(v, 'shape', v) for k, v in rec.outputs['onehot_to_feature_x.linear'][0].items()}