In [2]:
import numpy as np
from rdkit import Chem
from units.data import MultiDatasetV2


In [69]:
dataset = np.load('../dataset/UniTS_Lib.npy',allow_pickle=True)
dataset[0].keys()

dict_keys(['atoms', 'inpt_orien', 'std_orien', 'forces', 'energy', 'freqs', 'modes', 'TCH', 'TCG', 'chrg', 'mult', 'rdmol', 'rct_idx', 'graph_feat'])

In [95]:
processed_dataset = []
for data in dataset:
    atoms = data['atoms']
    coord = data['inpt_orien']
    graphfeat = data['graph_feat']
    molgraph = data['rdmol']
    blk_idxs = Chem.GetMolFrags(molgraph)
    reacting_atoms = data['rct_idx']
    processed_dataset.append([atoms,coord,graphfeat,molgraph,blk_idxs,reacting_atoms])
np.save("../dataset/units_lib_0.npy",np.array(processed_dataset,dtype=object))

In [76]:
pyg_dataset = MultiDatasetV2(root='../dataset',name_regrex='units_lib_0.npy')

Processing...
Done!


## get OOS test set

In [8]:
from units.utils import atoms_to_formula
from rdkit import Chem
pt = Chem.GetPeriodicTable()
dataset = np.load('../dataset/units_lib_0.npy',allow_pickle=True)

### Formula-OOS

In [12]:
formula_ct_map = {}
tot_formula = []
for data in dataset:
    atoms = data[0]
    formula = atoms_to_formula([pt.GetElementSymbol(at) for at in atoms])
    formula_ct_map[formula] = formula_ct_map.get(formula, 0) + 1
    tot_formula.append(formula)
formula_keys = sorted(formula_ct_map.keys())
np.random.seed(42)
np.random.shuffle(formula_keys)

sel_formula_keys = []
sel_data_ct = 0
for key in formula_keys:
    sel_data_ct += formula_ct_map[key]
    sel_formula_keys.append(key)
    if sel_data_ct >= 400:
        break
print(len(sel_formula_keys),sel_data_ct)

152 403


In [13]:
test_data_idx = []
for i in range(len(tot_formula)):
    if tot_formula[i] in sel_formula_keys:
        test_data_idx.append(i)
print(len(test_data_idx))
train_val_idx = list(set(range(len(tot_formula))) - set(test_data_idx))
print(len(train_val_idx))

403
3988


In [14]:
train_val_ts_data = dataset[train_val_idx]
test_ts_data = dataset[test_data_idx]

np.save(f"../dataset/formula_oos_trainval_{len(train_val_ts_data)+len(test_ts_data)}_{len(train_val_ts_data)}_0.npy",train_val_ts_data)
np.save(f"../dataset/formula_oos_test_{len(train_val_ts_data)+len(test_ts_data)}_{len(test_ts_data)}_0.npy",test_ts_data)

### Element-OOS

In [15]:
test_elements = ["Cs","Ba","La","Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg","Tl","Pb"] # all the elements in the 6th period
test_data_idx = []
for idx,data in enumerate(dataset):
    atoms = [pt.GetElementSymbol(at) for at in data[0]]
    atoms_set = list(set(atoms))
    for atom in test_elements:
        if atom in atoms_set:
            test_data_idx.append(idx)
            break
print(len(test_data_idx))
train_val_idx = list(set(range(len(tot_formula))) - set(test_data_idx))
print(len(train_val_idx))

351
4040


In [16]:
train_val_ts_data = dataset[train_val_idx]
test_ts_data = dataset[test_data_idx]

np.save(f"../dataset/element_oos_trainval_{len(train_val_ts_data)+len(test_ts_data)}_{len(train_val_ts_data)}_0.npy",train_val_ts_data)
np.save(f"../dataset/element_oos_test_{len(train_val_ts_data)+len(test_ts_data)}_{len(test_ts_data)}_0.npy",test_ts_data)

### Maximum-OOS

In [17]:
max_atom = 120
test_data_idx = []
for idx,data in enumerate(dataset):
    atoms = data[0]
    if len(atoms) > max_atom:
        test_data_idx.append(idx)
print(len(test_data_idx))
train_val_idx = list(set(range(len(tot_formula))) - set(test_data_idx))
print(len(train_val_idx))

219
4172


In [18]:
train_val_ts_data = dataset[train_val_idx]
test_ts_data = dataset[test_data_idx]

np.save(f"../dataset/maximum_oos_trainval_{len(train_val_ts_data)+len(test_ts_data)}_{len(train_val_ts_data)}_0.npy",train_val_ts_data)
np.save(f"../dataset/maximum_oos_test_{len(train_val_ts_data)+len(test_ts_data)}_{len(test_ts_data)}_0.npy",test_ts_data)