In [1]:
## 1. extract all slabs from ocp and cathub(mamun)
## 2. generate slab descriptor with dimenet++
## 3. extract all product from ocp and cathub(mamun)
## 4. generate product descriptor with chEMBL
## 5. run experiments-
##    - (cathub, ocp) x (xgboost) x (original 1024+1024, pca ncomponents, imr ncomponents)
##    - solves n^2 to n, solves descriptor generation for slab/surface

In [2]:
import pickle
import json
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
from ase.constraints import FixAtoms
from ase.build import add_adsorbate, molecule, surface
import torch
import pandas as pd

In [3]:
with open('datasets/reactions.pickle', 'rb') as f:
    loaded_reactions = pickle.load(f)

In [4]:
#####################################################################
## data from cathub
#####################################################################

In [5]:
list_rinfo = []
for i, react in enumerate(loaded_reactions):
    if i%20000==0: 
        print(i)
    try:
        d_prod = json.loads(react['products'])
        if len(d_prod.keys()) != 1:
            continue
        sc = react['surfaceComposition']
        facet = react['facet']
        star = react['reactionSystems']['star']
        pkey = list(d_prod.keys())[0].replace('star', '')
        pval = d_prod[list(d_prod.keys())[0]]
        re = react['reactionEnergy']
        nre = re/pval
        rinfo = {
            'sc': sc,
            'slab': star,
            'star': str(star.symbols),           
            'facet': facet,
            'pkey': pkey,
            'pval': pval,            
            're': re,
            'nre': nre,                        
        }
        list_rinfo.append(rinfo)
    except Exception as e:
        pass

0
20000
40000
60000
80000


In [6]:
len(list_rinfo)

45136

In [7]:
filtered_dict = {}
for item in list_rinfo:
    key = (item['star'], item['facet'], item['pkey'])
    if key not in filtered_dict or item['nre'] < filtered_dict[key]['nre']:
        filtered_dict[key] = item
list_rinfo = list(filtered_dict.values())

In [8]:
len(list_rinfo)

11257

In [9]:
list_rinfo[0]

{'sc': 'V3Sc',
 'slab': Atoms(symbols='V3ScV3ScV3Sc', pbc=True, cell=[[5.6401896008612615, 0.0, 0.0], [-2.8200948004306308, 4.884546654335772, 0.0], [0.0, 0.0, 24.60519325876673]], calculator=SinglePointCalculator(...)),
 'star': 'V3ScV3ScV3Sc',
 'facet': '111',
 'pkey': 'N',
 'pval': 1,
 're': -2.810391181412342,
 'nre': -2.810391181412342}

In [10]:
#####################################################################
## product descriptors
#####################################################################

In [11]:
set_pkey = set()
for rinfo in list_rinfo:
    if rinfo['pkey'] not in set_pkey:
        set_pkey.add(rinfo['pkey'])
list_pkey = list(set_pkey)

In [12]:
list_pkey

['N', 'C', 'CH', 'S', 'CH3', 'H', 'SH', 'CH2', 'O', 'NH', 'OH', 'H2O']

In [13]:
d_pkey_vs_smiles = {
    ## pubchem: one entry/ multi entry/ previous datasets
    'S': '[S]',
    'O': '[O]',
    'CH': '[CH]', 
    'H': '[H]',
    'SH': '[SH]',    
    'NH': '[NH]',
    'CH3': '[CH3]',   
    'H2O': 'O',
    'C': '[C]',   
    'N': '[N]',    
    'OH': '[OH]',
    'CH2': '[CH2]', 
}

In [14]:
tokenizer = AutoTokenizer.from_pretrained("mrm8488/chEMBL26_smiles_v2")
model = AutoModel.from_pretrained("mrm8488/chEMBL26_smiles_v2")
fe = pipeline('feature-extraction', model=model, tokenizer=tokenizer, device=-1) ## device= (0 for GPU, -1 for CPU)

Some weights of the model checkpoint at mrm8488/chEMBL26_smiles_v2 were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
list_pkey= list(d_pkey_vs_smiles.keys())
list_smiles= list(d_pkey_vs_smiles.values())
list_emb_1 = fe(list_smiles)

list_emb_2 = []
for emb_1 in list_emb_1:
    emb_2 = np.mean(np.array(emb_1[0]), axis=0)
    list_emb_2.append(emb_2)
arr_emb = np.array(list_emb_2)
print(arr_emb.shape)

d_pkey_vs_desc = {}
for i, pkey in enumerate(list_pkey):
    d_pkey_vs_desc[pkey] = arr_emb[i]

(12, 768)


In [16]:
for rinfo in list_rinfo:
    rinfo['pdesc'] = d_pkey_vs_desc[rinfo['pkey']]

In [17]:
list_rinfo[0].keys()

dict_keys(['sc', 'slab', 'star', 'facet', 'pkey', 'pval', 're', 'nre', 'pdesc'])

In [18]:
#####################################################################
## slab descriptors
#####################################################################

In [19]:
try:
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)
except Exception as e:
    print(e)
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)    

No module named 'ocpmodels.models.'


In [21]:
calc

In [24]:
len(list_rinfo)

11257

In [25]:
list_rinfo2 = []
for i, rinfo in enumerate(list_rinfo):
    if i%500==0:
        print(i)
    try:
        ase_slab = rinfo['slab']
        adslab = ase_slab.copy()  

        ads_symbol = rinfo['pkey']    
        adsorbate = molecule(ads_symbol)
        add_adsorbate(adslab, adsorbate, 3, offset=(1, 1))

        ## set additional info
        tags = np.zeros(len(adslab))
        tags[18:27] = 1
        tags[27:] = 2
        adslab.set_tags(tags)
        cons= FixAtoms(indices=[atom.index for atom in adslab if (atom.tag == 0)])
        adslab.set_constraint(cons)
        adslab.center(vacuum=13.0, axis=2)
        adslab.set_pbc(True)

        ## calculator, energy, embeddings
        adslab.calc = calc
        ads_energy = adslab.get_potential_energy()
        min_ads_embs = torch.min(calc.trainer.model.module.embs[0], dim=0).values        
        mean_ads_embs = torch.mean(calc.trainer.model.module.embs[0], dim=0)    
        max_ads_embs = torch.max(calc.trainer.model.module.embs[0], dim=0).values    
        # sum_ads_embs = torch.sum(calc.trainer.model.module.embs[0], dim=0)                
        # print(min_ads_embs.shape, mean_ads_embs.shape, max_ads_embs.shape)
        combined_embs = torch.cat((min_ads_embs, mean_ads_embs, max_ads_embs))    
        rinfo['sdesc'] = combined_embs.numpy()
        list_rinfo2.append(rinfo)
    except Exception as e:
        pass

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
10000
10500
11000


In [27]:
len(list_rinfo2)

10938

In [28]:
list_rinfo = list_rinfo2

In [29]:
len(list_rinfo)

10938

In [30]:
list_rinfo[-1].keys()

dict_keys(['sc', 'slab', 'star', 'facet', 'pkey', 'pval', 're', 'nre', 'pdesc', 'sdesc'])

In [31]:
sdesc_columns = [f's{i}' for i in range(768)]
pdesc_columns = [f'p{i}' for i in range(768)]
columns = sdesc_columns + pdesc_columns + ['energy']
data = []
for i, rinfo in enumerate(list_rinfo):
    try:
        sdesc_flat = rinfo['sdesc'].flatten()
        pdesc_flat = rinfo['pdesc'].flatten()
        energy = rinfo['nre']
        row = list(sdesc_flat) + list(pdesc_flat) + [energy]
        data.append(row)
    except Exception as e:
        print(i, e)
df = pd.DataFrame(data, columns=columns)

In [32]:
df.to_pickle('v4/cathub_df.pickle')
print(df.shape)
df.head(2)

(10938, 1537)


Unnamed: 0,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9,...,p759,p760,p761,p762,p763,p764,p765,p766,p767,energy
0,-1.981828,-1.770542,-1.320971,-2.183548,-1.858627,-0.831343,-2.508203,-1.690223,-1.849663,-1.821207,...,0.912192,0.652321,0.835792,-2.122195,-0.916445,-1.751576,0.783765,-1.780947,-0.90128,-2.810391
1,-1.587637,-2.088126,-2.159897,-1.888337,-1.780846,-1.802099,-1.936221,-1.551155,-2.214272,-0.360621,...,0.133178,0.967194,0.530776,-0.38646,0.284255,-1.429858,1.077861,-0.831704,-0.738524,-4.468474
