In [1]:
# import os
# # Navigate to the Open-Catalyst-Project directory
# os.chdir('/host_workspace/v9/ocp')
# # Install the package
# !pip install -e .

In [2]:
## 1. summary so far
## use 'v2_' prefix based jupyter files for the newer works.
##
## CATHUB
## useful data - datasets/reactions.pickle
## # samples - 88587
## useful features - star/sc, facet, pkey, nre/re
## this dataset is from Mamun's resource, contains BEEF-vdW functional data
## needs to generate descriptors using (star/sc, facet, pkey) vs (nre/re)
##
## OCP
## useful data - datasets/ocp_reactions_info.pickle
## # samples - 446885
## useful features - bulk_symbols, miller_index, pkey, energy
## this dataset is from META Open Catalyst Project, contains RPBE functional data
## needs to generate descriptors using (bulk_symbols, miller_index, pkey) vs (energy)
##
## 2. use start_jupyter torch2
## 3. this file is inspired by script_9_cathub.ipynb
## 4. retrieving last layer representations
# for the follwing code
#         # Embedding block.
#         x = self.emb(data.atomic_numbers.long(), rbf, i, j)
#         P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))
#         print(x.shape, P.shape)
        
#         ## change
#         def capture_embeddings(module, input, output):
#             self.embs.append(output.detach())  # Store the output embeddings
            
#         ## change
#         hook_handle = self.interaction_blocks[-1].register_forward_hook(capture_embeddings)
            
#         # Interaction blocks.
#         for interaction_block, output_block in zip(
#             self.interaction_blocks, self.output_blocks[1:]
#         ):
#             x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
#             P += output_block(x, rbf, i, num_nodes=pos.size(0))
#             print(x.shape, P.shape)

## getting following output
# Atoms(symbols='H5Au48Co8O17', pbc=True, cell=[5.878883441, 20.365049623, 20.81515], calculator=OCPCalculator(...))
# torch.Size([3477, 256]) torch.Size([78, 1])
# torch.Size([3477, 256]) torch.Size([78, 1])
# torch.Size([3477, 256]) torch.Size([78, 1])
# torch.Size([3477, 256]) torch.Size([78, 1])
# Atoms(symbols='CH9Au48Co8O17', pbc=True, cell=[5.878883441, 20.365049623, 20.81515], calculator=OCPCalculator(...))
# torch.Size([3533, 256]) torch.Size([83, 1])
# torch.Size([3533, 256]) torch.Size([83, 1])
# torch.Size([3533, 256]) torch.Size([83, 1])
# torch.Size([3533, 256]) torch.Size([83, 1])
# -1.29420852661132811

In [3]:
import os
import numpy as np
import ase.io
from ase.constraints import FixAtoms
from ase.build import add_adsorbate, molecule, surface
from pymatgen.ext.matproj import MPRester
from pymatgen.core.surface import generate_all_slabs, SlabGenerator
from pymatgen.io.ase import AseAtomsAdaptor
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
import torch
import pickle
import json
import pandas as pd

In [4]:
try:
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)
except Exception as e:
    print(e)
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)    

No module named 'ocpmodels.models.'


In [5]:
calc

<ocpmodels.common.relaxation.ase_utils.OCPCalculator at 0x7f0dc3d82f80>

In [6]:
with open('datasets/reactions.pickle', 'rb') as f:
    loaded_reactions = pickle.load(f)

In [7]:
len(loaded_reactions)

88587

In [8]:
num_sample = len(loaded_reactions) ## *0 +2000 ## change here

In [9]:
########################################################################
########################################################################

In [10]:
# list_rinfo = []
# for ir, r in enumerate(loaded_reactions[:1]):
#     if ir%1000==0:
#         print(ir, len(list_rinfo))

#     ## retrieve slab and adsorbate
#     ase_slab = r['reactionSystems']['star']
#     d_prod = json.loads(r['products'])
#     if len(d_prod.keys()) != 1:
#         continue
#     pkey = list(d_prod.keys())[0].replace('star', '')
#     ads_symbol = pkey    

#     ## get slab energy and embeddings
#     adslab = ase_slab.copy()
#     adslab.calc = calc
#     # print(adslab)
#     slab_energy = adslab.get_potential_energy()

#     adsorbate = molecule(ads_symbol)
#     add_adsorbate(adslab, adsorbate, 3, offset=(1, 1))
#     # print(adslab)
#     ads_energy = adslab.get_potential_energy()

In [11]:
# ase_slab.positions

In [12]:
# adsorbate

In [13]:
# adslab.positions

In [14]:
########################################################################
########################################################################

In [15]:
list_rinfo = []
for ir, r in enumerate(loaded_reactions[:num_sample]):
    try:
        if ir%1000==0:
            print(ir, len(list_rinfo))
    
        ## retrieve slab and adsorbate
        ase_slab = r['reactionSystems']['star']
        d_prod = json.loads(r['products'])
        if len(d_prod.keys()) != 1:
            continue
        pkey = list(d_prod.keys())[0].replace('star', '')
        ads_symbol = pkey    

        ## slab + adsorbate 
        adslab = ase_slab.copy()            
        adsorbate = molecule(ads_symbol)
        add_adsorbate(adslab, adsorbate, 3, offset=(1, 1))

        ## set additional info
        tags = np.zeros(len(adslab))
        tags[18:27] = 1
        tags[27:] = 2
        adslab.set_tags(tags)
        cons= FixAtoms(indices=[atom.index for atom in adslab if (atom.tag == 0)])
        adslab.set_constraint(cons)
        adslab.center(vacuum=13.0, axis=2)
        adslab.set_pbc(True)
        
        ## calculator, energy, embeddings
        adslab.calc = calc
        ads_energy = adslab.get_potential_energy()
        mean_ads_embs = torch.mean(calc.trainer.model.module.embs[0], dim=0)    
        min_ads_embs = torch.min(calc.trainer.model.module.embs[0], dim=0)    
        max_ads_embs = torch.max(calc.trainer.model.module.embs[0], dim=0)    
        sum_ads_embs = torch.sum(calc.trainer.model.module.embs[0], dim=0)                
        
        ## additional info
        sc = r['surfaceComposition']
        facet = r['facet']
        pval = d_prod[list(d_prod.keys())[0]]        
        re = r['reactionEnergy']
        nre = re/pval
        
        eqn = r['Equation']
        ae = r['activationEnergy']
        cc = r['chemicalComposition']
        d_cvr = json.loads(r['coverages'])
        dft_code = r['dftCode']
        dft_func = r['dftFunctional']
        pubid = r['pubId']
        d_reactants = json.loads(r['reactants'])
        sites = json.loads(r['sites'])
        username = r['username']

        rinfo = {
            'ase_slab': ase_slab,            
            'd_prod': d_prod,            
            'pkey': pkey,
            'mean_ads_embs': mean_ads_embs,
            'min_ads_embs': min_ads_embs.values,
            'max_ads_embs': max_ads_embs.values,
            'sum_ads_embs': sum_ads_embs,
            
            'sc': sc,
            'facet': facet,            
            'pval': pval,            
            're': re,
            'nre': nre,
            
            'eqn': eqn,
            'ae': ae, 
            'cc': cc, 
            'd_cvr': d_cvr, 
            'dft_code': dft_code,
            'dft_func': dft_func,
            'pubid': pubid,
            'd_reactants': d_reactants,
            'sites': sites,
            'username': username            
        }
        list_rinfo.append(rinfo)
    except Exception as e:
        pass

0 0
1000 814
2000 1234
3000 2004
4000 2378
5000 2709
6000 3079
7000 3558
8000 3861
9000 4119
10000 4407
11000 4828
12000 5148
13000 5494
14000 5936
15000 6233
16000 6461
17000 6731
18000 7096
19000 7505
20000 8243
21000 9151
22000 9923
23000 10901
24000 11886
25000 12884
26000 13884
27000 14884
28000 15884
29000 16874
30000 17863
31000 18863
32000 19863
33000 20863
34000 21855
35000 22849
36000 23844
37000 24841
38000 25841
39000 26841
40000 27841
41000 28841
42000 29841
43000 30841
44000 31271
45000 31391
46000 31503
47000 31620
48000 31750
49000 31920
50000 32248
51000 32354
52000 32480
53000 32609
54000 32720
55000 32849
56000 32953
57000 33103
58000 33209
59000 33325
60000 33421
61000 33527
62000 33661
63000 33800
64000 33970
65000 34156
66000 34269
67000 34385
68000 34501
69000 34649
70000 34806
71000 34920
72000 35028
73000 35130
74000 35263
75000 35391
76000 35505
77000 35734
78000 35971
79000 36268
80000 36701
81000 37289
82000 38058
83000 38901
84000 39746
85000 40720
86000 41

In [16]:
with open('datasets/dict_cathub_dpp_mean_min_max_sum.pickle', 'wb') as f:
    pickle.dump(list_rinfo, f)

In [17]:
with open('datasets/dict_cathub_dpp_mean_min_max_sum.pickle', 'rb') as f:
    loaded_list_rinfo = pickle.load(f)

In [18]:
len(loaded_list_rinfo)

43678

In [19]:
for agg_func in ['mean', 'min', 'max', 'sum']:
    print()
    print(agg_func)
    for item in loaded_list_rinfo:
        item['ads_embs'] = [float(x) for x in item[agg_func + '_ads_embs']]
    # Convert list of dictionaries to DataFrame
    df = pd.DataFrame(loaded_list_rinfo)
    # Split 'ads_embs' into separate columns
    ads_embs_cols = df['ads_embs'].apply(pd.Series)
    ads_embs_cols.columns = ['e' + str(i) for i in range(256)]
    # Drop the original 'ads_embs' column and concatenate the new columns
    df = df.drop('ads_embs', axis=1)
    df = pd.concat([df, ads_embs_cols], axis=1)
    ## keep the stable reaction
    df = df.sort_values(by='nre')
    df = df.groupby(['sc', 'facet', 'pkey']).first().reset_index()
    ##
    print(df.shape)
    print(df[['sc', 'facet', 'pkey', 'nre']].head(2))
    df = df[['e' + str(i) for i in range(256)] + ['nre']]
    print(df.shape)
    df.to_pickle('datasets/df_cathub_dpp_' + str(agg_func) + '.pickle')


mean
(10938, 278)
   sc facet pkey       nre
0  Ag   111    C  5.608979
1  Ag   111   CH  3.996914
(10938, 257)

min
(10938, 278)
   sc facet pkey       nre
0  Ag   111    C  5.608979
1  Ag   111   CH  3.996914
(10938, 257)

max
(10938, 278)
   sc facet pkey       nre
0  Ag   111    C  5.608979
1  Ag   111   CH  3.996914
(10938, 257)

sum
(10938, 278)
   sc facet pkey       nre
0  Ag   111    C  5.608979
1  Ag   111   CH  3.996914
(10938, 257)


In [20]:
## visualize 
agg_func = 'max'
df = pd.read_pickle('datasets/df_cathub_dpp_' + str(agg_func) + '.pickle')
print(df.shape)
df.head()

(10938, 257)


Unnamed: 0,e0,e1,e2,e3,e4,e5,e6,e7,e8,e9,...,e247,e248,e249,e250,e251,e252,e253,e254,e255,nre
0,7.336441,10.512039,4.940955,7.956531,9.691102,3.735639,16.076952,0.277445,5.527064,14.929154,...,1.775173,3.180848,9.447767,9.821554,10.599458,6.261724,3.143227,5.966093,9.493245,5.608979
1,2.475741,7.989372,3.852236,5.285998,7.627347,22.725266,11.491851,3.469543,17.968449,9.22589,...,3.267034,6.803737,6.796365,12.177052,17.952662,16.398821,3.917056,16.349478,11.494075,3.996914
2,3.114502,5.341427,6.251712,4.686482,4.443282,9.897226,3.351537,2.349534,4.290404,5.007502,...,1.81571,3.131814,2.582453,5.850073,5.613225,2.286682,3.940985,7.043963,5.547743,1.476339
3,0.444288,1.605584,3.471621,2.10997,4.423562,6.658717,5.854311,-0.048702,3.43838,1.848312,...,0.966942,3.170051,1.561559,6.017676,5.376443,1.803873,3.264681,2.892181,2.730005,0.338155
4,4.056249,3.879143,5.458462,9.691298,5.097898,7.205102,2.775909,3.750675,5.506991,1.865663,...,1.541541,3.014952,2.957124,5.853688,5.941875,3.863775,4.885923,9.841976,5.264493,-0.120299


In [21]:
## visualize 
agg_func = 'mean'
df = pd.read_pickle('datasets/df_cathub_dpp_' + str(agg_func) + '.pickle')
print(df.shape)
df.head()

(10938, 257)


Unnamed: 0,e0,e1,e2,e3,e4,e5,e6,e7,e8,e9,...,e247,e248,e249,e250,e251,e252,e253,e254,e255,nre
0,-0.051727,-0.286628,1.382931,-1.435718,2.597453,-0.519882,1.805876,-1.458474,-1.170391,0.392702,...,-0.281009,-1.516974,-0.710036,2.207848,-0.556048,0.106217,0.500164,0.067475,-1.460815,5.608979
1,-0.174674,-0.420665,1.293232,-1.630182,2.610617,0.04152,1.422713,-1.272242,-0.592301,0.131498,...,-0.154265,-1.227633,-0.841806,2.116063,-0.115456,0.22269,0.624747,0.54653,-1.335673,3.996914
2,-0.343314,-0.235674,1.56794,-1.390991,1.977526,0.139984,1.089224,-1.16658,-0.69296,-0.110241,...,-0.341958,-1.198122,-0.90351,1.975406,-0.591985,-0.113284,0.572909,0.548986,-1.033705,1.476339
3,-0.334801,-0.539823,1.276774,-1.881636,2.435134,-0.462365,1.572978,-1.456306,-1.060732,-0.164784,...,-0.24836,-1.591639,-1.124636,2.11642,-0.781833,-0.052016,0.452242,-0.112862,-1.649781,0.338155
4,-0.344107,-0.327845,1.366605,-1.291086,2.136956,-0.279175,1.17664,-1.202859,-0.808489,-0.20406,...,-0.250685,-1.344624,-0.927821,1.914095,-0.652575,-0.078963,0.684334,0.447531,-1.397458,-0.120299
