In [2]:
from catemb.data import CatDataset
from catemb.utils import symbol_pos_to_xyz_file,calculate_gaussview_spin_multiplicity
from rdkit import Chem
from rdkit.Chem import AllChem
import warnings,os
import numpy as np
from tqdm import tqdm
warnings.filterwarnings("ignore")
pt = Chem.GetPeriodicTable()

In [3]:
#dataset_old = CatDataset(root="../dataset/processed",name="lig_cat_dataset",seed=42,trunc=5000,save_smiles=True)
dataset = CatDataset(root="../dataset/processed",name="lig_cat_dataset_new",seed=42,trunc=0,save_smiles=True)
#dataset_no_smiles = CatDataset(root="../dataset/processed",name="lig_cat_dataset",seed=42,trunc=0,save_smiles=False)

In [5]:
dataset[0]

Data(x=[109, 10], edge_index=[2, 230], edge_attr=[230, 5], mol_coords=[109, 3], idx=[1], smiles='CN(C)c1cc2[C@@H](c3ccc4c(c3)C3c5ccccc5C4c4ccccc43)N(C)[S@](=O)(->[Ni]<-[P](c2cc1N(C)C)(C(C)(C)C)C(C)(C)C)C(C)(C)C', E=[1])

In [10]:
dest_dir = "../dataset/processed/xtb_opt_new"
for idx in tqdm(range(len(dataset))):
    data = dataset[idx]
    _tmp_dir = f"{dest_dir}/{idx}"
    if os.path.exists(_tmp_dir):
        continue
    atom_idx_lst = data.x[:,0]
    atom_sym_lst = [pt.GetElementSymbol(int(x)+1) for x in atom_idx_lst]
    mol_coords = data.mol_coords.numpy()
    smiles = data.smiles
    mol = AllChem.AddHs(Chem.MolFromSmiles(smiles))
    dative_bonds = [[x.GetBeginAtomIdx()+1,x.GetEndAtomIdx()+1] for x in mol.GetBonds() if x.GetBondType() == Chem.BondType.DATIVE]
    assert [at.GetSymbol() for at in mol.GetAtoms()] == atom_sym_lst
    chrg = sum([atom.GetFormalCharge() for atom in mol.GetAtoms()])
    mult = calculate_gaussview_spin_multiplicity(atom_sym_lst,chrg)
    title = f"smiles {smiles} charge {chrg} multiplicity {mult} uhf {mult-1} dative_bonds {dative_bonds}"
    os.makedirs(_tmp_dir,exist_ok=True)
    symbol_pos_to_xyz_file(atom_sym_lst,mol_coords,f"{_tmp_dir}/{idx}.xyz",title=title)
    constrain_inf = ["$constrain"]
    for x in dative_bonds:
        constrain_inf.append(f"    distance: {x[0]},{x[1]},auto")
    constrain_inf.append("    force constant=1.0")
    constrain_inf.append("$end")
    with open(f"{_tmp_dir}/constrain.inp","w") as f:
        f.write("\n".join(constrain_inf))


100%|██████████| 63059/63059 [05:16<00:00, 199.15it/s]


In [8]:
smiles

'CN(C)c1cc2[C@@H](c3ccc4c(c3)C3c5ccccc5C4c4ccccc43)N(C)[S@](=O)(->[Ni]<-[P](c2cc1N(C)C)(C(C)(C)C)C(C)(C)C)C(C)(C)C'

In [21]:
for idx in range(len(dataset)):
    diff = (dataset[idx].mol_coords-dataset_no_smiles[idx].mol_coords).abs().mean()
    if diff > 1e-3:
        print(idx)

In [22]:
for idx in range(len(dataset_old)):
    diff = (dataset[idx].mol_coords-dataset_old[idx].mol_coords).abs().mean()
    if diff > 1e-3:
        print(idx)

In [23]:
dest_dir = "../dataset/processed/xtb_opt"
for idx,data in enumerate(dataset):
    _tmp_dir = f"{dest_dir}/{idx}"
    if os.path.exists(_tmp_dir):
        continue
    atom_idx_lst = data.x[:,0]
    atom_sym_lst = [pt.GetElementSymbol(int(x)+1) for x in atom_idx_lst]
    mol_coords = data.mol_coords.numpy()
    smiles = data.smiles
    mol = AllChem.AddHs(Chem.MolFromSmiles(smiles))
    dative_bonds = [[x.GetBeginAtomIdx()+1,x.GetEndAtomIdx()+1] for x in mol.GetBonds() if x.GetBondType() == Chem.BondType.DATIVE]
    assert [at.GetSymbol() for at in mol.GetAtoms()] == atom_sym_lst
    
    mult = calculate_gaussview_spin_multiplicity(atom_sym_lst,0)
    title = f"smiles {smiles} charge 0 multiplicity {mult} dative_bonds {dative_bonds}"
        
    
    os.makedirs(_tmp_dir,exist_ok=True)
    symbol_pos_to_xyz_file(atom_sym_lst,mol_coords,f"{_tmp_dir}/{idx}.xyz",title=title)
    constrain_inf = ["$constrain"]
    for x in dative_bonds:
        constrain_inf.append(f"    distance: {x[0]},{x[1]},auto")
    constrain_inf.append("    force constant=1.0")
    constrain_inf.append("$end")
    with open(f"{_tmp_dir}/constrain.inp","w") as f:
        f.write("\n".join(constrain_inf))

    

In [24]:

for idx,data in enumerate(dataset):

    atom_idx_lst = data.x[:,0]
    atom_sym_lst = [pt.GetElementSymbol(int(x)+1) for x in atom_idx_lst]
    mol_coords = data.mol_coords.numpy()
    smiles = data.smiles
    mol = AllChem.AddHs(Chem.MolFromSmiles(smiles))
    dative_bonds = [[x.GetBeginAtomIdx()+1,x.GetEndAtomIdx()+1] for x in mol.GetBonds() if x.GetBondType() == Chem.BondType.DATIVE]
    assert [at.GetSymbol() for at in mol.GetAtoms()] == atom_sym_lst
    
    mult = calculate_gaussview_spin_multiplicity(atom_sym_lst,0)
    if len(dative_bonds) == 0:
        print(idx)

    

11
53
61
88
94
100
115
118
208
211
230
234
236
243
251
283
290
302
315
332
334
353
382
393
418
429
434
460
491
508
510
539
551
573
577
580
585
592
634
636
638
660
670
671
719
732
735
738
764
771
773
774
782
794
827
828
847
865
891
900
906
914
919
924
933
967
975
980
1032
1038
1046
1099
1132
1187
1189
1193
1200
1216
1224
1228
1232
1260
1269
1274
1277
1283
1314
1339
1364
1390
1394
1416
1445
1497
1502
1514
1515
1547
1548
1553
1565
1570
1580
1588
1590
1614
1632
1651
1705
1745
1756
1758
1787
1790
1823
1836
1849
1881
1887
1904
1907
1926
1951
1957
1960
1970
1973
1978
1980
1994
2021
2089
2094
2105
2118
2120
2126
2131
2175
2214
2279
2285
2313
2325
2339
2341
2343
2361
2364
2416
2419
2424
2443
2448
2449
2471
2474
2485
2489
2490
2491
2492
2498
2515
2544
2555
2560
2571
2586
2587
2609
2613
2642
2657
2665
2670
2688
2695
2699
2702
2707
2712
2716
2752
2764
2786
2804
2816
2819
2821
2864
2867
2902
2905
2923
2929
2947
2966
2968
2973
2974
2978
2981
2991
2996
2997
2998
3022
3026
3048
3066
3094
3111
3123
314

KeyboardInterrupt: 

### Dev xtb opt code

In [60]:
import os
from subprocess import run,PIPE

tgt_dir = "/inspire/ssd/tenant_predefaa-9a1b-4522-bb10-8850f313be13/global_user/8359-xulicheng/CatEmb/dataset/processed/xtb_opt"
start_idx = 0
end_idx = 2000


xyz_folders = sorted(os.listdir(tgt_dir),key=lambda x:int(x))
for idx in range(start_idx,end_idx):
    xyz_folder = xyz_folders[idx]
    xyz_folder = os.path.join(tgt_dir,xyz_folder)
    os.chdir(xyz_folder)
    cmd = f"xtb {idx}.xyz --opt --input constrain.inp > xtboptlog.out"
    run(cmd,stdout=PIPE,stderr=PIPE,universal_newlines=True,cwd=None,shell=True,executable='/bin/bash',check=False)

In [61]:
xyz_folders

['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30',
 '31',
 '32',
 '33',
 '34',
 '35',
 '36',
 '37',
 '38',
 '39',
 '40',
 '41',
 '42',
 '43',
 '44',
 '45',
 '46',
 '47',
 '48',
 '49',
 '50',
 '51',
 '52',
 '53',
 '54',
 '55',
 '56',
 '57',
 '58',
 '59',
 '60',
 '61',
 '62',
 '63',
 '64',
 '65',
 '66',
 '67',
 '68',
 '69',
 '70',
 '71',
 '72',
 '73',
 '74',
 '75',
 '76',
 '77',
 '78',
 '79',
 '80',
 '81',
 '82',
 '83',
 '84',
 '85',
 '86',
 '87',
 '88',
 '89',
 '90',
 '91',
 '92',
 '93',
 '94',
 '95',
 '96',
 '97',
 '98',
 '99',
 '100',
 '101',
 '102',
 '103',
 '104',
 '105',
 '106',
 '107',
 '108',
 '109',
 '110',
 '111',
 '112',
 '113',
 '114',
 '115',
 '116',
 '117',
 '118',
 '119',
 '120',
 '121',
 '122',
 '123',
 '124',
 '125',
 '126',
 '127',
 '128',
 '129',
 '130',
 '131',
 '132',
 '133',
 '134',
 '135',
 '136',
 '137',
 '138'