## GANITE(PyTorch): Train and Evaluation

This notebook presents the solution for training and evaluating GANITE(PyTorch version).

The implementation of GANITE is adapted in the local `ite` library.

First, make sure that all the depends are installed.
```
pip install -r requirements.txt
pip install .
```

### Setup

First, we import all the dependencies necessary for the task.

In [None]:
# Double check that we are using the correct interpreter.
import sys
print(sys.executable)
    
# Import depends
import ite.algs.ganite_torch.model as alg
import ite.datasets as ds
import ite.utils.numpy as utils

from matplotlib import pyplot as plt
import torch

### Load the Dataset



In [None]:
train_ratio = 0.8
 
[Train_X, Train_T, Train_Y, Opt_Train_Y, Test_X, Test_Y] = ds.load("twins", train_ratio)

### Load the model

In [None]:
dim = len(Train_X[0])
dim_hidden = 8
dim_outcome = Test_Y.shape[1]
 
model = alg.GaniteTorch(
    dim, # number of features
    dim_hidden, # size of the hidden layers
    dim_outcome, # size of the output
    num_iterations=10000, # number of training iterations
    alpha=2, # alpha hyperparameter, used for the Generator block loss
    beta=2, # beta hyperparameter, used for the ITE block loss
    minibatch_size=128, # data batch size
    num_discr_iterations=10, # number of iterations executed by the discriminator.
)

assert model is not None

### Train

In [None]:
metrics = model.fit(Train_X, Train_T, Train_Y, Test_X, Test_Y)

### Plot train metrics on the test set

In [None]:
plt.plot(metrics["gen_block"]["D_loss"], label="Cf Discriminator loss")
plt.plot(metrics["gen_block"]["G_loss"], label="Cf Generator loss")
plt.legend()
plt.show()

plt.plot(metrics["ite_block"]["I_loss"], label="ITE loss")
plt.plot(metrics["ite_block"]["Loss_sqrt_PEHE"], label="Loss_PEHE")
plt.plot(metrics["ite_block"]["Loss_ATE"], label="Loss_ATE")
plt.legend()
plt.show()

### Predict

In [None]:
hat_y = model.predict(Test_X)

print(type(hat_y), type(Test_X))
utils.sqrt_PEHE(hat_y.to_numpy(), Test_Y)

### Test
Will can run inferences and get metrics directly

In [None]:
test_metrics = model.test(Test_X, Test_Y)

test_metrics["sqrt_PEHE"]