# Creating .json objects for VIZ-plot

This notebook shows you how to create a .json object that can be used for creating a VIZ-plot. Here, we use the Fish toxicity dataset and a Gaussian Process Regressor, but of course you can do this for any dataset and model.

In [None]:
# stopping sklearn warnings
def warn(*args, **kwargs):
    pass

import warnings
warnings.warn = warn

In [None]:
import sys
sys.path.append('../')

import sklearn as sk
import matplotlib.pyplot as plt
import numpy as np

import json

from math import sqrt

%matplotlib inline

from sklearn.gaussian_process import GaussianProcessRegressor

from sklearn.metrics import mean_squared_error, r2_score

from src.load_data import DataLoader
from src.retro_score import RetroScore, run_retro_score
from src.evaluation import rs_at_threshold_plot, overlapping_points
from src.dimensionality_reduction import get_activations
from src.visualization import sample_new_point, get_range, sort_axes, make_dict_df#, dim_reduction, pickle_dict, take_sample

np.random.seed(42)

#from train_model import get_data, train_reg_model
#from select_data import select_features, get_points_by_index

### Load data

In [None]:
data = DataLoader()

# select one dataset (comment others out)
data.toxicfish()

# randomize the order of the data
data.randomize_order()

# split into train and test
data.split_train_test(test_size=0.2, random=False)

# scale features
data.scale_min_max(-1,1)

X_train, X_test, y_train, y_test = data.get_split_data()

### Train model

In [None]:
# train regression model
reg = GaussianProcessRegressor().fit(X_train, y_train)

# obtain predictions on train and test set
y_pred = reg.predict(X_test).reshape(-1,1)
y_train_pred = reg.predict(X_train).reshape(-1,1)

# show performance on train and test set
print(f"{sqrt(mean_squared_error(y_test, y_pred))} - RMSE test")
print(f"{r2_score(y_test, y_pred)} - r2 test")
print()
print(f"{sqrt(mean_squared_error(y_train, y_train_pred))} - RMSE train")
print(f"{r2_score(y_train, y_train_pred)} - r2 train")

### Sample point from test set

Because we can only show one prediction in a VIZ plot at the time, we must sample an instance.

In [None]:
# index for the selected point
i=163

X, y = sample_new_point(X_test, y_test, i=i)

y_pred = reg.predict(X).reshape(-1,1)

### Calculate the RETRO-score

We calculate the RETRO-score for the new instance, and obtain the instance and its neighbors.

In [None]:
# calculate RETRO score
rs = RetroScore(k=5)

retro_score, retro_score_unn, nbs_x, nbs_y = run_retro_score(rs, X_train, y_train, X, y_pred, y_train_pred)

### Create .json file

Here, we make the prediction and its neighbors ready to be saved as a json file.

In [None]:
# reshape neighbors
X_nbs = nbs_x.reshape(5,-1)
y_nbs = nbs_y.reshape(-1)

X_vars = list(data.vars_X)
y_var = data.var_y

# scale back data to original values (instead of normalized)
X_nbs_unscaled, y_nbs_unscaled = data.unscale(X_nbs, y_nbs)
X_unscaled, y_pred_unscaled = data.unscale(X, y_pred)

# get range of data (used to set axis length)
X_min, X_max = get_range(data.scalerX, X_unscaled)
y_min, y_max = get_range(data.scalerY, y_pred_unscaled)

# sort axes based on random sample from train data
sample_X = X_train[np.random.randint(X_train.shape[0], size=70), :]
sorted_axes = sort_axes(sample_X, list(data.vars_X))

# place points in dictionary of appropriate format
points = make_dict_df(X_unscaled, y_pred_unscaled,
                                    X_nbs_unscaled, y_nbs_unscaled,
                                    X_vars, y_var)

# place ranges in dictionary of appropriate format
ranges = make_dict_df(X_min.reshape(1,-1), y_min.reshape(1,1),
                                    X_max.reshape(1,-1), y_max.reshape(1,1),
                                    X_vars, y_var, concat_data=False)

# all data in dictionary to pass to JavaScript
data_set = {'points': points,
            'ranges': ranges,
            'retro_score': retro_score[0],
            'target': y_var,
            'sorted_axes': sorted_axes}

with open('vizplot.json', 'w') as outfile:
    json.dump(data_set, outfile)