# Test notebook

This notebook contains example code on how to run and interact with the GAIN model. 
## Import and train GAIN model
Import libraries, specify arguments and load dummy dataset. Then train a GAIN model with the run function.

In [None]:
import pandas as pd
from src.runtime import run
import numpy as np
data_path = 'data/letter_3ohe.csv'
data = pd.read_csv(data_path)

colnames = data.columns 
dct_args = {
    "batch_size": 128,
    "hint_rate": 0.9,
    "alpha": 10,
    "beta": 10,
    "epochs": 30,
    "missing_rate": 0.2,
    "OHE_features": [(13,29), (29, 45), (45, 61)],  # letter_ohe: (15,31) #letter_3ohe: (13,29), (29, 45), (45, 61)
    "learning_rate": 1e-4,
    "log": True,
    "debug": False,
    "rounding": True,
    "early_stop": False,
    "normalized_loss": True,
    "validation": False,
}
df = data.values 

In [None]:
imputed, model = run(dct_args, data.values)

## Imputing with GAIN

The run function returned the fitted GAIN object in the variable _model_. This object can be used for imputation. Below we impute the first 50 rows of the training data set and create a data frame from them. (NOTE: this cell is purely for demonstration, the passed data does not have missing values)

In [None]:
df_imputed = model.impute(data.head(50).values)
pd.DataFrame(df_imputed, columns=colnames).head(5)

### Evaluators

We can also use the evaluator class to look at the training process of the GAIN object. For this first change the path and filename to the training log you want to evaluate. Then initialize an Evaluator object with the path. Now you can access the `.print_summary()` and `.make_plot(col1: list, col2: list)` methods. The earlier prints out statistics of the training metrics, the latter can be used to quickly make plotly graphs about selected training metrics on two separate axes.

In [None]:
from src.evaluator import Evaluator
path = 'logs/train-run-13-59.csv' #MODIFY to last train log

ev = Evaluator(path)
ev.print_summary()

In [None]:
# ['Generator', 'Discriminator'], ['MSE Reconstruction', 'Entropy Reconstruction']
fig = ev.make_plot(['Generator', 'Discriminator'], ['MSE Reconstruction', 'Entropy Reconstruction'])
fig.show()