In [1]:
%load_ext autoreload
%autoreload 2
from rxn_modeling import train
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import json
import pickle as pkl
from sklearn.metrics import classification_report, precision_recall_fscore_support, f1_score, confusion_matrix
from scipy.stats import gaussian_kde
import matplotlib.cm as cm
import time

In [None]:
objectives = [
    "3_class",
    "4_class"
]
precursor_only = [False, True]
target_only = [True, False]
models = ["lr", "xgb", "rf", "nn"]
featurizations=["pca", "mp_fraction"]

start = time.time()

results = []
for objective in objectives:
    for model in models:
        for featurization in featurizations:
            for prec_only, tar_only in zip(precursor_only, target_only):
                if tar_only and featurization == "pca":
                    continue
                print(objective, model, featurization, prec_only, tar_only)
                best_params, best_estimators, y_pred_train, y_pred_test, X_train_k, X_test_k, y_train_k, y_test_k = train(model=model, objective=objective, featurization=featurization, precursor_only=prec_only, target_only=tar_only)

                precision_scores, recall_scores, f1_scores = [], [], []
                for i, j in zip(y_test_k, y_pred_test):
                    res = precision_recall_fscore_support(i, j, average='macro')
                    precision_scores.append(res[0])
                    recall_scores.append(res[1])
                    f1_scores.append(res[2])

                result = {
                    "objective": objective,
                    "model": model,
                    "precursor_only": str(prec_only),
                    "target_only": str(tar_only),
                    "featurization": featurization,
                    "precision": str(np.mean(precision_scores)),
                    "precision_std": str(np.std(precision_scores)),
                    "recall": str(np.mean(recall_scores)),
                    "recall_std": str(np.std(recall_scores)),
                    "f1_score": str(np.mean(f1_scores)),
                    "f1_score_std": str(np.std(f1_scores)),
                }
                results.append(result)
                with open('data/rxn_condition_log_ablation_micro.json', 'w') as f:
                    json.dump(results, f, indent=4)
                with open('data/time_log.txt', 'a') as f:
                    f.writelines(str((time.time() - start)/60) + "\n")

3_class lr pca True False
Returning extracted data of 26787/31782 reactions.
After removing duplicates, remaining extracted data contains 10854/26787 reactions.
Returning extracted data of 6854/9518 reactions.
After removing duplicates, remaining extracted data contains 2972/6854 reactions.


No electronegativity for Hs. Setting to NaN. This has no physical meaning, and is mainly done to avoid errors caused by the code expecting a float.


Returning extracted data of 14116/35675 reactions.
After removing duplicates, remaining extracted data contains 4752/14116 reactions.
Shape of X: (17537, 0)
Shape of precursor features: (17537, 515)
Shape of y: (17537,)
0.7941083306936967 0.01
0.7944250871080138 0.1
0.7950585999366487 0.01
0.7931580614507443 0.01
0.7953753563509661 0.01
0.7925245486221096 0.01
0.7956921127652835 0.01
0.792841305036427 0.01
0.8007602153943617 0.01
0.7979094076655052 0.01
3_class lr mp_fraction False True
Returning extracted data of 26787/31782 reactions.
After removing duplicates, remaining extracted data contains 14328/26787 reactions.
Returning extracted data of 6854/9518 reactions.
After removing duplicates, remaining extracted data contains 3364/6854 reactions.


No electronegativity for Hs. Setting to NaN. This has no physical meaning, and is mainly done to avoid errors caused by the code expecting a float.


Returning extracted data of 14116/35675 reactions.
After removing duplicates, remaining extracted data contains 3087/14116 reactions.


HBox(children=(HTML(value='StrToComposition'), FloatProgress(value=0.0, max=19511.0), HTML(value='')))




HBox(children=(HTML(value='MultipleFeaturizer'), FloatProgress(value=0.0, max=19511.0), HTML(value='')))


Shape of X: (19511, 103)
Shape of y: (19511,)
0.590261958997722 0.01
0.5837129840546698 0.01
0.5936788154897494 0.01
0.5908314350797267 0.01
0.5925398633257403 0.1
0.5899772209567198 0.01
0.591116173120729 0.01
0.5894077448747153 0.01
0.592255125284738 0.01
0.5746013667425968 0.01
3_class lr mp_fraction True False
