In [1]:
%cd ~/cdv
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

import rho_plus as rp

is_dark = True
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

/home/nmiklaucic/cdv


In [2]:
from pymatgen.core import Structure

s = Structure.from_file('data/087.cif')
s

Structure Summary
Lattice
    abc : 3.46566108 5.912730870000001 5.45788908
 angles : 89.99986914 89.99991061 90.01299071999999
 volume : 111.84044720244309
      A : 3.465661079995782 0.0 5.406950504201276e-06
      B : -0.0013405984239926575 5.912730718007075 1.35043098847761e-05
      C : 0.0 0.0 5.45788908
    pbc : True True True
PeriodicSite: Mn0 (Mn) (1.731, 4.91, 2.721) [0.4999, 0.8303, 0.4985]
PeriodicSite: Mn1 (Mn) (1.733, 1.003, 5.45) [0.5001, 0.1697, 0.9985]
PeriodicSite: Mn2 (Mn) (3.465, 1.953, 2.72) [0.9998, 0.3304, 0.4984]
PeriodicSite: Mn3 (Mn) (-0.0003616, 3.959, 5.449) [0.0001547, 0.6696, 0.9984]
PeriodicSite: O4 (O) (-0.0002849, 1.977, 0.6282) [4.711e-05, 0.3343, 0.1151]
PeriodicSite: O5 (O) (1.732, 4.933, 0.6285) [0.5001, 0.8342, 0.1151]
PeriodicSite: O6 (O) (1.732, 0.9801, 3.357) [0.4999, 0.1658, 0.6151]
PeriodicSite: O7 (O) (3.465, 3.936, 3.357) [1.0, 0.6657, 0.6151]

In [3]:
from mace.calculators import mace_mp
from ase import build
from pathlib import Path

atoms = s.to_ase_atoms()

calc = mace_mp(dispersion=False, default_dtype="float32", device='cpu', model='medium')
calc.calculate(atoms=atoms)
calc.results

Using Materials Project MACE for MACECalculator with /home/nmiklaucic/.cache/mace/20231203mace128L1_epoch199model
Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.
Default dtype float32 does not match model dtype float64, converting models to float32.


{'energy': -62.19428253173828,
 'free_energy': -62.19428253173828,
 'node_energy': array([ 0.6044607 ,  0.6044612 ,  0.6045575 ,  0.6045661 , -0.5645795 ,
        -0.5625615 , -0.5624423 , -0.56454134], dtype=float32),
 'forces': array([[ 0.00144893, -0.0083482 , -0.02244522],
        [-0.00140256,  0.00836148, -0.0228433 ],
        [ 0.00255934, -0.01262411, -0.01770091],
        [-0.00251153,  0.01247613, -0.01785059],
        [ 0.00115948, -0.02385738,  0.02021112],
        [ 0.00214557, -0.02359966,  0.02019673],
        [-0.00219709,  0.0236147 ,  0.02025928],
        [-0.00120194,  0.02397679,  0.02017296]], dtype=float32),
 'stress': array([-1.71648320e-02, -2.06205603e-02, -1.71729978e-02, -5.25974542e-07,
         6.88569287e-07, -8.41041474e-05])}

In [4]:
calc.results['free_energy'] / atoms.get_number_of_atoms()

  calc.results['free_energy'] / atoms.get_number_of_atoms()


-7.774285316467285

In [5]:
# ckpt_params = dict(seven_calc.model.named_parameters())
# ckpt_params = {k: v.detach().cpu().numpy() for k, v in ckpt_params.items()}
# np.save('precomputed/sevennet_ckpt.npy', ckpt_params)

In [6]:
from functools import partial
import torch
from copy import deepcopy


def serialize_atomgraph(data):    
    return data


class Recorder:    
    def __init__(self):
        self.inputs = {}
        self.outputs = {}

    def pre_hook(self, mod, args, name='module'):        
        self.inputs[name] = serialize_atomgraph(args)

    def post_hook(self, mod, args, output, name='module'):        
        self.outputs[name] = serialize_atomgraph(output)


rec = Recorder()
atoms = s.to_ase_atoms()
handles = []
for name, mod in calc.models[0].named_modules():
    try:
        handle = mod.register_forward_hook(partial(rec.post_hook, name=name))
        handles.append(handle)
        handle = mod.register_forward_pre_hook(partial(rec.pre_hook, name=name))
        handles.append(handle)
    except RuntimeError:
        continue

try:
    calc.calculate(atoms=atoms)
finally:
    for handle in handles:
        handle.remove()
print(calc.results['energy'] / s.num_sites)
len(rec.inputs)

-7.774285316467285


64

In [7]:
rec.inputs.keys()

dict_keys(['atomic_energies_fn', '', 'node_embedding', 'node_embedding.linear', 'spherical_harmonics', 'radial_embedding', 'radial_embedding.cutoff_fn', 'radial_embedding.bessel_fn', 'interactions.0', 'interactions.0.skip_tp', 'interactions.0.linear_up', 'interactions.0.conv_tp_weights', 'interactions.0.conv_tp_weights.layer0', 'interactions.0.conv_tp_weights.layer0.act', 'interactions.0.conv_tp_weights.layer1', 'interactions.0.conv_tp_weights.layer2', 'interactions.0.conv_tp_weights.layer3', 'interactions.0.conv_tp', 'interactions.0.linear', 'interactions.0.reshape', 'products.0', 'products.0.symmetric_contractions', 'products.0.symmetric_contractions.contractions.0', 'products.0.symmetric_contractions.contractions.0.graph_opt_main', 'products.0.symmetric_contractions.contractions.0.contractions_weighting.0', 'products.0.symmetric_contractions.contractions.0.contractions_features.0', 'products.0.symmetric_contractions.contractions.0.contractions_weighting.1', 'products.0.symmetric_con

In [9]:
y = rec.outputs['interactions.0.linear'].numpy(force=True)
y[0, 0]

-46.223297

In [None]:
x = rec.inputs['0_self_connection_intro.linear'].numpy(force=True)
x

In [None]:
params = dict(seven_calc.model.get_submodule('0_self_connection_intro.linear').named_parameters())
w = params['weight'].numpy(force=True).reshape(128, 224)
w

In [None]:
((x / np.sqrt(128)) @ w)[0, 0]

In [None]:
rec.outputs['reduce_hidden_to_energy.linear']

In [None]:
ag = rec.outputs['']

In [None]:
i = np.where((ag['edge_index'][0] == 48) & (ag['edge_index'][1] == 0))[0].item()
i

In [None]:
rec.outputs['edge_embedding']['edge_embedding'][i]

In [None]:
rec.inputs['0_self_connection_intro']['x']

In [None]:
rec.outputs['0_self_connection_intro.linear'][0, 1].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x']['x'], robust=True, center=0)

In [None]:
params = dict(seven_calc.model.onehot_to_feature_x.named_parameters())
w = params['linear.weight'].detach().cpu().numpy()
sns.heatmap(w.reshape(89, 128)[[11]], center=0, robust=True)

In [None]:
seven_calc.model

In [None]:
rec.outputs['edge_embedding.spherical.sph'][i][6].item()

In [None]:
sns.heatmap(rec.outputs['onehot_to_feature_x'][0]['x'], robust=True)

In [None]:
rec.outputs['onehot_to_feature_x.linear'][0][0].item()

In [None]:
{k: getattr(v, 'shape', v) for k, v in rec.outputs['onehot_to_feature_x.linear'][0].items()}