In [1]:
## 1. extract all slabs from ocp and cathub(mamun)
## 2. generate slab descriptor with dimenet++
## 3. extract all product from ocp and cathub(mamun)
## 4. generate product descriptor with chEMBL
## 5. run experiments-
##    - multitask learner -> (cathub, ocp) x (xgboost) x (original 1024+1024, pca ncomponents, imr ncomponents)
##    - solves n^2 to n, solves descriptor generation for slab/surface

In [2]:
import json
import pickle
import numpy as np
from ocpmodels.datasets import SinglePointLmdbDataset 
import os
import ase.io
from ase.constraints import FixAtoms
from ase.build import add_adsorbate, molecule, surface
from pymatgen.ext.matproj import MPRester
from pymatgen.core.surface import generate_all_slabs, SlabGenerator
from pymatgen.io.ase import AseAtomsAdaptor
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
from pprint import pprint
import pubchempy as pcp
from transformers import AutoTokenizer, AutoModel, pipeline
import pandas as pd
import ray
import torch
import random
from concurrent.futures import ThreadPoolExecutor, as_completed

2024-06-20 11:14:13,697	INFO util.py:159 -- Outdated packages:
  ipywidgets==7.6.5 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [3]:
# File path for the dataset
folder_path = 'datasets/slabs/'

# File names
file_names = [
    "list_rinfo_0to100000.pkl",
    "list_rinfo_100000to200000.pkl",
    "list_rinfo_200000to300000.pkl",
    "list_rinfo_300000to400000.pkl",
    "list_rinfo_400000to500000.pkl"
]

# Initialize the combined list
list_rinfo_all = []

# Iterate through each file and append its contents to the combined list
for file_name in file_names:
    file_path = os.path.join(folder_path, file_name)
    with open(file_path, 'rb') as file:
        list_rinfo = pickle.load(file)
        list_rinfo_all.extend(list_rinfo)

In [4]:
# Check the length of the combined list
len(list_rinfo_all)

443368

In [5]:
# list_rinfo_all[0]

In [6]:
for lld in list_rinfo_all[:]:
    ads_symbols = lld['ads_symbols']
    pkey = ads_symbols.replace('*', '')
    lld['pkey'] = pkey

In [7]:
list_pkey =  []
for data in list_rinfo_all:
    pkey = data['pkey']
    list_pkey.append(pkey)

list_pkey = list(set(list_pkey))

len(list_pkey)

68

In [8]:
d_pkey_vs_smiles = {
    'CH2CO': 'C=C=O',
    'COHCHO': 'C(=O)C=O',
    'OH2': 'O',
    'CH2O': 'C=O',
    'COCHO': 'C(=C=O)[O-]',
    'CHOH': 'C=O',
    'ONN(CH3)2': 'ONN(C)(C)',
    'OHCH2CH3': 'CCO',
    'NONH': 'N=N[O-]',
    'NO2': 'N(=O)[O-]',
    'CCHO': 'C#C[O-]',
    'OHCH3': 'CO',
    'NH': '[NH]',
    'COHCOH': 'C(=O)C=O',
    'OCH2CH3': 'CC[O-]',
    'NNO': '[N-]=[N+]=O',
    'CHCH2OH': 'C1CO1',
    'CHOCH2OH': 'CC(=O)O',
    'CHCOH': 'C=C=O',
    'NO': '[N]=O',
    'CHOHCHOH': 'CC(=O)O',
    'CH2OH': 'C[O-]',
    'NNH': '[NH+]#N',
    'CCH2': 'C#C',
    'OCH3': 'C[O-]',
    'CCH3': 'C=[CH]',
    'CH2CH3': 'C[CH2+]',
    'CHOHCH2': 'C1CO1',
    'COHCH2OH': 'CC(=O)O',
    'OCH2CHOH': 'CC(=O)O',
    'CN': '[C-]#N',
    'CHOHCH3': 'CC[O-]',
    'CCH2OH': 'C[C]=O',
    'CHOCHO': 'C(=O)C=O',
    'CHCH': 'C#C',
    'CH4': 'C',
    'COCH2O': 'C(=O)C=O',
    'CHOCHOH': 'CC(=O)[O-]',
    'CCH': 'C#[C-]',
    'CHOHCH2OH': 'C([CH]O)O',
    'N2': 'N#N',
    'ONNH2': 'NN=O',
    'N': '[N]',
    'OCHCH3': 'C1CO1',
    'C': '[C]',
    'ONH': 'N=O',
    'CHCO': 'C#C[O-]',
    'CCO': 'C1#CO1',
    'CH2CH2OH': 'CC[O-]',
    'NO3': '[N+](=O)([O-])[O-]',
    'O': '[O]',
    'NO2NO2': '[N+](=O)([N+](=O)[O-])[O-]',
    'CH2': '[CH2]',
    'CHCH2': 'C=[CH]',
    'OH': '[OH-]',
    'CC': '[C-]#[C+]',
    'NHNH': 'N=N',
    'H': '[H+]',
    'COHCHOH': 'CC(=O)[O-]',
    'CH3': '[CH3+]',
    'NH3': 'N',
    'OHNNCH3': 'C(=O)(N)N',
    'CHCHOH': 'C[C]=O',
    'COCH3': 'C[C]=O',
    'CCHOH': 'C=C=O',
    'OHNH2': 'NO',
    'COHCH3': 'C1CO1',
    'CHCHO': 'C=C=O'
}

In [9]:
tokenizer = AutoTokenizer.from_pretrained("mrm8488/chEMBL26_smiles_v2")
model = AutoModel.from_pretrained("mrm8488/chEMBL26_smiles_v2")
fe = pipeline('feature-extraction', model=model, tokenizer=tokenizer, device=-1) ## device= (0 for GPU, -1 for CPU)

Some weights of the model checkpoint at mrm8488/chEMBL26_smiles_v2 were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
list_pkey= list(d_pkey_vs_smiles.keys())
list_smiles= list(d_pkey_vs_smiles.values())
list_emb_1 = fe(list_smiles)

list_emb_2 = []
for emb_1 in list_emb_1:
    emb_2 = np.mean(np.array(emb_1[0]), axis=0)
    list_emb_2.append(emb_2)
arr_emb = np.array(list_emb_2)
print(arr_emb.shape)

d_pkey_vs_desc = {}
for i, pkey in enumerate(list_pkey):
    d_pkey_vs_desc[pkey] = arr_emb[i]

(68, 768)


In [11]:
for rinfo in list_rinfo_all:
    rinfo['pdesc'] = d_pkey_vs_desc[rinfo['pkey']]

In [12]:
list_rinfo_all[0].keys()

dict_keys(['bulk_id', 'ads_id', 'bulk_mpid', 'bulk_symbols', 'ads_symbols', 'miller_index', 'shift', 'top', 'adsorption_site', 'class', 'anomaly', 'adslab_slab_key', 'energy', 'slab', 'pkey', 'pdesc'])

In [13]:
# list_rinfo_all[0]

In [14]:
# m = MPRester('Yct0KDbJbqMLWluZEovkwrLXh2VRHXbc')

In [15]:
try:
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)
except Exception as e:
    print(e)
    config_yml_path = "ocp/configs/is2re/all/dimenet_plus_plus/dpp.yml"
    checkpoint_path = "dimenetpp_all.pt"
    # Define the calculator
    calc = OCPCalculator(config_yml=config_yml_path, checkpoint=checkpoint_path)    

No module named 'ocpmodels.models.'


In [16]:
calc

<ocpmodels.common.relaxation.ase_utils.OCPCalculator at 0x7fa7446c7ca0>

In [17]:
AseAtomsAdaptor.get_atoms(list_rinfo_all[0]['slab'])

Atoms(symbols='Hf12Ge8Hf12Ge8', pbc=True, cell=[[3.70111906, 0.0, 2.2662818050461305e-16], [2.2836646906400526e-15, 14.20081168, -7.10040584], [0.0, 0.0, 28.40162336]], bulk_equivalent=..., bulk_wyckoff=..., initial_magmoms=...)

In [18]:
###############################
## work-in-progress
###############################

In [19]:
num_sample = len(list_rinfo_all)
print(num_sample)
list_rinfo_all = list_rinfo_all[:num_sample]

443368


In [20]:
# ray.shutdown()
# ray.init(num_cpus=52)

# calc_id = ray.put(calc)
# aaa = AseAtomsAdaptor()
# aaa_id = ray.put(aaa)


# @ray.remote
# def process_rinfo(r, calc, aaa):
#     try:
#         rinfo = {}
#         ## retrieve slab and adsorbate
#         slab = r['slab']
#         ase_slab = aaa.get_atoms(slab)
#         adslab = ase_slab.copy()
        
#         ## set additional info
#         tags = np.zeros(len(adslab))
#         tags[18:27] = 1
#         tags[27:] = 2
#         adslab.set_tags(tags)
#         cons = FixAtoms(indices=[atom.index for atom in adslab if (atom.tag == 0)])
#         adslab.set_constraint(cons)
#         adslab.center(vacuum=13.0, axis=2)
#         adslab.set_pbc(True)

#         ## Use the shared calculator directly
#         adslab.calc = calc

#         ## calculator, energy, embeddings
#         ads_energy = adslab.get_potential_energy()
        
#         ## Embeddings
#         min_ads_embs = torch.min(calc.trainer.model.module.embs[0], dim=0).values        
#         mean_ads_embs = torch.mean(calc.trainer.model.module.embs[0], dim=0)    
#         max_ads_embs = torch.max(calc.trainer.model.module.embs[0], dim=0).values    
#         combined_embs = torch.cat((min_ads_embs, mean_ads_embs, max_ads_embs))    
        
#         rinfo['sdesc'] = combined_embs.numpy()    
#         rinfo['pdesc'] = r['pdesc']
#         rinfo['energy'] = r['energy']
        
#         return rinfo
        
#     except Exception as e:
#         print(e)
#         return None

# def parallel_processing_ray(list_rinfo_all, calc, aaa, num_sample):
#     futures = [process_rinfo.remote(r, calc, aaa) for r in list_rinfo_all[:num_sample]]
#     list_rinfo = ray.get(futures)
#     list_rinfo = [rinfo for rinfo in list_rinfo if rinfo is not None]
#     return list_rinfo

# list_rinfo = parallel_processing_ray(list_rinfo_all, calc, aaa, num_sample)

# ray.shutdown()

# print(len(list_rinfo))

In [21]:
list_rinfo = []
aaa = AseAtomsAdaptor()
for ir, r in enumerate(list_rinfo_all[:num_sample]):
    if ir%1000==0:
        print(ir)
    try:
        ## filter to keep only anomaly==0        
        # if int(r['anomaly']) != 0:
            # continue
        rinfo = {}
        ###################################
        ## retrieve slab and adsorbate
        slab = r['slab']
        ase_slab = aaa.get_atoms(slab)
        adslab = ase_slab.copy()            
        ##################################

        ## set additional info
        tags = np.zeros(len(adslab))
        tags[18:27] = 1
        tags[27:] = 2
        adslab.set_tags(tags)
        cons= FixAtoms(indices=[atom.index for atom in adslab if (atom.tag == 0)])
        adslab.set_constraint(cons)
        adslab.center(vacuum=13.0, axis=2)
        adslab.set_pbc(True)

        ## calculator, energy, embeddings
        adslab.calc = calc
        ads_energy = adslab.get_potential_energy()
        min_ads_embs = torch.min(calc.trainer.model.module.embs[0], dim=0).values        
        mean_ads_embs = torch.mean(calc.trainer.model.module.embs[0], dim=0)    
        max_ads_embs = torch.max(calc.trainer.model.module.embs[0], dim=0).values    
        # sum_ads_embs = torch.sum(calc.trainer.model.module.embs[0], dim=0)                
        # print(min_ads_embs.shape, mean_ads_embs.shape, max_ads_embs.shape)
        combined_embs = torch.cat((min_ads_embs, mean_ads_embs, max_ads_embs))    
        rinfo['sdesc'] = combined_embs.numpy()    
        rinfo['pdesc'] = r['pdesc']
        rinfo['energy'] = r['energy']
        list_rinfo.append(rinfo)
        
    except Exception as e:
        print(ir, e)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000


In [22]:
# Save list_rinfo to a pickle file
output_path = 'v3/ocp_list.pickle'
with open(output_path, 'wb') as handle:
    pickle.dump(list_rinfo, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [23]:
# Read the pickle file into a list
with open('v3/ocp_list.pickle', 'rb') as handle:
    list_rinfo_all = pickle.load(handle)

In [24]:
len(list_rinfo_all)

443368

In [25]:
list_rinfo_all[0].keys()

dict_keys(['sdesc', 'pdesc', 'energy'])

In [27]:
# list_rinfo_all[0]

In [28]:
sdesc_columns = [f's{i}' for i in range(768)]
pdesc_columns = [f'p{i}' for i in range(768)]
columns = sdesc_columns + pdesc_columns + ['energy']
data = []
for i, rinfo in enumerate(list_rinfo_all):
    try:
        sdesc_flat = rinfo['sdesc'].flatten()
        pdesc_flat = rinfo['pdesc'].flatten()
        energy = rinfo['energy']
        row = list(sdesc_flat) + list(pdesc_flat) + [energy]
        data.append(row)
    except Exception as e:
        print(i, e)
df = pd.DataFrame(data, columns=columns)

In [29]:
df.to_pickle('v3/ocp_df.pickle')
print(df.shape)
df.head(2)

(443368, 1537)


Unnamed: 0,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9,...,p759,p760,p761,p762,p763,p764,p765,p766,p767,energy
0,-2.15989,-1.966015,-2.367016,-2.370347,-1.434321,-1.697237,-2.534137,-2.047184,-2.312323,-1.361503,...,0.273536,0.742442,0.61816,0.486411,-0.93435,0.204278,0.679431,0.341144,-0.852883,-9.992999
1,-2.319036,-2.295187,-2.499794,-2.311768,-2.493567,-1.913282,-1.957203,-1.638834,-2.300856,-1.837182,...,-0.456795,0.304594,1.134157,-0.394405,-1.171003,-0.648138,-0.02155,-1.443121,0.004959,-9.982733
