# Run the search for dispersants using PyPAL

In [3]:
from pypal import PALCoregionalized
from pypal.models.gpr import build_coregionalized_model
from pypal.pal.utils import get_maxmin_samples

from sklearn.feature_selection import VarianceThreshold
from sklearn.preprocessing import StandardScaler
import pandas as pd
import os
import numpy as np 

DATADIR = '../data/'

TARGETS = ['deltaGmin', 'A2_normalized']

FEATURES = [
    'num_[W]', 'max_[W]', 'num_[Tr]', 'max_[Tr]', 'num_[Ta]', 'max_[Ta]', 'num_[R]', 'max_[R]', '[W]', '[Tr]', '[Ta]',
    '[R]', 'rel_shannon', 'length'
]

In [4]:
def load_data(n_samples, label_scaling: bool = False):
    """Take in Brian's data and spit out some numpy arrays for the PAL"""
    df_full_factorial_feat = pd.read_csv(os.path.join(DATADIR, 'new_features_full_random.csv'))[FEATURES].values
    a2 = pd.read_csv(os.path.join(DATADIR, 'b1-b21_random_virial_large_new.csv'))['A2_normalized'].values
    deltaGMax = pd.read_csv(os.path.join(DATADIR, 'b1-b21_random_virial_large_new.csv'))['A2_normalized'].values  # pylint:disable=unused-variable
    gibbs = pd.read_csv(os.path.join(DATADIR, 'b1-b21_random_deltaG.csv'))['deltaGmin'].values * (-1)
    gibbs_max = pd.read_csv(os.path.join(DATADIR, 'b1-b21_random_virial_large_new.csv'))['deltaGmax'].values
    force_max = pd.read_csv(os.path.join(DATADIR, 'b1-b21_random_virial_large_fit2.csv'))['F_repel_max'].values  # pylint:disable=unused-variable
    rg = pd.read_csv(os.path.join(DATADIR, 'rg_results.csv'))['Rg'].values
    y = np.hstack([rg.reshape(-1, 1), gibbs.reshape(-1, 1), gibbs_max.reshape(-1, 1)])
    assert len(df_full_factorial_feat) == len(a2) == len(gibbs) == len(y)

    vt = VarianceThreshold()
    X = vt.fit_transform(df_full_factorial_feat)

    feat_scaler = StandardScaler()
    X = feat_scaler.fit_transform(X)

    if label_scaling:
        label_scaler = StandardScaler()
        y = label_scaler.fit_transform(y)

    X_train, y_train, greedy_indices = get_maxmin_samples(X, y, n_samples)

    return X, y, greedy_indices

In [5]:
n_samples = 100

In [6]:
X, y, indices  = load_data(n_samples)


m = [build_coregionalized_model(X, y)]

In [7]:
pareto_optimal = []
discarded = []

In [15]:
palinstance = PALCoregionalized(X, m, 3, beta_scale=1/6, epsilon=[0.05, 0.05, 0.05])



In [16]:
palinstance.update_train_set(indices, y[indices])

In [None]:
while sum(palinstance.unclassified):
    idx = palinstance.run_one_step()
    pareto_optimal.append(palinstance.pareto_optimal_indices)
    discarded.append(palinstance.discarded_indices)
    if idx is not None:
        palinstance.update_train_set(np.array([idx]), y[idx : idx + 1, :])
    



Optimization restart 1/20, f = -432.85667668355785
Optimization restart 2/20, f = -432.1124650375778
Optimization restart 3/20, f = -431.9781724223013
Optimization restart 4/20, f = -414.86786607593683
Optimization restart 5/20, f = -430.58178219461905




Optimization restart 6/20, f = -431.5053118541259
Optimization restart 7/20, f = -432.2799417444196




Optimization restart 8/20, f = -431.25490093493437
Optimization restart 9/20, f = -431.96362733989247
Optimization restart 10/20, f = -432.4582432749563
Optimization restart 11/20, f = -432.3338536415798
Optimization restart 12/20, f = -430.91802818821895
Optimization restart 13/20, f = -431.77048622140387
Optimization restart 14/20, f = -430.9032897564181
Optimization restart 15/20, f = -431.4934285883095
Optimization restart 16/20, f = -432.1556886576966
Optimization restart 17/20, f = -430.8031848468037
Optimization restart 18/20, f = -429.6846573868791
Optimization restart 19/20, f = -432.4438012139982
Optimization restart 20/20, f = -432.3829362284903
Optimization restart 1/20, f = -445.31930969878454
Optimization restart 2/20, f = -444.203033908652
Optimization restart 3/20, f = -444.5686632137766
Optimization restart 4/20, f = -444.47521654334105
Optimization restart 5/20, f = -402.67356292555655
Optimization restart 6/20, f = -443.86567630671453
Optimization restart 7/20, f = -



Optimization restart 12/20, f = -431.91061890099127
Optimization restart 13/20, f = -445.9064619375821
Optimization restart 14/20, f = -444.59120471695803
Optimization restart 15/20, f = -404.23094750968926
Optimization restart 16/20, f = -405.4376568535362
Optimization restart 17/20, f = -442.3270097097816
Optimization restart 18/20, f = -443.7501822772981
Optimization restart 19/20, f = -444.28709836466066
Optimization restart 20/20, f = -395.5449359626949
Optimization restart 1/20, f = -455.6943462323014
Optimization restart 2/20, f = -454.3807804576156
Optimization restart 3/20, f = -452.18931581700497
Optimization restart 4/20, f = -452.54632790979264
Optimization restart 5/20, f = -408.15169768267236
Optimization restart 6/20, f = -454.4086646654191
Optimization restart 7/20, f = -452.8610178355792
Optimization restart 8/20, f = -433.7070084683705
Optimization restart 9/20, f = -453.936270309805
Optimization restart 10/20, f = -407.78108156884196
Optimization restart 11/20, f = -



Optimization restart 15/20, f = -995.0358983395002
Optimization restart 16/20, f = -993.5916449412346
Optimization restart 17/20, f = -974.1579406875023
Optimization restart 18/20, f = -999.6480281326417
Optimization restart 19/20, f = -987.8139846387792
Optimization restart 20/20, f = -988.905233237075
Optimization restart 1/20, f = -1460.8799751314214
Optimization restart 2/20, f = -1457.3842480123656
Optimization restart 3/20, f = -1457.470467523227
Optimization restart 4/20, f = -1457.7045747178495
Optimization restart 5/20, f = -1454.1892942012782
Optimization restart 6/20, f = -1455.3081981286207
Optimization restart 7/20, f = -1446.6898143484257
Optimization restart 8/20, f = -1457.9722779585627
Optimization restart 9/20, f = -1405.6289559547995
Optimization restart 10/20, f = -1457.6571327853703
Optimization restart 11/20, f = -1424.475961780121
Optimization restart 12/20, f = -1455.532172429555
Optimization restart 13/20, f = -1455.554922061876
Optimization restart 14/20, f = 