In [17]:
%cd ~/cdv
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'
import numpy as np
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
import seaborn as sns

import rho_plus as rp

is_dark = False
theme, cs = rp.mpl_setup(is_dark)
rp.plotly_setup(is_dark)

/home/nmiklaucic/cdv


In [2]:
from pymatgen.core import Structure

s = Structure.from_file('data/087.cif')
s

Structure Summary
Lattice
    abc : 3.46566108 5.912730870000001 5.45788908
 angles : 89.99986914 89.99991061 90.01299071999999
 volume : 111.84044720244309
      A : 3.465661079995782 0.0 5.406950504201276e-06
      B : -0.0013405984239926575 5.912730718007075 1.35043098847761e-05
      C : 0.0 0.0 5.45788908
    pbc : True True True
PeriodicSite: Mn0 (Mn) (1.731, 4.91, 2.721) [0.4999, 0.8303, 0.4985]
PeriodicSite: Mn1 (Mn) (1.733, 1.003, 5.45) [0.5001, 0.1697, 0.9985]
PeriodicSite: Mn2 (Mn) (3.465, 1.953, 2.72) [0.9998, 0.3304, 0.4984]
PeriodicSite: Mn3 (Mn) (-0.0003616, 3.959, 5.449) [0.0001547, 0.6696, 0.9984]
PeriodicSite: O4 (O) (-0.0002849, 1.977, 0.6282) [4.711e-05, 0.3343, 0.1151]
PeriodicSite: O5 (O) (1.732, 4.933, 0.6285) [0.5001, 0.8342, 0.1151]
PeriodicSite: O6 (O) (1.732, 0.9801, 3.357) [0.4999, 0.1658, 0.6151]
PeriodicSite: O7 (O) (3.465, 3.936, 3.357) [1.0, 0.6657, 0.6151]

In [3]:
from sevenn.sevennet_calculator import SevenNetCalculator
seven_calc = SevenNetCalculator("7net-0", device='cpu')  # 7net-0, SevenNet-0, 7net-0_22May2024, 7net-0_11July2024 ...

In [4]:
from functools import partial
import torch
from copy import deepcopy


def serialize_atomgraph(data):    
    values = []
    for value in data:
        if hasattr(value, 'num_atoms'):
            values.append({
                k: (np.array(torch.clone(v).numpy(force=True)) if isinstance(v, torch.Tensor) else v)
                for k, v in dict(value).items()
            })
        else:
            values.append(value)

    return values


class Recorder:    
    def __init__(self):
        self.inputs = {}
        self.outputs = {}

    def pre_hook(self, mod, args, name='module'):        
        self.inputs[name] = serialize_atomgraph(args)[0]

    def post_hook(self, mod, args, output, name='module'):        
        self.outputs[name] = serialize_atomgraph([output])[0]


rec = Recorder()
atoms = s.to_ase_atoms()
handles = []
for name, mod in seven_calc.model.named_modules():
    try:
        handle = mod.register_forward_hook(partial(rec.post_hook, name=name))
        handles.append(handle)
        handle = mod.register_forward_pre_hook(partial(rec.pre_hook, name=name))
        handles.append(handle)
    except RuntimeError:
        continue
out = seven_calc.calculate(atoms=atoms)
print(seven_calc.results['free_energy'] / s.num_sites)

for handle in handles:
    handle.remove()

len(rec.inputs)

-7.8014678955078125


133

In [35]:
y = rec.outputs['0_convolution.convolution'].numpy(force=True)
y.shape

(304, 1152)

In [36]:
y_out[0, 0]

0.0031199902

In [51]:
rec.inputs['0_convolution.convolution']

tensor([[-0.6473,  0.3526,  0.3866,  ...,  0.3685,  0.6307, -1.1181],
        [-0.6473,  0.3526,  0.3866,  ...,  0.3685,  0.6307, -1.1181],
        [-0.0191,  0.0787,  0.0885,  ..., -0.3683, -0.0444, -0.5610],
        ...,
        [-0.6473,  0.3526,  0.3866,  ...,  0.3685,  0.6307, -1.1181],
        [-0.6473,  0.3526,  0.3866,  ...,  0.3685,  0.6307, -1.1181],
        [-0.6473,  0.3526,  0.3866,  ...,  0.3685,  0.6307, -1.1181]],
       grad_fn=<IndexBackward0>)

In [46]:
y_out = rec.inputs['0_convolution']
y_out.shape

(8, 128)

In [12]:
x1 = rec.inputs['edge_preprocess']
x1['edge_index'].shape

(2, 304)

In [62]:
y[x1['edge_index'][0] == 0][:, 0].sum() / 35.989574

0.003119990430604515

In [76]:
np.repeat(np.arange(y.shape[0]), y.shape[1])

array([  0,   0,   0, ..., 303, 303, 303])

In [127]:
from scipy.stats import pearsonr

def tile_to(arr):
    return np.repeat(arr, y.shape[1])

df = pd.DataFrame({'dst': tile_to(x1['edge_index'][0]), 'src': tile_to(x1['edge_index'][1]), 'd': np.tile(np.arange(y.shape[0]), y.shape[1]), 'y': y.reshape(-1)})
df

Unnamed: 0,dst,src,d,y
0,0,1,0,0.001217
1,0,1,1,-0.005464
2,0,1,2,-0.001705
3,0,1,3,-0.002379
4,0,1,4,-0.004055
...,...,...,...,...
350203,7,1,299,0.026355
350204,7,1,300,0.022971
350205,7,1,301,-0.003749
350206,7,1,302,-0.044581


In [129]:
sums = df.groupby(['dst', 'd']).sum()[['y']].reset_index().pivot(index='dst', columns='d', values='y')
sums

d,0,1,2,3,4,5,6,7,8,9,...,294,295,296,297,298,299,300,301,302,303
dst,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.080341,0.979445,-0.369224,-1.599165,3.413868,-0.449158,0.480734,-11.813493,-9.691558,-4.668252,...,-3.41693,1.067544,-3.527453,-1.841094,5.459202,-3.752878,-1.598834,-3.493857,-0.074179,3.58556
1,-1.600157,2.210276,-1.879552,0.386093,0.079122,-0.684628,10.979606,5.504765,-6.483073,6.763945,...,1.899035,-5.401458,-1.655047,-1.867161,2.487915,-0.176466,2.8589,-3.925158,0.008496,1.060535
2,1.33543,1.991293,0.980064,2.90757,-2.295511,-3.258919,-1.081009,-2.087975,-1.711585,0.616451,...,1.190295,-1.03798,0.003805,-0.457985,0.535972,-0.678583,-0.485962,0.109434,0.581956,0.786045
3,2.136575,0.29341,-2.356462,-1.981838,-1.3095,3.80939,-1.122025,5.053125,4.640914,-8.811065,...,-1.369886,-2.648954,5.682726,-1.973662,1.921569,-4.683675,1.24559,-0.069146,-2.784544,-0.288168
4,0.73822,-1.711397,-4.586409,4.402134,2.952676,1.023712,-4.465078,0.557161,0.480293,2.625098,...,-4.57993,-4.153,2.727923,-0.061252,-3.250031,-0.709231,4.845757,-5.654696,-2.809213,5.542956
5,2.886852,-1.444471,0.689829,-3.380885,1.831948,2.564552,1.4676,1.956264,1.621551,-1.381358,...,0.143661,-3.129117,0.829262,1.277241,-1.660523,-2.152118,-9.063655,2.481149,2.106292,-0.052443
6,3.237519,-0.160961,-0.356338,-0.976993,1.715193,2.347855,0.57293,4.818274,-0.197501,0.542776,...,1.154601,-0.534233,-2.144239,-3.351512,4.185036,-2.476043,-6.353804,-0.043793,1.806428,2.842532
7,10.089384,-3.583895,0.972734,6.628081,0.117261,-0.484851,6.096695,3.307767,-10.135487,2.160851,...,-1.321827,-2.84025,-0.283021,0.243471,-2.188114,-1.913223,1.402384,-1.848845,-6.954618,0.56823


In [130]:
vdf = df.groupby(['dst', 'd']).agg(np.var)[['y']].reset_index().pivot(index='dst', columns='d', values='y')
vdf

  vdf = df.groupby(['dst', 'd']).agg(np.var)[['y']].reset_index().pivot(index='dst', columns='d', values='y')


d,0,1,2,3,4,5,6,7,8,9,...,294,295,296,297,298,299,300,301,302,303
dst,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.02001,0.027048,0.007321,0.104947,0.172726,0.042048,0.11612,0.838443,0.411758,0.145701,...,0.046657,0.021413,0.03492,0.020889,0.094205,0.035478,0.013434,0.0771,0.082708,0.019845
1,0.035061,0.135647,0.028447,0.010662,0.035403,0.289032,0.485957,0.148217,0.100089,0.316014,...,0.03453,0.115148,0.058637,0.007451,0.23881,0.019939,0.148962,0.025476,0.012604,0.058114
2,0.021057,0.085726,0.024866,0.075847,0.152884,0.033136,0.02841,0.089957,0.010757,0.018154,...,0.014177,0.018665,0.131783,0.030692,0.024864,0.021751,0.013786,0.088626,0.040587,0.01464
3,0.021969,0.013763,0.022131,0.061664,0.059761,0.051813,0.036058,0.245197,0.101397,0.729209,...,0.083168,0.080923,0.094901,0.027755,0.079215,0.063574,0.015391,0.008754,0.022468,0.025084
4,0.011739,0.068881,0.114905,0.087892,0.038059,0.052947,0.064496,0.013304,0.049784,0.040832,...,0.073465,0.095589,0.120055,0.046921,0.064423,0.164236,0.128684,0.098195,0.062353,0.231045
5,0.055471,0.097626,0.050332,0.093934,0.024671,0.077029,0.015654,0.035604,0.055971,0.029755,...,0.014748,0.090546,0.053898,0.026284,0.012041,0.020232,0.312575,0.022799,0.037279,0.04833
6,0.038553,0.03603,0.05919,0.023035,0.064062,0.024277,0.022175,0.079558,0.050773,0.035854,...,0.12272,0.014636,0.072923,0.352885,0.125984,0.092364,0.229748,0.096869,0.063507,0.030487
7,0.412916,0.064023,0.046937,0.101193,0.033194,0.089603,0.129671,0.026891,0.3713,0.037604,...,0.041587,0.044893,0.066334,0.075095,0.013273,0.037474,0.246706,0.114548,0.131696,0.133634


In [141]:
kdf = df.groupby(['dst', 'd']).count()[['y']].reset_index().pivot(index='dst', columns='d', values='y')
kdf

d,0,1,2,3,4,5,6,7,8,9,...,294,295,296,297,298,299,300,301,302,303
dst,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
1,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
2,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
3,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
4,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
5,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
6,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144
7,144,144,144,144,144,144,144,144,144,144,...,144,144,144,144,144,144,144,144,144,144


In [138]:
(np.log(abs(sums) / vdf) / np.log(144)).mean().mean()

0.69031763