### Settings

In [None]:
!pip install ml_collections

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jacrev, jit, lax, random, vmap

N = jax.scipy.stats.norm.cdf
N_prime = jax.scipy.stats.norm.pdf
N_inv = jax.scipy.stats.norm.ppf
eps = 1e-15


def bound(x): return jnp.maximum(1e-15, x)


def d1_(s, k, r, sigma, tau):
    s, k, tau, sigma = bound(s), bound(k), bound(tau), bound(sigma)
    return (jnp.log(s/k) + (r + sigma*sigma/2)*tau)/(sigma*jnp.sqrt(tau))


def d2_(s, k, r, sigma, tau):
    s, k, tau, sigma = bound(s), bound(k), bound(tau), bound(sigma)
    return d1_(s, k, r, sigma, tau) - sigma*jnp.sqrt(tau)


def bs(s, k, tau, r, cp, sigma):
    s, k, tau, sigma = bound(s), bound(k), bound(tau), bound(sigma)
    d1 = d1_(s, k, r, sigma, tau)
    d2 = d2_(s, k, r, sigma, tau)
    return cp*s*N(cp*d1) - cp*k*jnp.exp(-r*tau)*N(cp*d2)


def black(fwd, k, tau, r, cp, sigma):
    fwd, k, tau, sigma = bound(fwd), bound(k), bound(tau), bound(sigma)
    return bs(fwd, k, tau, 0., cp, sigma) * jnp.exp(-r*tau)


def lv_var(vol, tau):
    return (vol **2) * tau


def lv_fwd_sqr(k, w, tau, dk_w, d2k_w, dt_w):
    # k := ln(k/fwd)
    # w := (sigma_BS)^2*tau
    w = bound(w)
    # see https://quant.stackexchange.com/questions/16343/in-dupires-paper-why-is-s-t-t-in-the-k-t-space
    # A = -k / w * dk_w + 0.25 * (-0.25 - 1/w + (k**2)/(w)) * (dk_w **2)
    A = -k/w * dk_w + 0.25 * (-0.25 - 1/w + (k**2)/(w**2)) * (dk_w **2)
    return dt_w/(1.0 + A + 0.5*d2k_w)


def lv_fwd_pde(lv_fwd_sqr, k, d2k_v, dt_v):
    return dt_v - 0.5 * lv_fwd_sqr * (k **2) * d2k_v


def lv_sqr(s, k, r, v, tau, dk_v, d2k_v, dtau_v):
    A = v**2 + 2.0*v*tau*(dtau_v + r*k*dk_v)
    y = jnp.log(k/(jnp.exp(r*tau)*s))
    B = (1 - k*y/v*dk_v)**2
    C = k*v*tau*(dk_v - 0.25*k*v*tau*(dk_v**2) + k*d2k_v)
    D = B + C
    D = D + (D == 0.)*eps
    return A/(B + C)


def bs_vega(s, k, tau, r, sigma):
    return s*jnp.sqrt(tau)*N_prime(d1_(s, k, r, sigma, tau))


def bs_iv(C, s, k, tau, cp, r=0.0,
            tol=1e-15, tol_vega=1e-15, ini=0.5,
            thr=2.0, max_it=20):
    sigma = ini * jnp.ones_like(C)
    for i in range(max_it):
        diff = bs(s, k, tau, r, cp, sigma) - C
        vega = bs_vega(s, k, tau, r, sigma)

        end = jnp.logical_or(abs(diff) < tol, vega < tol)
        sigma = (sigma - diff / vega) * jnp.logical_not(end) + sigma * end
    return jnp.minimum(jnp.maximum(sigma, 0.), thr)


def derivatives(fn, x):
    def f(x): return fn(x)[0]
    def f_dK(x): return grad(f)(x)[0]
    dx = vmap(grad(f), 0)(x)
    d2x = vmap(grad(f_dK), 0)(x) # CAUTION NOT hessian BUT pdv{pdv[f][x1]}{x}
    dx1, d2x1, dx2 = dx.T[0], d2x.T[0], dx.T[1]
    return dx1, d2x1, dx2


def call_derivatives(fn, x):
    def f(x): return bs(s_0, x.T[0], x.T[1], r, 1, fn(x).flatten())[0]
    def f_dK(x): return grad(f)(x)[0]
    dx = vmap(grad(f), 0)(x)
    d2x = vmap(grad(f_dK), 0)(x) # CAUTION NOT hessian BUT pdv{pdv[f][x1]}{x}
    dx1, d2x1, dx2 = dx.T[0], d2x.T[0], dx.T[1]
    return dx1, d2x1, dx2


def error(fn, data):
    x_train, y_train, x_mesh = data
    K, T = x_mesh[:,0], x_mesh[:,1]

    # Calculate derivatives and values
    dK_V, d2K_V, dT_V = derivatives(fn, x_mesh)
    dK_C, d2K_C, dT_C = call_derivatives(fn, x_mesh)
    V = fn(x_mesh).flatten()
    C = bs(s_0, K, T, r, 1, V)
    LV_sqr = lv_sqr(s_0, K, r, V, T, dK_V, d2K_V, dT_V)

    # Error components
    pred = bs(s_0, x_train.T[0], x_train.T[1], r, 1, fn(x_train).flatten())
    e_acc = (pred - y_train.ravel()) ** 2
    e_pde = (dT_C + r*K*dK_C - 0.5*LV_sqr*(K**2)*d2K_C) ** 2
    e_arb = {
        'dK': jnp.where(dK_C > 0, dK_C**2, jnp.where(dK_C < -jnp.exp(-r*T), dK_C**2, 0)),
        'd2K': jnp.where(d2K_C < 0, d2K_C**2, 0),
        'dT': jnp.where(dT_C < 0, dT_C**2, 0)
    }

    return {k: v for k, v in {'e_acc': e_acc, 'e_pde': e_pde, **{'e_arb_'+k: v for k, v in e_arb.items()}}.items()}, \
           {k: jnp.mean(v) for k, v in {'e_acc': e_acc, 'e_pde': e_pde, **{'e_arb_'+k: v for k, v in e_arb.items()}}.items()}


def adj(loss, lw, m):
    return lw * jnp.mean(m * loss)


def make_loss_lb(components):
    def loss_fn(fn, data, l_ws, params_sa):
        err, metrics = error(fn, data)
        loss = {k: adj(err[k], l_ws[k], params_sa[k]) for k in components}
        return loss, metrics
    return loss_fn


loss_fn_lb = {
    "MLP": make_loss_lb(['e_acc']),
    "PINN": make_loss_lb(['e_acc', 'e_pde']),
    "DCPINN": make_loss_lb(['e_acc', 'e_pde', 'e_arb_dK', 'e_arb_d2K', 'e_arb_dT'])
}


def make_loss_fn(components):
    def loss_fn(fn, data, l_ws, params_sa):
        loss, metrics = loss_fn_lb[components](fn, data, l_ws, params_sa)
        return sum(loss.values()), metrics
    return loss_fn


loss_fn = {
    "MLP": make_loss_fn("MLP"),
    "PINN": make_loss_fn("PINN"),
    "DCPINN": make_loss_fn("DCPINN")
}

In [None]:
import optax

optimizer = "Adam"
beta1 = 0.9
beta2 = 0.999
eps = 1e-8
learning_rate = 1e-3
decay_rate = 0.9
decay_steps = 2000
grad_accum_steps = 0

lr = optax.exponential_decay(
        init_value=learning_rate,
        transition_steps=decay_steps,
        decay_rate=decay_rate,
    )
tx = optax.adam(
    learning_rate=lr, b1=beta1, b2=beta2, eps=eps
)

In [None]:
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import jax.numpy as jnp
from jax import jit, random as jran
from jax.nn.initializers import glorot_normal, normal, zeros
import ml_collections
from flax import linen as nn
from flax.core.frozen_dict import freeze


activation_fn = {
    "tanh": jnp.tanh,
    "sin": jnp.sin,
}


def _get_activation(str):
    if str in activation_fn:
        return activation_fn[str]

    else:
        raise NotImplementedError(f"Activation {str} not supported yet!")


def _weight_fact(init_fn, mean, stddev):
    def init(key, shape):
        key1, key2 = jran.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(stddev)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v

    return init


class Dense(nn.Module):
    features: int
    kernel_init: Callable = glorot_normal()
    bias_init: Callable = zeros
    reparam: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init, (x.shape[-1], self.features)
            )

        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    stddev=self.reparam["stddev"],
                ),
                (x.shape[-1], self.features),
            )
            kernel = g * v
        bias = self.param("bias", self.bias_init, (self.features,))
        y = jnp.dot(x, kernel) + bias

        return y


class MLP(nn.Module):
    arch_name: Optional[str]="MLP"
    hidden_dim: Tuple[int]=(32, 16)
    out_dim: int=1
    activation: str="tanh"
    periodicity: Union[None, Dict]=None
    fourier_emb: Union[None, Dict]=None
    reparam: Union[None, Dict]=None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        for i in range(len(self.hidden_dim)):
            x = Dense(features=self.hidden_dim[i], reparam=self.reparam)(x)
            x = self.activation_fn(x)
        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        x = nn.softplus(x) # CAUTION NOT PLAIN MLP
        return x


class ModifiedMLP(nn.Module):
    arch_name: Optional[str]="ModifiedMLP"
    hidden_dim: Tuple[int]=(32, 16)
    out_dim: int=1
    activation: str="tanh"
    periodicity: Union[None, Dict]=None
    fourier_emb: Union[None, Dict]=None
    reparam: Union[None, Dict]=None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):

        u = Dense(features=self.hidden_dim[0], reparam=self.reparam)(x)
        v = Dense(features=self.hidden_dim[0], reparam=self.reparam)(x)

        u = self.activation_fn(u)
        v = self.activation_fn(v)

        for i in range(len(self.hidden_dim)):
            x = Dense(features=self.hidden_dim[i], reparam=self.reparam)(x)
            x = self.activation_fn(x)
            x = x * u + (1 - x) * v

        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        x = nn.softplus(x) # CAUTION NOT PLAIN MLP
        return x


def ann_gen(config):
    ann = None
    reparam = None
    if config.ann_reparam=="weight_fact":
        reparam = ml_collections.ConfigDict({"type": "weight_fact", "mean": 0.5, "stddev": 0.1})

    if config.ann_str == "MLP":
        ann = MLP(arch_name=config.ann_str,
                  hidden_dim=config.ann_hidden_dim,
                  out_dim=config.ann_out_dim,
                  activation=config.ann_activation_str,
                  periodicity=config.ann_periodicity,
                  fourier_emb=config.ann_fourier_emb,
                  reparam=reparam)
    elif config.ann_str == "ModifiedMLP":
        ann = ModifiedMLP(arch_name=config.ann_str,
                  hidden_dim=config.ann_hidden_dim,
                  out_dim=config.ann_out_dim,
                  activation=config.ann_activation_str,
                  periodicity=config.ann_periodicity,
                  fourier_emb=config.ann_fourier_emb,
                  reparam=reparam)
    return ann

### Input Data

In [None]:
# SABE parameters
s_0, r = 1.0, 0.05
alpha, beta, rho, nu = 0.3, 0.7, -0.6, 0.6

In [None]:
import math
import numpy as np
from scipy.stats import truncnorm


def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
    return truncnorm(
        (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)


def black(F, K, T, r, sigma):
    if T <= 0. or K <= 0.:
        return np.exp(-r*T)*max(F - K, 0.)
    call = F*N(d1_(F, K, 0., sigma, T)) - K*N(d2_(F, K, 0., sigma, T))
    return call*np.exp(-r*T)


def chi(z, rho):
    if rho != 1.:
        s = math.sqrt(1. - 2*rho*z + z*z)
        return math.log((s + z - rho) / (1 - rho))
    else:
        return math.log(1 + z / abs(1 - z))


def SabrVolHagan(F, K, T, alpha, beta, nu, rho, h):
    coef_vol = 0.
    if nu == 0. and rho == 0.:
        return alpha
    if K <= 0.:
        return 0.
    FK = F*K
    subBeta = 1. - beta
    if abs(F - K) > 1e-4:
        logFK = math.log(F / K)
        a0 = pow(subBeta, 2) / 24*pow(logFK, 2)
        a1 = pow(subBeta, 4) / 1920*pow(logFK, 4)
        c0 = pow(F*K, (subBeta) / 2)*(1 + a0 + a1)

        z = nu / alpha*pow(F*K, (subBeta) / 2)*logFK
        c1 = z / chi(z, rho)*math.log((F + h) / (K + h)) / logFK
        coef_vol = alpha / c0*c1
    else:
        coef_vol = alpha*pow(F, beta) / (F + h)
    FK_subBeta = pow(FK, subBeta)
    sqrtFK = math.sqrt(FK)
    y0 = pow(subBeta, 2) / 24.*pow(alpha, 2) / FK_subBeta
    y1 = 0.25*(alpha*beta*rho*nu) / pow(FK, (subBeta) / 2)
    y2 = (2 - 3.*pow(rho, 2)) / 24.*pow(nu, 2)
    y3 = h*(2.*sqrtFK + h)*pow(alpha, 2)
    y4 = 24.*pow(sqrtFK + h, 2)*FK_subBeta

    return coef_vol*(1 + (y0 + y1 + y2 - y3 / y4)*T)


def get_data(n_pts, n_h):
    eps = 1e-3

    xs_mesh, ts_mesh = np.meshgrid(np.linspace(eps, 2.5+eps, n_h), np.linspace(eps, 5.0+eps, n_h))
    x_mesh = np.array([xs_mesh.flatten(), ts_mesh.flatten()]).T
    x_val = x_mesh

    x_train = np.array([get_truncated_normal(mean=s_0, sd=0.4, low=eps, upp=2.5).rvs(n_pts),
                       get_truncated_normal(mean=eps, sd=2, low=eps, upp=5).rvs(n_pts)]).T
    df = []
    for K, tau in x_train:
        F = s_0*np.exp(r*tau)
        vol = SabrVolHagan(F, K, tau, alpha, beta, nu, rho, 0.0)*(1 + np.random.normal(0, 0.1))
        prem = black(F, K, tau, r, vol)
        df.append(prem)
    y_train = np.array(df)[...,np.newaxis]
    data = jnp.array(x_train), jnp.array(y_train), jnp.array(x_mesh)
    return  data

In [None]:
import matplotlib.pyplot as plt# input examples

data = get_data(300, 101)

# premium surfaces
fig, ax = plt.subplots(1, 1, figsize=(6, 4),
                    subplot_kw=dict(projection='3d'))
fig.suptitle(f'premium surface input')
ax.scatter(data[0][:, 0], data[0][:, 1], data[1], s=2)
ax.set_xlim(0, 2.5); ax.set_ylim(0, 5.); ax.set_zlim(0.0, 1)
plt.show()
plt.savefig('img_prem_input.png')
plt.close()

### Calibration

In [None]:
import os
import pickle
import time
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from jax import grad, jacrev, jit, lax, random
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves, tree_map
import jax.example_libraries.optimizers as optimizers

import ml_collections
from flax import linen
from flax.training import train_state, orbax_utils
from flax.serialization import to_state_dict, from_state_dict
import orbax.checkpoint


def save_params(params, path_item: str) -> None:
    serialized_params = to_state_dict(params)
    with open(path_item, 'wb') as f:
        pickle.dump(serialized_params, f)

def load_params(params_initialized, path_item: str):
    with open(path_item, 'rb') as f:
        loaded_dict = pickle.load(f)
    return from_state_dict(loaded_dict, params_initialized)

def flatten_pytree(pytree):
    return ravel_pytree(pytree)[0]

init_l_ws = {
    "MLP": {'e_acc': 1.},
    "PINN": {'e_acc': 1., 'e_pde': 1.},
    "DCPINN": {'e_acc': 1., 'e_pde': 1., 'e_arb_dK': 1., 'e_arb_d2K': 1., 'e_arb_dT': 1.},
}

def init_params_sa(loss_str, data):
    ret = {
        "MLP": {'e_acc': jnp.ones(len(data[0]))},
        "PINN": {'e_acc': jnp.ones(len(data[0])),
             'e_pde': jnp.ones(len(data[2]))},
        "DCPINN": {'e_acc': jnp.ones(len(data[0])),
             'e_pde': jnp.ones(len(data[2])),
             'e_arb_dK': jnp.ones(len(data[2])),
             'e_arb_d2K': jnp.ones(len(data[2])),
                'e_arb_dT': jnp.ones(len(data[2]))}
        }
    return ret[loss_str]

def calibration(config, data):
    ann = ann_gen(config)
    ofunc = loss_fn[config.loss_str]

    l_ws = init_l_ws[config.loss_str]
    params_sa = init_params_sa(config.loss_str, data)

    key = jax.random.PRNGKey(config.seed)
    key, key_init = jax.random.split(key, 2)
    dummy = jnp.ones((1, config.ann_in_dim), dtype=jnp.float32)
    state = train_state.TrainState.create(apply_fn=ann.apply,
                                        params=ann.init(key_init, dummy),
                                        tx=tx)

    # self-adaptive
    opt_init_sa, opt_update_sa, get_params_sa = optimizers.sgd(1.0)
    state_sa = opt_init_sa(params_sa)

    @jit
    def train_step(state, data, l_ws, state_sa):
        params_sa = get_params_sa(state_sa)
        def loss_fn(params):
            def fn(x): return ann.apply(params, x)
            return ofunc(fn, data, l_ws, params_sa)

        params = state.params
        (loss, metric), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        state = state.apply_gradients(grads=grads)

        return state, loss, metric, state_sa

    hist_loss = []
    momentum = config.loss_balancing_momentum
    start_time = time.time()

    @jit
    def train_step_sa(step, state, data, l_ws, state_sa):
        params = state.params
        def fn(x): return ann.apply(params, x)
        def loss_fn_sa(params_sa):
            return ofunc(fn, data, l_ws, params_sa)[0]

        params_sa = get_params_sa(state_sa)
        value_sa, grads_sa = jax.value_and_grad(loss_fn_sa)(params_sa)
        for key in grads_sa.keys():
            grads_sa[key] *= -1.
        state_sa = opt_update_sa(step, grads_sa, state_sa)

        return state_sa

    @jit
    def update_loss_weights(state, data, l_ws, state_sa):
        params = state.params
        params_sa = get_params_sa(state_sa)
        def loss_fn(params):
            def fn(x): return ann.apply(params, x)
            return loss_fn_lb[config.loss_str](fn, data, l_ws, params_sa)[0]
        grads = jacrev(loss_fn)(params)

        # Compute the grad norm of each loss
        grad_norm_dict, mean_nonzero_grad_norm_dict = {},{}
        for key, value in grads.items():
            flattened_grad = flatten_pytree(value)
            # grad_norm_dict[key] = jnp.linalg.norm(flattened_grad)
            grad_norm_dict[key] = jnp.abs(flattened_grad).mean()


        # Compute the mean of grad norms over all losses
        sum_grad_norm = jnp.sum(jnp.stack(tree_leaves(grad_norm_dict)))
        # Grad Norm Weighting
        w = tree_map(lambda x: jnp.where(x==0., 1., sum_grad_norm / x), grad_norm_dict)

        running_average = (
            lambda old_w, new_w: old_w * momentum + (1 - momentum) * new_w
        )
        weights = tree_map(running_average, l_ws, w)
        weights = lax.stop_gradient(weights)

        return weights

    # Training loop
    print(f"{config.loss_str} calibration------>")
    for epoch in range(config.num_epochs):

        # Whack-a-mole Learning
        if config.loss_str=="DCPINN" and epoch % 100 == 0:
            l_ws = update_loss_weights(state, data, l_ws, state_sa)
        if config.loss_str=="DCPINN" and (epoch+50) % 100 == 0:
            state_sa = train_step_sa(epoch, state, data, l_ws, state_sa)

        state, loss, metric, state_sa = train_step(state, data, l_ws, state_sa)

        # Print progress every 1000 epochs
        if (epoch % 1000) == 0:
            print(f"Epoch {epoch}: loss = {loss:.6f}", end="\r")
            hist_loss.append((epoch, float(loss), metric))

    comp_time = time.time() - start_time
    print(f"------> completed in {comp_time:.2f} seconds")

    # Save checkpoint
    CKPT_DIR = 'checkpoints'
    CKPT_DIR = os.path.abspath(CKPT_DIR)
    ckpt = {'params': state.params, 'ms': get_params_sa(state_sa), 'ls': l_ws}
    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(ckpt)
    orbax_checkpointer.save(CKPT_DIR, ckpt, force=True, save_args=save_args)

    def fn(x): return ann.apply(state.params, x)
    return fn, hist_loss

def run_experiment(config):
    data = get_data(config.pts_num, 101)
    model, hist_loss = calibration(config, data)

    # Save results
    # timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    results = {
        'config': config,
        'history': hist_loss,
    }
    with open(f'results_latest.pkl', 'wb') as f:
        pickle.dump(results, f)

    return model

In [None]:
# Example configuration
config_MLP = ml_collections.ConfigDict({
    "data_source": "SABR_syn",
    "pts_num": 300,
    "ann_str": "MLP",
    "loss_str": "MLP",
    "num_epochs": 5000,
    "ann_in_dim": 2,
    "ann_out_dim": 1,
    "ann_activation_str": "tanh",
    "self_adaptive_lr": 1.0,
    "loss_balancing_momentum": 0.5,
    "seed": 42,
    "ann_periodicity": None,
    "ann_fourier_emb": None,
    "ann_reparam": False,
    "ann_hidden_dim": (16,16,16,16)
})

config_DCPINN = ml_collections.ConfigDict({
    "data_source": "SABR_syn",
    "pts_num": 300,
    "ann_str": "MLP",
    "loss_str": "DCPINN",
    "num_epochs": 5000,
    "ann_in_dim": 2,
    "ann_out_dim": 1,
    "ann_activation_str": "tanh",
    "self_adaptive_lr": 1.0,
    "loss_balancing_momentum": 0.5,
    "seed": 42,
    "ann_periodicity": None,
    "ann_fourier_emb": None,
    "ann_reparam": False,
    "ann_hidden_dim": (16,16,16,16)
})


# Run single experiment
model_MLP = run_experiment(config_MLP)
model_DCPINN = run_experiment(config_DCPINN)

### Visualization

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm


def load_model(checkpoint_path, config):
    """Load a saved model from checkpoint"""
    checkpoint_path = os.path.abspath(checkpoint_path)

    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
    ckpt = orbax_checkpointer.restore(checkpoint_path)
    params = ckpt['params']
    params_sa = ckpt['ms']

    nn = ann_gen(config)
    key = jax.random.PRNGKey(config.seed)
    key, key_init = jax.random.split(key, 2)
    dummy = jnp.ones((1, config.ann_in_dim), dtype=jnp.float32)
    def fn(x): return nn.apply(params, x)
    return fn, params_sa

def plot_volatility_surface(model, config, save_path=None):
    """Plot the volatility surface"""
    # Generate grid points
    Ks_val, Ts_val = np.meshgrid(np.linspace(0, 2.5, 101), np.linspace(0, 5, 101))
    x_val = np.array([Ks_val.flatten(), Ts_val.flatten()]).T

    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(6, 4), subplot_kw=dict(projection='3d'))
    fig.suptitle(f'Volatility Surface ({config.loss_str})')

    # Plot surface
    vol_surface = model(x_val).reshape(101, 101)
    surf = ax.plot_surface(Ks_val, Ts_val, vol_surface,
                          cmap=cm.coolwarm, linewidth=0, antialiased=False)

    # Customize plot
    ax.set_xlabel('Strike')
    ax.set_ylabel('Time')
    # ax.set_zlabel('Volatility')
    ax.view_init(30, -70)

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

def plot_arbitrage_heatmaps(model, config, save_path=None):
    """Plot heatmaps showing arbitrage conditions"""
    # Generate dense grid for detailed visualization
    K_mesh = np.linspace(0, 2.5, 201)
    T_mesh = np.linspace(0, 5, 201)
    Ks_mesh, Ts_mesh = np.meshgrid(K_mesh, T_mesh)
    x_mesh = np.array([Ks_mesh.flatten(), Ts_mesh.flatten()]).T
    K, T = x_mesh[:,0], x_mesh[:,1]

    # Calculate derivatives
    dK_C, d2K_C, dT_C = call_derivatives(model, x_mesh)
    e_arb = {
    'dK': jnp.where(dK_C > 0, dK_C**2, jnp.where(dK_C < -jnp.exp(-r*T), dK_C**2, 0)),
    'd2K': jnp.where(d2K_C < 0, d2K_C**2, 0),
    'dT': jnp.where(dT_C < 0, dT_C**2, 0)}

    # Create heatmap plots
    fig, axes = plt.subplots(1, 3, figsize=(18, 4))
    fig.suptitle(f'Arbitrage Condition Heatmaps ({config.loss_str})')

    # First derivative wrt K
    im1 = axes[0].pcolormesh(Ks_mesh, Ts_mesh,
                            e_arb['dK'].reshape(len(T_mesh), len(K_mesh)),
                            cmap='bwr',
                            vmin=-1e-50, vmax=1e-50, shading='auto', alpha=0.7)
    axes[0].set_title('∂C/∂K')
    plt.colorbar(im1, ax=axes[0])

    # Second derivative wrt K
    im2 = axes[1].pcolormesh(Ks_mesh, Ts_mesh,
                            e_arb['d2K'].reshape(len(T_mesh), len(K_mesh)),
                            cmap='bwr',
                            vmin=-1e-50, vmax=1e-50, shading='auto', alpha=0.7)
    axes[1].set_title('∂²C/∂K²')
    plt.colorbar(im2, ax=axes[1])

    # Derivative wrt T
    im3 = axes[2].pcolormesh(Ks_mesh, Ts_mesh,
                            e_arb['dT'].reshape(len(T_mesh), len(K_mesh)),
                            cmap='bwr',
                            vmin=-1e-50, vmax=1e-50, shading='auto', alpha=0.7)
    axes[2].set_title('∂C/∂T')
    plt.colorbar(im3, ax=axes[2])

    # Set labels and adjust layout
    for ax in axes:
        ax.set_xlabel('Strike')
        ax.set_ylabel('Time')
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

def plot_training_history(history_path, save_path=None):
    """Plot training loss history"""
    # Load history
    with open(history_path, 'rb') as f:
        results = pickle.load(f)
    history = results['history']

    # Extract data
    epochs, losses, metrics = zip(*history)

    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(6, 4))
    ax.plot(epochs, losses, 'b-', label='Total Loss')

    # Plot individual metrics if available
    colors = {'e_acc': 'r', 'e_pde': 'g', 'e_arb_dK': 'm', 'e_arb_d2K': 'c', 'e_arb_dT': 'y'}
    for key in metrics[0].keys():
        metric_values = [m[key] for m in metrics]
        ax.plot(epochs, metric_values, f'{colors.get(key, "k")}--', label=key)

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.grid(True)
    ax.legend()
    plt.title('Training History')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

def compare_with_sabr(model, save_path=None):
    """Compare learned volatility surface with true SABR surface"""
    # Generate grid points
    x_mesh = np.linspace(0.0, 2.5, 101)
    t_mesh = np.linspace(0.0, 5.0, 101)
    xs_mesh, ts_mesh = np.meshgrid(x_mesh, t_mesh)
    x_vec = np.array([xs_mesh.flatten(), ts_mesh.flatten()]).T

    # Calculate SABR volatilities
    sabr_vols = []
    for K, tau in x_vec:
        F = s_0 * np.exp(r * tau)
        if K == 0. or tau == 0.:
            vol = np.nan
        else:
            vol = SabrVolHagan(F, K, tau, alpha, beta, nu, rho, 0.0)
        sabr_vols.append(vol)
    sabr_vols = np.array(sabr_vols)

    # Calculate model volatilities
    model_vols = model(x_vec).flatten()

    # Create figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5), subplot_kw=dict(projection='3d'))
    fig.suptitle('SABR vs Learned Volatility Surface')

    # Plot SABR surface
    surf1 = ax1.plot_surface(xs_mesh, ts_mesh,
                            sabr_vols.reshape(len(t_mesh), len(x_mesh)),
                            cmap=cm.coolwarm)
    ax1.set_title('SABR')
    plt.colorbar(surf1, ax=ax1)

    # Plot learned surface
    surf2 = ax2.plot_surface(xs_mesh, ts_mesh,
                            model_vols.reshape(len(t_mesh), len(x_mesh)),
                            cmap=cm.coolwarm)
    ax2.set_title('Calibrated')
    plt.colorbar(surf2, ax=ax2)

    # Plot difference
    diff = model_vols - sabr_vols
    surf3 = ax3.plot_surface(xs_mesh, ts_mesh,
                            diff.reshape(len(t_mesh), len(x_mesh)),
                            cmap=cm.coolwarm)
    ax3.set_title('Difference')
    plt.colorbar(surf3, ax=ax3)

    # Customize plots
    for ax in (ax1, ax2, ax3):
        ax.set_xlabel('Strike')
        ax.set_ylabel('Time')
        ax.set_zlabel('Volatility')
        ax.view_init(30, -70)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()


os.makedirs('figures', exist_ok=True)# Create output directory if it doesn't exist

plot_volatility_surface(model_MLP, config_MLP, 'figures/volatility_surface_MLP.png')
plot_volatility_surface(model_DCPINN, config_DCPINN, 'figures/volatility_surface_DCPINN.png')
plot_arbitrage_heatmaps(model_MLP, config_MLP, 'figures/arbitrage_heatmaps_MLP.png')
plot_arbitrage_heatmaps(model_DCPINN, config_DCPINN, 'figures/arbitrage_heatmaps_DCPINN.png')
# plot_training_history('results_latest.pkl', 'figures/training_history.png')
# compare_with_sabr(model, 'figures/sabr_comparison.png')