In [1]:
import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
from collections import Counter
from data.data_utils import *
from atl_utils import *

In [2]:
list_of_source_models = joblib.load("joblib_files/source_model_100_trees.joblib")

In [3]:
# Preparing necessary arrays to get suggestions of experiments
additive_ids = np.array([
    [0, 0], # None
    [5, 1], # NaCl
    [5, 3], # NaI
    [6, 1], # MgCl2
    [6, 2], # MgBr2
    [7, 1], # KCl
    [7, 2], # KBr
    [7, 3], # KI
    [8, 1], # ZnCl2
    [9, 9], # succinimide
    [10, 1], # TMSCl
    [11, 1], # TBACl
    [11, 2], # TBABr
    [11, 3] # TBAI
])
additive_names = ["None", "NaCl", "NaI", "MgCl2", "MgBr2", "KCl", "KBr", "KI", "ZnCl2", "succinimide", "TMSCl", "TBACl", "TBABr", "TBAI"]
additive_name_to_id = {}
additive_id_to_name = {}
for i, row in enumerate(additive_ids) : 
    additive_id_to_name.update({
        tuple(row) : additive_names[i]
    })
for k, v in additive_id_to_name.items():
    additive_name_to_id.update({v:k})
    
X_candidate_id = prep_array_of_enumerated_candidates([np.arange(1,6), np.arange(1,30), additive_ids, np.arange(1,10)])
X_candidate_desc = prep_array_of_enumerated_candidates([Ni_source_onehot, ligand_desc, additive_ion_onehot, solvent_desc])
Ala = pd.read_excel("./data/descriptors.xlsx", sheet_name="NHPI", usecols=list(np.arange(2,14))).to_numpy()[1,:]
MeBnKat = pd.read_excel("./data/descriptors.xlsx", sheet_name="Katritzky", usecols=list(np.arange(2,14))).to_numpy()[-1,:]

X_candidate_desc_Ala4MeBn = np.hstack((
    np.hstack((Ala, MeBnKat)).reshape(1,-1).repeat(X_candidate_desc.shape[0], axis=0),
    X_candidate_desc
))

In [4]:
count_list = count_num_topN_suggestions(
    [list_of_source_models], 
    X_candidate_desc_Ala4MeBn, 
    12
)
print_suggestions(X_candidate_id, count_list[0], 5)

[NiBr2(glyme),  4CF3,                  MgCl2,                   tBuOMe] 50
[NiBr2(glyme),  4CF3,                  MgCl2,                    glyme] 45
[NiBr2(glyme),  4CF3,                  MgBr2,                    glyme] 42
[NiBr2(glyme),  4CF3,                  MgBr2,                   tBuOMe] 42
[NiCl2(glyme),  4CF3,                  MgCl2,                    glyme] 36


In [5]:
# When only 4 votes are made by each RFC
count_list = count_num_topN_suggestions(
    [list_of_source_models], 
    X_candidate_desc_Ala4MeBn, 
    4
)
print_suggestions(X_candidate_id, count_list[0], 5)

[NiBr2(glyme),  4CF3,                  MgCl2,                   tBuOMe] 29
[NiBr2(glyme),  4CF3,                  MgCl2,                    glyme] 27
[NiBr2(glyme),  4CF3,                  MgBr2,                   tBuOMe] 17
[NiCl2(glyme),  4CF3,                  MgCl2,                    glyme] 16
[NiBr2(glyme),  4CF3,                  MgBr2,                    glyme] 15


In [6]:
iter1_id_array = np.array([
    [2,4,2,4,6,1,9],
    [2,4,2,4,6,1,6],
    [2,4,2,4,6,2,6],
])
iter1_desc = id_array_to_desc_array(iter1_id_array, True)

### Assigning weights to each tree by the number of wrong predictions
# if any one is wrong, weight becomes zero
tree_weights = np.zeros((100,100))
for i, rfc in enumerate(list_of_source_models) : 
    for j, dtc in enumerate(rfc.estimators_) :
        if np.sum(dtc.predict(iter1_desc)) == 0 : 
            tree_weights[i,j] = 1

weighted_probabilities = np.zeros((X_candidate_desc_Ala4MeBn.shape[0], 100))   # scheme 2

for i, rfc in enumerate(list_of_source_models) : 
    Me_probabilities_by_tree = np.zeros((X_candidate_desc_Ala4MeBn.shape[0], 100))
    for j, dtc in enumerate(rfc.estimators_) :
        Me_probabilities_by_tree[:,j] = dtc.predict_proba(X_candidate_desc_Ala4MeBn)[:,1]
    weighted_probabilities[:,i] = np.matmul(Me_probabilities_by_tree, tree_weights[i,:].reshape(-1,1)).reshape(-1,)

In [7]:
# Total number of decision trees removed
np.sum(tree_weights==0)

3871

In [8]:
iter2_suggestion_counter = Counter()
for i in range(100):
    iter2_suggestion_counter.update(
        np.argsort(weighted_probabilities[:,i])[::-1][:12]
    )

print_suggestions(X_candidate_id, iter2_suggestion_counter, 6)

[NiCl2(glyme),  4CF3,                    KCl,                      THF] 24
[NiCl2(glyme),  4CF3,                  MgBr2,                      THF] 18
[NiCl2(glyme),  4CF3,                  MgCl2,                      THF] 18
[NiCl2(glyme),  4CF3,                    KCl,                  Dioxane] 17
[NiCl2(glyme),  4CF3,                  TMSCl,                      THF] 15
[NiCl2(glyme),  4CF3,            succinimide,                      THF] 14
