In [1]:
try:
    # always cd to the right directory in google colab
    from google.colab import drive
    drive.mount('/content/drive')
    %cd "/content/drive/MyDrive/Colab Notebooks/deep_branching_with_domain"
except:
    pass

import math
import time
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from galerkin import DGMNet
from branch import Net
from functools import partial

try:
    from ray import tune
except:
    !pip install ray
    from ray import tune

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks/deep_branching_with_domain
Collecting ray
  Downloading ray-1.12.1-cp37-cp37m-manylinux2014_x86_64.whl (53.2 MB)
[K     |████████████████████████████████| 53.2 MB 270 kB/s 
[?25hCollecting grpcio<=1.43.0,>=1.28.1
  Downloading grpcio-1.43.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.1 MB)
[K     |████████████████████████████████| 4.1 MB 41.1 MB/s 
Collecting virtualenv
  Downloading virtualenv-20.14.1-py2.py3-none-any.whl (8.8 MB)
[K     |████████████████████████████████| 8.8 MB 35.4 MB/s 
Collecting frozenlist
  Downloading frozenlist-1.3.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (144 kB)
[K     |████████████████████████████████| 144 kB 37.7 MB/s 
Collecting aiosignal
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting distlib<1,>=0.3.1
  Downloading distlib-0.3.4-py2.py3-none-any.whl (461 kB)
[K     |█████████████████████████

# Implementation with closed-form formula given by Borodin

In [19]:
lower_bound, upper_bound = -10, 10

def conditional_probability_to_survive(t, x, y, k_arr=range(-5, 5)):
    ans = 0
    for k in k_arr:
        ans += (
                torch.exp(((y - x) ** 2 - (y - x + 2 * k * (upper_bound - lower_bound)) ** 2) / (2 * t))
                - torch.exp(((y - x) ** 2 - (y + x - 2 * lower_bound + 2 * k * (upper_bound - lower_bound)) ** 2) / (2 * t))
        )
    return ans.prod(dim=0)

def is_x_inside(x):
    return torch.logical_and(lower_bound <= x, x <= upper_bound).all(dim=0)

def tune_wrapper(config, grid_d_dim_with_t, true):
    torch.manual_seed(0)
    model = Net(**config)
    output_dict = model.train_and_eval(debug_mode=False, return_dict=True)
    nn = (
        model(torch.tensor(grid_d_dim_with_t.astype(np.float32).T, device=model.device), patch=model.patches-1)
            .detach()
            .cpu()
            .numpy()
    )
    tune.report(runtime=output_dict[f"patch_{model.patches-1}"][0], test_max_error=np.abs(true - nn).max())

## Tuning using heat equation

In [3]:
nu = 1
y, eps = 0, 1e-1
a, b = y - eps, y + eps

# function definition
deriv_map = np.array([0]).reshape(-1, 1)
def f_example(y):
    """
    idx 0 -> no deriv
    """
    return torch.zeros_like(y[0])

def phi_example(x):
    return torch.logical_and(x[0] <= b, x[0] >= a).float()

def exact_example(t, x, T, with_bound=False, k_arr=range(-5, 5)):
    if t == T:
        return np.logical_and(x[0] <= b, x[0] >= a)
    else:
        normal_std = math.sqrt(nu * (T - t))
        if not with_bound:
            # without bound
            return norm.cdf((b - x[0]) / normal_std) - norm.cdf((a - x[0]) / normal_std)
        else:
            # with bound
            ans = 0
            for k in k_arr:
                mu = x[0] - 2 * k * (upper_bound - lower_bound)
                ans += (norm.cdf((b - mu) / normal_std) - norm.cdf((a - mu) / normal_std))
                mu = 2 * lower_bound - 2 * k * (upper_bound - lower_bound) - x[0]
                ans -= (norm.cdf((b - mu) / normal_std) - norm.cdf((a - mu) / normal_std))
            return ans

t_lo, x_lo, x_hi, n = 0., lower_bound, upper_bound, 0
grid = np.linspace(x_lo, x_hi, 100)
grid_d_dim = np.expand_dims(grid, axis=0)
grid_d_dim_with_t = np.concatenate((t_lo * np.ones((1, 100)), grid_d_dim), axis=0)

patches = 1
T = patches * 1.0
true = exact_example(t_lo, grid_d_dim, T, with_bound=True)
terminal = exact_example(T, grid_d_dim, T)

In [None]:
config = {
    "f_fun": f_example,
    "deriv_map": deriv_map,
    "phi_fun": phi_example,
    "conditional_probability_to_survive": conditional_probability_to_survive,
    "is_x_inside": is_x_inside,
    "device": device,
    "x_lo": x_lo,
    "x_hi": x_hi,
    "T": T,
    "verbose": False,
    "nu": nu,
    "branch_patches": patches,
    "branch_nb_path_per_state": 500,
    "outlier_multiplier": 50,
    "save_for_best": False,
    "branch_lr": tune.choice([1e-1, 1e-2, 1e-3]),
    "epochs": tune.choice([3000, 5000, 8000]),
    "layers": tune.choice([4, 5, 6]),
    "neurons": tune.choice([20, 50, 100]),
    "lr_gamma": tune.choice([.1, .5, .8]),
    "branch_activation": tune.choice(["tanh", "relu", "softplus"]),
}
runs = 500

scheduler = tune.schedulers.ASHAScheduler(
    metric="test_max_error",
    mode="min",
    max_t=10,
    grace_period=1,
    reduction_factor=2
)
reporter = tune.JupyterNotebookReporter(
    overwrite=True,
    metric_columns=["test_max_error", "runtime"],
    max_progress_rows=runs,
    metric="test_max_error",
    mode="min",
    sort_by_metric=True,
)
result = tune.run(
    partial(tune_wrapper, grid_d_dim_with_t=grid_d_dim_with_t, true=true),
    resources_per_trial={"cpu": 2, "gpu": 1},
    config=config,
    num_samples=runs,
    scheduler=scheduler,
    progress_reporter=reporter,
    log_to_file=True,
)

Trial name,status,loc,branch_activation,branch_lr,epochs,layers,lr_gamma,neurons,test_max_error,runtime
tune_wrapper_63674_00111,RUNNING,172.28.0.2:9212,tanh,0.001,5000,5,0.8,20,,
tune_wrapper_63674_00112,PENDING,,relu,0.001,8000,5,0.5,50,,
tune_wrapper_63674_00113,PENDING,,softplus,0.01,5000,4,0.8,100,,
tune_wrapper_63674_00114,PENDING,,relu,0.1,3000,6,0.5,100,,
tune_wrapper_63674_00115,PENDING,,softplus,0.01,3000,5,0.5,100,,
tune_wrapper_63674_00116,PENDING,,softplus,0.01,3000,5,0.1,100,,
tune_wrapper_63674_00117,PENDING,,tanh,0.01,8000,5,0.8,50,,
tune_wrapper_63674_00118,PENDING,,tanh,0.001,8000,6,0.1,20,,
tune_wrapper_63674_00119,PENDING,,relu,0.1,3000,5,0.1,100,,
tune_wrapper_63674_00120,PENDING,,softplus,0.01,5000,4,0.5,20,,
