In [1]:
import numpy as np
import sys
import json

from ipywidgets import Dropdown
from widget_periodictable import PTableWidget
from tqdm.auto import tqdm

sys.path.append("../")
from config import ROOT_DIR
from Source.applicability_domain.knn_ad import knnAD
from Source.applicability_domain.mahalanobis import MahalanobisAD

In [2]:
all_elements = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', 'Cs', 'Ba', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fi', 'Mc', 'Lv', 'Ts', 'Og', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']
skipatom_elements = ['Si', 'C', 'Pb', 'I', 'Br', 'Cl', 'Eu', 'O', 'Fe', 'Sb', 'In', 'S', 'N', 'U', 'Mn', 'Lu', 'Se', 'Tl', 'Hf', 'Ir', 'Ca', 'Ta', 'Cr', 'K', 'Pm', 'Mg', 'Zn', 'Cu', 'Sn', 'Ti', 'B', 'W', 'P', 'H', 'Pd', 'As', 'Co', 'Np', 'Tc', 'Hg', 'Pu', 'Al', 'Tm', 'Tb', 'Ho', 'Nb', 'Ge', 'Zr', 'Cd', 'V', 'Sr', 'Ni', 'Rh', 'Th', 'Na', 'Ru', 'La', 'Re', 'Y', 'Er', 'Ce', 'Pt', 'Ga', 'Li', 'Cs', 'F', 'Ba', 'Te', 'Mo', 'Gd', 'Pr', 'Bi', 'Sc', 'Ag', 'Rb', 'Dy', 'Yb', 'Nd', 'Au', 'Os', 'Pa', 'Sm', 'Be', 'Ac', 'Xe', 'Kr', 'Cm', 'Am', 'Ra', 'Bk', 'Cf']
all_train_metals = ['Ac', 'Ag', 'Al', 'Am', 'Au', 'Ba', 'Be', 'Bi', 'Bk', 'Ca', 'Cd', 'Ce', 'Cf', 'Cm', 'Co', 'Cr', 'Cs', 'Cu', 'Dy', 'Er', 'Eu', 'Fe', 'Ga', 'Gd', 'Hf', 'Hg', 'Ho', 'In', 'K', 'La', 'Li', 'Lu', 'Mg', 'Mn', 'Mo', 'Na', 'Nd', 'Ni', 'Np', 'Pa', 'Pb', 'Pd', 'Pm', 'Pr', 'Pt', 'Pu', 'Rb', 'Re', 'Rh', 'Sb', 'Sc', 'Sm', 'Sn', 'Sr', 'Tb', 'Th', 'Ti', 'Tl', 'Tm', 'U', 'V', 'Y', 'Yb', 'Zn', 'Zr']
all_test_metals = ['Sc', 'Y', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho', 'Er', 'Tm', 'Yb', 'Lu', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', 'Cf',]

In [3]:
with open(ROOT_DIR / "Source/models/GCNN_FCNN/skipatom_vectors_dim200.json", "r") as f:
    get_vector = json.load(f)
def get_features(metals):
    return np.array([get_vector[metal] for metal in metals])

In [4]:
# {metal: (knn_ad, distances, densities)}
get_ad = {}
for metal in tqdm(all_test_metals, desc="metals"):
    x_train = get_features([m for m in all_train_metals if m != metal])
    x_test = get_features(skipatom_elements)
    
    knn_ad_estimator = knnAD(x_train)
    mahalanobis_ad_estimator = MahalanobisAD(x_train)
    
    knn_ad = knn_ad_estimator.get_dataset_ad(x_test)
    distances, densities = mahalanobis_ad_estimator.get_dataset_ad(x_test)
    
    get_ad[metal] = (knn_ad.astype(int).tolist(), distances.astype(int).tolist(), densities.astype(int).tolist())

metals:   0%|          | 0/27 [00:00<?, ?it/s]

In [5]:
widget = PTableWidget(states=2, selected_colors=("red", "green"), width = '20px')
metal_dropdown = Dropdown(options=all_test_metals, description='Metal:')

def change_metal(value):
    global widget
    metal = value['new']
    knn_ad, distances, densities = get_ad[metal]
    el_in_ad = dict(zip(skipatom_elements, densities))
    widget.selected_elements = el_in_ad
    widget.disabled_elements = [e for e in all_elements if e not in skipatom_elements]
    widget.display_names_replacements = {e: f"<i><b><u>{e}</u></b></i>" for e in all_train_metals if e != metal}
change_metal({"new": metal_dropdown.value})

metal_dropdown.observe(change_metal, names='value')

display(metal_dropdown)
display(widget)

Dropdown(description='Metal:', options=('Sc', 'Y', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy',…

PTableWidget(allElements=['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si', 'P', '…