# Benchmark
#### Author: JP Melo

### Imports

In [None]:
from derpinns.nn import *
from derpinns.utils import *
from derpinns.trainer import *
import time
import torch
import kfac
import json

## Parameters

In [None]:
# Fix seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Global parameters
n_assets = list(range(10))
sampler = "pseudo"                                         
nn_shapes = ["10x1","10x3","32x1","32x3"]                
device = torch.device("cpu")   
dtype = torch.float32

## Training

We use a multi-stage training process, where the first stage is done with ADAM and the second stage with SSBroyden as this methods achieve the best perfomance.

In [None]:
# load json if it exists, else create it
try:
    with open("bench.json", "r") as f:
        bench = json.load(f)
except FileNotFoundError:
    bench = {}
    with open("bench.json", "w") as f:
        json.dump(bench, f)

In [None]:
for nn_shape in nn_shapes:
    for assets in n_assets:
        params = OptionParameters(
            n_assets=assets,
            tau=1.0,
            sigma=np.array([0.2] * assets),
            rho=np.eye(assets) + 0.25 * (np.ones((assets, assets)) - np.eye(assets)),
            r=0.05,
            strike=100,
            payoff=payoff
        )

        # Build the net to be used
        model = build_nn(
            nn_shape=nn_shape,
            input_dim=assets,
            dtype=torch.float32
        ).apply(weights_init).to(device)
        model.train()

        # Set the training parameters
        batch_size = 100
        total_iter = 200
        boundary_samples = 20_000
        interior_samples = boundary_samples*assets*2
        initial_samples = boundary_samples*assets*2

        # Create dataset to traing over
        dataset = SampledDataset(
            params, interior_samples, initial_samples, boundary_samples, sampler, dtype, device, seed=0)

        # Set optimizer and training function
        # 1e-2 is big enought to reach a reasonable min in few steps
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
        preconditioner = kfac.preconditioner.KFACPreconditioner(model)

        # # Set the training function
        closure = DimlessBS()\
            .with_dataset(dataset, loader_opts={'batch_size': batch_size, "shuffle": True, "pin_memory": True})\
            .with_model(model)\
            .with_device(device)\
            .with_dtype(dtype)

        trainer = PINNTrainer()\
            .with_optimizer(optimizer)\
            .with_device(device)\
            .with_dtype(dtype)\
            .with_training_step(closure)\
            .with_preconditioner(preconditioner)\
            .with_epochs(total_iter)\
        
        start = time.time()
        trainer.train()
        adam_time = time.time() - start

        boundary_samples = 500
        interior_samples = boundary_samples*assets*2
        initial_samples = boundary_samples*assets*2

        # We create new samples
        dataset = SampledDataset(
            params, interior_samples, initial_samples, boundary_samples, sampler, dtype, device, seed=0)

        optimizer = SSBroyden(
            model.parameters(),
            max_eval=1_000,
        )
        batch_size = len(dataset) # we use all samples

        closure = closure.with_dataset(
            dataset, loader_opts={'batch_size': batch_size, "shuffle": False, "pin_memory": True})

        trainer = trainer.with_optimizer(optimizer).with_training_step(closure)

        start = time.time()
        trainer.train()
        ssbroyden_time = time.time() - start

        state = closure.get_state()
        # We save the model
        torch.save(model.state_dict(), f"model_{nn_shape}_{assets}.pt")

        l2_err = compare_with_mc(model, params, n_prices=200,
                          n_simulations=10_000, dtype=dtype, device=device, seed=42)['l2_rel_error']
        print("L2 Error (%): ", l2_err*100)
        
        bench[f"{nn_shape}_{assets}"] = {
            "adam_time": adam_time,
            "ssbroyden_time": ssbroyden_time,
            "state": state,
            "l2_err": l2_err,
        }

        with open("bench.json", "w") as f:
            json.dump(bench, f)