In [4]:
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.metrics import mean_absolute_error,r2_score
from sklearn.model_selection import train_test_split

## One-hot

## MFF

In [1]:
import numpy as np
import pandas as pd
from rdkit import RDLogger, Chem
from rdkit.Chem import rdFingerprintGenerator, AllChem, Descriptors
from rdkit.ML.Descriptors import MoleculeDescriptors
from rdkit.Avalon.pyAvalonTools import GetAvalonFP, GetAvalonCountFP
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, HistGradientBoostingRegressor, StackingRegressor, RandomForestClassifier
from sklearn.metrics import r2_score
import seaborn as sns
from tqdm.auto import tqdm
RDLogger.DisableLog('rdApp.*')
tqdm.pandas()

randomseed = 3407

avalon_256 = lambda mol: np.array(GetAvalonCountFP(mol, nBits=256).ToList())
atompair_256 = lambda mol: rdFingerprintGenerator.GetAtomPairGenerator(fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
topo_256 = lambda mol: rdFingerprintGenerator.GetTopologicalTorsionGenerator(fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
maccs_167 = lambda mol: np.array(AllChem.GetMACCSKeysFingerprint(mol).ToList())
rdkit_2_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=2, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_4_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=4, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_6_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=6, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_8_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=8, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_linear_2_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=2, branchedPaths=False, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_linear_4_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=4, branchedPaths=False, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_linear_6_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=6, branchedPaths=False, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_linear_8_256 = lambda mol: rdFingerprintGenerator.GetRDKitFPGenerator(maxPath=8, branchedPaths=False, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
rdkit_layered_2_256 = lambda mol: np.array(Chem.LayeredFingerprint(mol, maxPath=2, fpSize=256).ToList())
rdkit_layered_4_256 = lambda mol: np.array(Chem.LayeredFingerprint(mol, maxPath=4, fpSize=256).ToList())
rdkit_layered_6_256 = lambda mol: np.array(Chem.LayeredFingerprint(mol, maxPath=6, fpSize=256).ToList())
rdkit_layered_8_256 = lambda mol: np.array(Chem.LayeredFingerprint(mol, maxPath=8, fpSize=256).ToList())
morgan_0_256 = lambda mol: rdFingerprintGenerator.GetMorganGenerator(radius=0, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
morgan_2_256 = lambda mol: rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
morgan_4_256 = lambda mol: rdFingerprintGenerator.GetMorganGenerator(radius=4, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
morgan_6_256 = lambda mol: rdFingerprintGenerator.GetMorganGenerator(radius=6, fpSize=256).GetCountFingerprintAsNumPy(mol=mol)
morgan_feature_0_256 = lambda mol: np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=0, useFeatures=True, nBits=256).ToList())
morgan_feature_2_256 = lambda mol: np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=0, useFeatures=True, nBits=256).ToList())
morgan_feature_4_256 = lambda mol: np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=0, useFeatures=True, nBits=256).ToList())
morgan_feature_6_256 = lambda mol: np.array(AllChem.GetMorganFingerprintAsBitVect(mol, radius=0, useFeatures=True, nBits=256).ToList())
rdkit_desc = lambda mol: np.array(MoleculeDescriptors.MolecularDescriptorCalculator(set([desc_name[0] for desc_name in Descriptors._descList])).CalcDescriptors(mol))


embedders = [
    rdkit_desc,
    avalon_256,
    # atompair_256,
    # topo_256,
    maccs_167,
    # rdkit_2_256,
    # rdkit_4_256,
    # rdkit_6_256,
    rdkit_8_256,
    # rdkit_linear_2_256,
    # rdkit_linear_4_256,
    # rdkit_linear_6_256,
    # rdkit_linear_8_256,
    # rdkit_layered_2_256,
    # rdkit_layered_4_256,
    # rdkit_layered_6_256,
    rdkit_layered_8_256,
    # morgan_0_256,
    # morgan_2_256,
    # morgan_4_256,
    morgan_6_256,
    # morgan_feature_0_256,
    # morgan_feature_2_256,
    # morgan_feature_4_256,
    # morgan_feature_6_256,
]

embedding_mol = lambda mol: np.concatenate([embedder(mol) for embedder in embedders])
embedding_smi = lambda smi: embedding_mol(Chem.MolFromSmiles(smi))
embedding_series = lambda series: embedding_smi(series.iloc[0])

  from .autonotebook import tqdm as notebook_tqdm


### Aryl-Scope

In [6]:
random_seed = 42

In [2]:
rxn_data = pd.read_csv("../dataset/rxn_data/aryl_scope_ligand/aryl-scope-ligand.csv")
lig_smi_lst, rct1_smi_lst, rct2_smi_lst, pdt_smi_lst = rxn_data['ligand_smiles'].to_list(),rxn_data['electrophile_smiles'].to_list(),rxn_data['nucleophile_smiles'].to_list(),rxn_data['product_smiles'].to_list()
label = rxn_data['yield'].to_numpy()

In [3]:
lig_mff = np.array([embedding_smi(smi) for smi in lig_smi_lst])
rct1_mff = np.array([embedding_smi(smi) for smi in rct1_smi_lst])
rct2_mff = np.array([embedding_smi(smi) for smi in rct2_smi_lst])
pdt_mff = np.array([embedding_smi(smi) for smi in pdt_smi_lst])
rxn_mff = np.concatenate([lig_mff, rct1_mff, rct2_mff, pdt_mff], axis=1)


In [7]:
r2_ave = []
mae_ave = []
#model = deepcopy(model2)
model = ExtraTreesRegressor(n_estimators=500,random_state=random_seed,n_jobs=-1)
all_test_y = []
all_test_p = []
for seed in range(10):
    train_x,test_x,train_y,test_y = train_test_split(rxn_mff,label,test_size=0.2,random_state=seed)
    model.fit(train_x,train_y)
    test_p = model.predict(test_x)
    r2 = r2_score(test_y,test_p)
    mae = mean_absolute_error(test_y,test_p)
    print(f"seed: {seed}, r2: {r2:.4f}, mae: {mae:.4f}")
    all_test_y.append(test_y)
    all_test_p.append(test_p)
    r2_ave.append(r2)
    mae_ave.append(mae)
r2_ave = np.mean(r2_ave)
mae_ave = np.mean(mae_ave)
print(f"r2_ave: {r2_ave:.4f}, mae_ave: {mae_ave:.4f}")

seed: 0, r2: 0.7772, mae: 7.7200
seed: 1, r2: 0.7628, mae: 8.6322
seed: 2, r2: 0.8230, mae: 7.6255
seed: 3, r2: 0.7748, mae: 8.4890
seed: 4, r2: 0.7858, mae: 8.2471
seed: 5, r2: 0.7457, mae: 8.5237
seed: 6, r2: 0.8111, mae: 7.8573
seed: 7, r2: 0.7972, mae: 8.0087
seed: 8, r2: 0.7174, mae: 9.0374
seed: 9, r2: 0.7857, mae: 8.2839
r2_ave: 0.7781, mae_ave: 8.2425


In [9]:
## 替换lig descriptor
lig_cat_emb = np.loadtxt("./gen_desc/aryl_scope_cat_emb.txt")
lig_smi4chk = np.loadtxt("./gen_desc/aryl_scope_cat_emb_smi.txt",dtype=str)
for lig,ligc in zip(lig_smi_lst,lig_smi4chk):
    assert lig == ligc
rxn_cat_emb_mff = np.concatenate([lig_cat_emb, rct1_mff, rct2_mff, pdt_mff], axis=1)

In [11]:
r2_ave = []
mae_ave = []
#model = deepcopy(model2)
model = ExtraTreesRegressor(n_estimators=500,random_state=random_seed,n_jobs=-1)
all_test_y = []
all_test_p = []
for seed in range(10):
    train_x,test_x,train_y,test_y = train_test_split(rxn_cat_emb_mff,label,test_size=0.2,random_state=seed)
    model.fit(train_x,train_y)
    test_p = model.predict(test_x)
    r2 = r2_score(test_y,test_p)
    mae = mean_absolute_error(test_y,test_p)
    print(f"seed: {seed}, r2: {r2:.4f}, mae: {mae:.4f}")
    all_test_y.append(test_y)
    all_test_p.append(test_p)
    r2_ave.append(r2)
    mae_ave.append(mae)
r2_ave = np.mean(r2_ave)
mae_ave = np.mean(mae_ave)
print(f"r2_ave: {r2_ave:.4f}, mae_ave: {mae_ave:.4f}")

seed: 0, r2: 0.7767, mae: 8.1334
seed: 1, r2: 0.8028, mae: 8.3176
seed: 2, r2: 0.7973, mae: 8.4721
seed: 3, r2: 0.7639, mae: 9.1234
seed: 4, r2: 0.8103, mae: 8.0148
seed: 5, r2: 0.7495, mae: 8.7316
seed: 6, r2: 0.8258, mae: 8.2482
seed: 7, r2: 0.7745, mae: 8.6926
seed: 8, r2: 0.6891, mae: 9.5586
seed: 9, r2: 0.7953, mae: 8.2090
r2_ave: 0.7785, mae_ave: 8.5501


### Denmark

In [12]:
rxn_data = pd.read_csv("../dataset/rxn_data/denmark/NS_acetal_dataset_with_pdt.csv")
rxn_data

Unnamed: 0.1,Unnamed: 0,Imine,Thiol,Catalyst,ΔΔG,Product
0,0,O=C(/N=C/c1ccccc1)c1ccccc1,Sc1ccccc1,O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccc...,1.179891,O=C(NC(Sc1ccccc1)c1ccccc1)c1ccccc1
1,1,O=C(/N=C/c1ccccc1)c1ccccc1,CCS,O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccc...,0.501759,CCSC(NC(=O)c1ccccc1)c1ccccc1
2,2,O=C(/N=C/c1ccccc1)c1ccccc1,SC1CCCCC1,O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccc...,0.650584,O=C(NC(SC1CCCCC1)c1ccccc1)c1ccccc1
3,3,O=C(/N=C/c1ccccc1)c1ccccc1,COc1ccc(S)cc1,O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccc...,1.238109,COc1ccc(SC(NC(=O)c2ccccc2)c2ccccc2)cc1
4,4,O=C(/N=C/c1ccc(C(F)(F)F)cc1)c1ccccc1,Sc1ccccc1,O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccc...,1.179891,O=C(NC(Sc1ccccc1)c1ccc(C(F)(F)F)cc1)c1ccccc1
...,...,...,...,...,...,...
1070,1070,O=C(/N=C/c1ccccc1)c1ccccc1,Sc1ccccc1,O=P1(O)Oc2c(-c3cc(C(F)(F)F)cc(C(F)(F)F)c3)cc3c...,1.531803,O=C(NC(Sc1ccccc1)c1ccccc1)c1ccccc1
1071,1071,O=C(/N=C/c1ccccc1)c1ccccc1,Cc1ccccc1S,O=P1(O)Oc2c(-c3cc(C(F)(F)F)cc(C(F)(F)F)c3)cc3c...,1.531803,Cc1ccccc1SC(NC(=O)c1ccccc1)c1ccccc1
1072,1072,O=C(/N=C/c1ccc(C(F)(F)F)cc1)c1ccccc1,Cc1ccccc1S,O=P1(O)Oc2c(-c3cc(C(F)(F)F)cc(C(F)(F)F)c3)cc3c...,1.370104,Cc1ccccc1SC(NC(=O)c1ccccc1)c1ccc(C(F)(F)F)cc1
1073,1073,O=C(/N=C/c1cccc2ccccc12)c1ccccc1,Sc1ccccc1,O=P1(O)Oc2c(-c3cc(C(F)(F)F)cc(C(F)(F)F)c3)cc3c...,1.301167,O=C(NC(Sc1ccccc1)c1cccc2ccccc12)c1ccccc1


In [13]:
imine_lst = rxn_data['Imine'].to_list()
thiol_lst = rxn_data['Thiol'].to_list()
cat_lst = rxn_data['Catalyst'].to_list()
pdt_lst = rxn_data['Product'].to_list()
label = rxn_data['ΔΔG'].to_numpy()

In [14]:
imine_mff = np.array([embedding_smi(smi) for smi in imine_lst])
thiol_mff = np.array([embedding_smi(smi) for smi in thiol_lst])
cat_mff = np.array([embedding_smi(smi) for smi in cat_lst])
pdt_mff = np.array([embedding_smi(smi) for smi in pdt_lst])
rxn_mff = np.concatenate([imine_mff, thiol_mff, cat_mff, pdt_mff], axis=1)


In [15]:
r2_ave = []
mae_ave = []
all_test_y = []
all_test_p = []
#model = deepcopy(model2)
model = ExtraTreesRegressor(n_estimators=500,random_state=random_seed,n_jobs=-1)
for seed in range(10):
    train_x,test_x,train_y,test_y = train_test_split(rxn_mff,label,test_size=475/1075,random_state=seed)
    model.fit(train_x,train_y)
    test_p = model.predict(test_x)
    all_test_y.append(test_y)
    all_test_p.append(test_p)
    r2 = r2_score(test_y,test_p)
    mae = mean_absolute_error(test_y,test_p)
    print(f"seed: {seed}, r2: {r2:.4f}, mae: {mae:.4f}")

    r2_ave.append(r2)
    mae_ave.append(mae)
r2_ave = np.mean(r2_ave)
mae_ave = np.mean(mae_ave)
print(f"r2_ave: {r2_ave:.4f}, mae_ave: {mae_ave:.4f}")

seed: 0, r2: 0.8725, mae: 0.1554
seed: 1, r2: 0.8839, mae: 0.1508
seed: 2, r2: 0.8941, mae: 0.1479
seed: 3, r2: 0.8760, mae: 0.1625
seed: 4, r2: 0.8919, mae: 0.1539
seed: 5, r2: 0.9031, mae: 0.1478
seed: 6, r2: 0.8957, mae: 0.1499
seed: 7, r2: 0.9074, mae: 0.1416
seed: 8, r2: 0.9006, mae: 0.1490
seed: 9, r2: 0.8796, mae: 0.1562
r2_ave: 0.8905, mae_ave: 0.1515


In [16]:
## 替换cat descriptor
cat_cat_emb = np.loadtxt("./gen_desc/denmark_cat_emb.txt")
cat_smi4chk = np.loadtxt("./gen_desc/denmark_cat_emb_smi.txt",dtype=str)
for cat,catc in zip(cat_lst,cat_smi4chk):
    assert cat == catc
rxn_cat_emb_mff = np.concatenate([imine_mff, thiol_mff, cat_cat_emb, pdt_mff], axis=1)

In [23]:
rxn_mff_cat_emb_all = np.concatenate([imine_mff, thiol_mff, cat_mff, pdt_mff, cat_cat_emb], axis=1)

In [17]:
r2_ave = []
mae_ave = []
all_test_y = []
all_test_p = []
#model = deepcopy(model2)
model = ExtraTreesRegressor(n_estimators=500,random_state=random_seed,n_jobs=-1)
for seed in range(10):
    train_x,test_x,train_y,test_y = train_test_split(rxn_cat_emb_mff,label,test_size=475/1075,random_state=seed)
    model.fit(train_x,train_y)
    test_p = model.predict(test_x)
    all_test_y.append(test_y)
    all_test_p.append(test_p)
    r2 = r2_score(test_y,test_p)
    mae = mean_absolute_error(test_y,test_p)
    print(f"seed: {seed}, r2: {r2:.4f}, mae: {mae:.4f}")

    r2_ave.append(r2)
    mae_ave.append(mae)
r2_ave = np.mean(r2_ave)
mae_ave = np.mean(mae_ave)
print(f"r2_ave: {r2_ave:.4f}, mae_ave: {mae_ave:.4f}")

seed: 0, r2: 0.8892, mae: 0.1526
seed: 1, r2: 0.8932, mae: 0.1489
seed: 2, r2: 0.9029, mae: 0.1446
seed: 3, r2: 0.8881, mae: 0.1521
seed: 4, r2: 0.9047, mae: 0.1438
seed: 5, r2: 0.8955, mae: 0.1597
seed: 6, r2: 0.8940, mae: 0.1491
seed: 7, r2: 0.9062, mae: 0.1464
seed: 8, r2: 0.9039, mae: 0.1480
seed: 9, r2: 0.8842, mae: 0.1664
r2_ave: 0.8962, mae_ave: 0.1512


#### 样本外

In [18]:
oos_imine = ['O=C(/N=C/c1ccc(Cl)cc1Cl)c1ccccc1']
oos_thiol = ['Cc1ccccc1S']
oos_cat = ['O=P1(O)Oc2c(-c3c(C4CCCCC4)cc(C4CCCCC4)cc3C3CCCCC3)cc3ccccc3c2-c2c(c(-c3c(C4CCCCC4)cc(C4CCCCC4)cc3C3CCCCC3)cc3ccccc23)O1',
 'CC(C)c1cc(C(C)C)c(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4c(C(C)C)cc(C(C)C)cc4C(C)C)cc4ccccc4c2-3)c(C(C)C)c1',
 'COc1cccc(OC)c1-c1cc2ccccc2c2c1OP(=O)(O)Oc1c(-c3c(OC)cccc3OC)cc3ccccc3c1-2',
 'Cc1cc(C)c(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4c(C)cc(C)cc4C)cc4ccccc4c2-3)c(C)c1',
 'O=P1(O)Oc2c(-c3c4ccccc4cc4ccccc34)cc3ccccc3c2-c2c(c(-c3c4ccccc4cc4ccccc34)cc3ccccc23)O1',
 'O=P1(O)Oc2c(-c3ccc4ccc5cccc6ccc3c4c56)cc3ccccc3c2-c2c(c(-c3ccc4ccc5cccc6ccc3c4c56)cc3ccccc23)O1', 
 'O=P1(O)Oc2c(-c3ccccc3OC(F)(F)F)cc3ccccc3c2-c2c(c(-c3ccccc3OC(F)(F)F)cc3ccccc23)O1',
 'CC(C)(C)c1cc(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4cc(C(C)(C)C)cc(C(C)(C)C)c4)cc4ccccc4c2-3)cc(C(C)(C)C)c1',
 'CC(C)(C)c1cc(-c2cc3c(c4c2OP(=O)(O)Oc2c(-c5cc(C(C)(C)C)cc(C(C)(C)C)c5)cc5c(c2-4)CCCC5)CCCC3)cc(C(C)(C)C)c1',
 'Cc1ccc(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4ccc(C)cc4)cc4ccccc4c2-3)cc1',
 'CC(C)(C)c1ccc(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4ccc(C(C)(C)C)cc4)cc4ccccc4c2-3)cc1',
 'O=P1(O)Oc2c(-c3ccc(-c4ccc5ccccc5c4)cc3)cc3ccccc3c2-c2c(c(-c3ccc(-c4ccc5ccccc5c4)cc3)cc3ccccc23)O1',
 'COc1ccc(-c2cc3ccccc3c3c2OP(=O)(O)Oc2c(-c4ccc(OC)cc4)cc4ccccc4c2-3)cc1',
 'COCc1cccc(-c2cc3c(c4c2OP(=O)(O)Oc2c(-c5cccc(COC)c5)cc5c(c2-4)CCCC5)CCCC3)c1',
 'O=P1(O)Oc2c(-c3ccccc3)cc3ccccc3c2-c2c(c(-c3ccccc3)cc3ccccc23)O1',
 'C[Si](c1ccccc1)(c1ccccc1)c1cc2ccccc2c2c1OP(=O)(O)Oc1c([Si](C)(c3ccccc3)c3ccccc3)cc3ccccc3c1-2',
 'O=P1(O)Oc2c(Br)cc3c(c2-c2c4c(cc(Br)c2O1)CCCC4)CCCC3',
 'O=P1(O)Oc2c([Si](c3ccccc3)(c3ccccc3)c3ccccc3)cc3ccccc3c2-c2c(c([Si](c3ccccc3)(c3ccccc3)c3ccccc3)cc3ccccc23)O1',
 'O=P1(O)Oc2c(Cc3ccc(C(F)(F)F)cc3C(F)(F)F)cc3ccccc3c2-c2c(c(Cc3ccc(C(F)(F)F)cc3C(F)(F)F)cc3ccccc23)O1']
train_data_idx_lst = []
sub_test_data_idx_lst = []
cat_test_data_idx_lst = []
sub_cat_test_data_idx_lst = []
for i in range(len(imine_lst)):
    imine_smi = imine_lst[i]
    thiol_smi = thiol_lst[i]
    cat_smi = cat_lst[i]
    if (imine_smi in oos_imine or thiol_smi in oos_thiol) and not (cat_smi in oos_cat):
        sub_test_data_idx_lst.append(i)
    elif (cat_smi in oos_cat) and not (imine_smi in oos_imine or thiol_smi in oos_thiol):
        cat_test_data_idx_lst.append(i)
    elif (imine_smi in oos_imine or thiol_smi in oos_thiol) and (cat_smi in oos_cat):
        sub_cat_test_data_idx_lst.append(i)
    else:
        train_data_idx_lst.append(i)
len(train_data_idx_lst),len(sub_test_data_idx_lst),len(cat_test_data_idx_lst),len(sub_cat_test_data_idx_lst)

(384, 216, 304, 171)

In [19]:
train_x = rxn_mff[train_data_idx_lst]
train_y = label[train_data_idx_lst]
sub_test_x = rxn_mff[sub_test_data_idx_lst]
sub_test_y = label[sub_test_data_idx_lst]
cat_test_x = rxn_mff[cat_test_data_idx_lst]
cat_test_y = label[cat_test_data_idx_lst]
sub_cat_test_x = rxn_mff[sub_cat_test_data_idx_lst]
sub_cat_test_y = label[sub_cat_test_data_idx_lst]
model.fit(train_x,train_y)
sub_test_p = model.predict(sub_test_x)
cat_test_p = model.predict(cat_test_x)
sub_cat_test_p = model.predict(sub_cat_test_x)
r2_sub_test = r2_score(sub_test_y,sub_test_p)
r2_cat_test = r2_score(cat_test_y,cat_test_p)
r2_sub_cat_test = r2_score(sub_cat_test_y,sub_cat_test_p)
mae_sub_test = mean_absolute_error(sub_test_y,sub_test_p)
mae_cat_test = mean_absolute_error(cat_test_y,cat_test_p)
mae_sub_cat_test = mean_absolute_error(sub_cat_test_y,sub_cat_test_p)
#print(f"r2_sub_test: {r2_sub_test:.4f}, r2_cat_test: {r2_cat_test:.4f}, r2_sub_cat_test: {r2_sub_cat_test:.4f}")
print(f"mae_sub_test: {mae_sub_test:.4f}, mae_cat_test: {mae_cat_test:.4f}, mae_sub_cat_test: {mae_sub_cat_test:.4f}")

mae_sub_test: 0.1349, mae_cat_test: 0.2586, mae_sub_cat_test: 0.2756


In [20]:
train_x = rxn_cat_emb_mff[train_data_idx_lst]
train_y = label[train_data_idx_lst]
sub_test_x = rxn_cat_emb_mff[sub_test_data_idx_lst]
sub_test_y = label[sub_test_data_idx_lst]
cat_test_x = rxn_cat_emb_mff[cat_test_data_idx_lst]
cat_test_y = label[cat_test_data_idx_lst]
sub_cat_test_x = rxn_cat_emb_mff[sub_cat_test_data_idx_lst]
sub_cat_test_y = label[sub_cat_test_data_idx_lst]
model.fit(train_x,train_y)
sub_test_p = model.predict(sub_test_x)
cat_test_p = model.predict(cat_test_x)
sub_cat_test_p = model.predict(sub_cat_test_x)
r2_sub_test = r2_score(sub_test_y,sub_test_p)
r2_cat_test = r2_score(cat_test_y,cat_test_p)
r2_sub_cat_test = r2_score(sub_cat_test_y,sub_cat_test_p)
mae_sub_test = mean_absolute_error(sub_test_y,sub_test_p)
mae_cat_test = mean_absolute_error(cat_test_y,cat_test_p)
mae_sub_cat_test = mean_absolute_error(sub_cat_test_y,sub_cat_test_p)
#print(f"r2_sub_test: {r2_sub_test:.4f}, r2_cat_test: {r2_cat_test:.4f}, r2_sub_cat_test: {r2_sub_cat_test:.4f}")
print(f"mae_sub_test: {mae_sub_test:.4f}, mae_cat_test: {mae_cat_test:.4f}, mae_sub_cat_test: {mae_sub_cat_test:.4f}")

mae_sub_test: 0.1389, mae_cat_test: 0.4156, mae_sub_cat_test: 0.4493


In [24]:
train_x = rxn_mff_cat_emb_all[train_data_idx_lst]
train_y = label[train_data_idx_lst]
sub_test_x = rxn_mff_cat_emb_all[sub_test_data_idx_lst]
sub_test_y = label[sub_test_data_idx_lst]
cat_test_x = rxn_mff_cat_emb_all[cat_test_data_idx_lst]
cat_test_y = label[cat_test_data_idx_lst]
sub_cat_test_x = rxn_mff_cat_emb_all[sub_cat_test_data_idx_lst]
sub_cat_test_y = label[sub_cat_test_data_idx_lst]
model.fit(train_x,train_y)
sub_test_p = model.predict(sub_test_x)
cat_test_p = model.predict(cat_test_x)
sub_cat_test_p = model.predict(sub_cat_test_x)
r2_sub_test = r2_score(sub_test_y,sub_test_p)
r2_cat_test = r2_score(cat_test_y,cat_test_p)
r2_sub_cat_test = r2_score(sub_cat_test_y,sub_cat_test_p)
mae_sub_test = mean_absolute_error(sub_test_y,sub_test_p)
mae_cat_test = mean_absolute_error(cat_test_y,cat_test_p)
mae_sub_cat_test = mean_absolute_error(sub_cat_test_y,sub_cat_test_p)
#print(f"r2_sub_test: {r2_sub_test:.4f}, r2_cat_test: {r2_cat_test:.4f}, r2_sub_cat_test: {r2_sub_cat_test:.4f}")
print(f"mae_sub_test: {mae_sub_test:.4f}, mae_cat_test: {mae_cat_test:.4f}, mae_sub_cat_test: {mae_sub_cat_test:.4f}")

mae_sub_test: 0.1341, mae_cat_test: 0.2607, mae_sub_cat_test: 0.2774
