In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sigkernel
import torch
import torch.nn.functional as F
import torchsde
import plotly.graph_objects as go
from scipy import stats
from sklearn.preprocessing import StandardScaler
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

from src.sdes import ScaledBrownianMotion
from src.utils.helper_functions.plot_helper_functions import make_grid

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### 1. Train FFN to learn the map from parameters to $(1-\alpha)\%$ of Gamma dist

This is to ensure we can backprop.

Want to learn the map from $(\mu, \theta)$ to the $1-\alpha$ critical threshold. We generate uniformly some values of $(\mu, \theta)$ and train the model to learn the associated $(1-\alpha)$ quantile.

In [None]:
load_model = False

if load_model:
    model = torch.load('entire_model.pth')
    model.eval()  # Important to set the model to evaluation mode

#### Generate data

In [None]:
def gamma_ppf(q: float, alpha, beta):
    """
    Calculates quantile function of corresponding distribution for the inverse of the CDF at the level q

    :param q:   Quantile to calculate inverse cdf of
    :return:    Value of quantile function at q
    """
    return stats.gamma.ppf(q, a=alpha, scale=beta)


n_train_samples = 8192*8
n_val_samples   = 8192
n_test_samples  = 8192
alpha           = 0.05
scale_data      = True

mu_ub    = 1000
theta_ub = 100

batch_size      = 512
# mu_gen          = lambda x: torch.empty(x, 1, dtype=torch.float).uniform_(0, mu_ub)
# theta_gen       = lambda x: torch.empty(x, 1, dtype=torch.float).uniform_(0, theta_ub)
mu_gen          = lambda x: torch.empty(x, 1, dtype=torch.float).exponential_(1/500)
theta_gen       = lambda x: torch.empty(x, 1, dtype=torch.float).exponential_(1/50)

train_mu     = mu_gen(n_train_samples)
train_theta  = theta_gen(n_train_samples)
train_labels = torch.tensor([gamma_ppf(1-alpha, mu, theta) for mu, theta in zip(train_mu, train_theta)])

test_mu     = mu_gen(n_test_samples)
test_theta  = theta_gen(n_test_samples)
test_labels = torch.tensor([gamma_ppf(1-alpha, mu, theta) for mu, theta in zip(test_mu, test_theta)])

val_mu     = mu_gen(n_val_samples)
val_theta  = theta_gen(n_val_samples)
val_labels = torch.tensor([gamma_ppf(1-alpha, mu, theta) for mu, theta in zip(val_mu, val_theta)])

train_X = torch.cat([train_mu, train_theta], axis=1).to(device).float()
test_X  = torch.cat([test_mu, test_theta], axis=1).to(device).float()
val_X  = torch.cat([val_mu, val_theta], axis=1).to(device).float()

train_Y = train_labels.to(device).float()
test_Y  = test_labels.to(device).float()
val_Y = val_labels.to(device).float()

# Scale
if scale_data:
    scaler = StandardScaler()
    scaler.fit_transform(train_X.cpu().numpy())
    
    train_X = torch.tensor(scaler.transform(train_X.cpu().numpy())).to(device).float()
    test_X = torch.tensor(scaler.transform(test_X.cpu().numpy())).to(device).float() # use the same scaler on test data
    val_X = torch.tensor(scaler.transform(val_X.cpu().numpy())).to(device).float() # and validation data

train_dataset = TensorDataset(train_X, train_Y)
train_loader  = DataLoader(dataset=train_dataset, batch_size=batch_size)

#### Select model

In [None]:
def polynomial_features(x, order=3):
    """Generate polynomial features up to a given order for tensor x."""
    x_poly = x.clone()
    for i in range(2, order + 1):
        x_poly = torch.cat((x_poly, x ** i), dim=1)
    return x_poly

# Polynomial Regression Model
class PolynomialRegressionModel(nn.Module):
    def __init__(self, order, dim):
        super(PolynomialRegressionModel, self).__init__()
        # Adjust the input feature size based on the polynomial order
        self.order  = order
        self.linear = nn.Linear(sum((dim**o for o in range(1, order))), 1)
        
    def forward(self, x):
        features = polynomial_features(x, self.order)
        return self.linear(features)
    
def inverse_distance_weighting(x, y, z, x_new, y_new, power=2):
    """
    Perform inverse distance weighting interpolation for scattered data.
    
    x, y: Coordinates of the data points.
    z: Values at the data points.
    x_new, y_new: Coordinates of the point to interpolate.
    power: Power parameter for the weighting. Higher values assign greater influence to closer points.
    """
    # Calculate squared distances from the new point to all existing points
    distances_sq = (x - x_new) ** 2 + (y - y_new) ** 2
    
    # Avoid division by zero for the exact location points
    distances_sq = torch.clamp(distances_sq, min=1e-6)
    
    # Calculate weights based on inverse distance
    weights = 1 / distances_sq ** (power / 2)
    
    # Compute weighted average
    z_new = torch.sum(weights * z) / torch.sum(weights)
    
    return z_new

In [None]:
epochs            = 500
batch_itr         = n_train_samples/batch_size
losses            = torch.zeros(int(epochs*batch_itr))
count             = 0
patience          = 20  # how many epochs to wait for improvement in the validation loss before stopping
best_val_loss     = None
epochs_no_improve = 0
lr                = 5e-4
l1_factor         = 1e-5
itr               =  tqdm(range(epochs))

model_choice = "idw"

if model_choice != "idw":
    if model_choice == "ffn":
        n_neurons = 64
        model = nn.Sequential(
            nn.Linear(2, n_neurons),
            nn.BatchNorm1d(n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, n_neurons),
            nn.BatchNorm1d(n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, n_neurons),
            nn.BatchNorm1d(n_neurons),
            nn.ReLU(),
            nn.Linear(n_neurons, 1)
        ).to(device)

    elif model_choice == "polyfit":
        model = PolynomialRegressionModel(order=3, dim=2).to(device)
        
                                                             
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10)

    model.train()

    for epoch in itr:

        if (epoch + 1) % int(epochs/10) == 0 or (epoch == 0):
            
            l1_regularization = torch.tensor(0.).to(device)
            for param in model.parameters():
                l1_regularization += torch.norm(param, 1)

            train_loss  = criterion(model(train_X), train_Y) + l1_factor*l1_regularization

            with torch.no_grad():
                val_loss = criterion(model(val_X), val_Y) + l1_factor*l1_regularization

            tqdm.write(f'Epoch {epoch+1}: Loss/Train: {train_loss.item():.4f}')
            tqdm.write(f'Epoch {epoch+1}: Loss/Val: {val_loss.item():.4f}\n')

        for x, y in train_loader:
            optimizer.zero_grad()
            l1_regularization = torch.tensor(0., requires_grad=True).to(device).float()
            for param in model.parameters():
                l1_regularization += torch.norm(param, 1).float()

            outputs = model(x)
            loss    = criterion(outputs, y) + l1_factor*l1_regularization
            loss.backward()
            optimizer.step()

            losses[count] = loss.item()
            count += 1

        prev_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']

        if current_lr != prev_lr:
            tqdm.write(f"Epoch {epoch+1}: Learning rate reduced from {prev_lr} to {current_lr}")
            prev_lr = current_lr

        # Early stopping
        if best_val_loss is None:
            best_val_loss = val_loss.item()
        elif val_loss.item() < best_val_loss:
            epochs_no_improve = 0
            best_val_loss = val_loss.item()
        else:
            epochs_no_improve += 1

        if epochs_no_improve == patience:
            tqdm.write(f"Epoch {epoch+1}: Stopping training early")
            break
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))

    make_grid(axis=ax)

    with torch.no_grad():
        ax.plot(losses, alpha=0.5)
else:
    tmu = train_mu.squeeze()
    tthet = train_theta.squeeze()
    tlab = train_labels.squeeze()
    model = lambda vec, x=tmu, y=tthet, z=tlab: inverse_distance_weighting(x, y, z, vec[0], vec[1])

In [None]:
# Generate predictions
with torch.no_grad():
    if model_choice != "idw":
        predictions = model(test_X).cpu().numpy()
    else:
        test_X_ = torch.hstack([test_mu, test_theta])
        predictions = np.array([model(val) for val in test_X_])

# Prepare the true values and predictions for plotting
true_values  = test_Y.cpu().numpy()
mu_values    = test_X[:, 0].cpu().numpy()
theta_values = test_X[:, 1].cpu().numpy()

# Create traces for true values and predictions
trace_true = go.Scatter3d(
    x=mu_values,
    y=theta_values,
    z=true_values.flatten(),
    mode='markers',
    marker=dict(
        size=4,
        color='blue',  # color for true values
        opacity=0.25
    ),
    name='True Values'
)

trace_pred = go.Scatter3d(
    x=mu_values,
    y=theta_values,
    z=predictions.flatten(),
    mode='markers',
    marker=dict(
        size=4,
        color='red',  # color for predictions
        opacity=0.25
    ),
    name='Predictions'
)

data = [trace_true, trace_pred]

layout = go.Layout(
    scene=dict(
        xaxis_title='mu',
        yaxis_title='theta',
        zaxis_title='Quantile Value'
    ),
    margin=dict(r=0, b=0, l=0, t=0)
)

fig = go.Figure(data=data, layout=layout)
fig.show()

### 2. Training to learn the optimal scaling

Steps are as follows, for each epoch:

1. Sample N sets of $2M$ paths from $H_0$ and $M$ paths from $H_1$.
2. Calculate a bootstrapped approximation to the MMD under $H_0$ and $H_1$ (biased)
3. For the null distribution, approximate the critical value using statistics from the null distribution, use (learnt) Gamma approximation (must be FFN so you can backprop)
4. Calculate smoothed Type II error expectation
5. Backprop, repeat.

We don't necessarily need to use the Gamma approximation - since the sort function in torch can be backpropagated through.

In [None]:
mu0, sig = 0., 0.2
mu1, beta = 0., 0.3
noise_type = "diagonal"
sde_type   = "ito"

# Grid params
T           = 1
grid_points = 64
dt_scale    = 1e-1  # Finer refinements give better solutions (but slower)
ts          = torch.linspace(0, T, grid_points).to(device)

# Path bank params
path_bank_size = 32768
state_size     = 1

h0_model = ScaledBrownianMotion(mu0, sig, noise_type, sde_type).to(device)
h1_model = ScaledBrownianMotion(mu1, beta, noise_type, sde_type).to(device)

y0 = torch.full(size=(path_bank_size, state_size), fill_value=0.).to(device)

_dt = dt_scale*torch.diff(ts)[0]

with torch.no_grad():
    h0_paths = torchsde.sdeint(h0_model, y0, ts, method='euler', dt = _dt)
    h1_paths = torchsde.sdeint(h1_model, y0, ts, method='euler', dt = _dt)

In [None]:
h0_paths = torch.cat([ts.unsqueeze(-1).expand(path_bank_size, ts.size(0), 1), torch.transpose(h0_paths, 1, 0)], dim=2)
h1_paths = torch.cat([ts.unsqueeze(-1).expand(path_bank_size, ts.size(0), 1), torch.transpose(h1_paths, 1, 0)], dim=2)

In [None]:
dyadic_order  = 1  # At least 1 for accurate calculation of the signature kernel
static_kernel = sigkernel.LinearKernel()

signature_kernel = sigkernel.SigKernel(static_kernel=static_kernel, dyadic_order=dyadic_order)

In [None]:
class PhiKernel(torch.nn.Module):
    def __init__(self, initial_: torch.tensor):
        super().__init__()
        
        device        = initial_.device
        
        n_scalings    = initial_.shape[0]
        self.lambda_  = torch.nn.Parameter(initial_, requires_grad=True).to(device)
        self.weights_ = (torch.ones(n_scalings)/n_scalings).unsqueeze(1).to(device)

In [None]:
def smooth_step(x, scale_factor):
    """
    Smooth approximation to step function.
    
    Parameters:
        x (Tensor): Input tensor
        threshold (float): The point around which to transition from 0 to 1
        scale_factor (float): Controls the sharpness of the transition. Larger values make the transition sharper.
        
    Returns:
        Tensor: Smoothed step function applied to x
    """
    return torch.sigmoid(scale_factor * x)

In [None]:
def compute_biased_mmd(kernel, mu, nu, max_batch=128):
    K_XX = kernel.compute_Gram(mu, mu, sym=True, max_batch=max_batch)
    K_YY = kernel.compute_Gram(nu, nu, sym=True, max_batch=max_batch)
    K_XY = kernel.compute_Gram(mu, nu, sym=False, max_batch=max_batch)
    
    return K_XX.mean() + K_YY.mean() - 2. * torch.mean(K_XY)

def expected_type2_error(dist, crit_value):
    n_atoms = dist.shape[0]
    num_fail = dist <= crit_value
    return sum(num_fail)/n_atoms

def soft_quantile(v, tau, temperature=1.0):
    """
    Approximate a quantile of a vector `v` in a differentiable way.
    - v: Input tensor (vector).
    - tau: Desired quantile (0 to 1).
    - temperature: Temperature for the softmax, controls smoothness.
    """
    n = v.size(0)
    ranks = torch.arange(1, n+1, device=v.device).float()
    target_rank = tau * n
    # Compute distances to the desired rank, negative distances for differentiation
    distances = -(ranks - target_rank).abs()
    # Compute soft weights
    weights = F.softmax(distances / temperature, dim=0)
    # Compute weighted sum as the quantile approximation
    quantile_approx = torch.sum(weights * v)
    return quantile_approx

In [None]:
# Begin optimisation test
# model.eval()
n_steps    = 256
tr_loss    = torch.zeros(n_steps).to(device)

n_paths = 16
n_atoms = 64
alpha   = 0.05
n_scalings = 2

initial_   = torch.tensor([1. for _ in range(n_scalings)]).to(device)
lr         = 5e-2
redr_flag  = True
optimizer_ = "Adam"
lambdas_   = torch.zeros((n_steps, n_scalings)).to(device)

scale_factor = 1e4
method_      = "quantile"
tau          = 1-alpha

dyadic_order     = 0  # At least 1 for accurate calculation of the signature kernel
static_kernel    = sigkernel.LinearKernel()
signature_kernel = sigkernel.SigKernel(static_kernel=static_kernel, dyadic_order=dyadic_order)

phi                = PhiKernel(initial_).to(device)
time_normalisation = True
optimizer          = getattr(torch.optim, optimizer_)(phi.parameters(), lr=lr)

scale = lambda mean, var, N: mean**2/var
rate  = lambda mean, var, N: (N*var)/mean

scaler_means = torch.tensor(scaler.mean_, dtype=torch.float).to(device)
scaler_stds = torch.tensor(scaler.scale_, dtype=torch.float).to(device)

In [None]:
torch.autograd.set_detect_anomaly(True)

trange_steps = tqdm(range(n_steps), position=0)

for step in trange_steps:
    
    # Init mmds
    mmd_h0 = torch.zeros((n_atoms, n_scalings)).to(device)
    mmd_h1 = torch.zeros((n_atoms, n_scalings)).to(device)
    
    # Init loss
    loss = 0
    
    # Calculate distributions for each scaling
    for k, lambd_ in enumerate(phi.lambda_):
        
        t_h0_paths = lambd_*h0_paths.clone().to(device)
        t_h1_paths = lambd_*h1_paths.clone().to(device)
        
        if time_normalisation:
            t_h0_paths[..., 0] /= lambd_*T
            t_h1_paths[..., 0] /= lambd_*T

        for j in range(n_atoms):
            # Generate random noise
            h0_rands = torch.randperm(path_bank_size)[:int(2*n_paths)]

            ix, jx = h0_rands[:n_paths], h0_rands[n_paths:]
            iy     = torch.randperm(path_bank_size)[:n_paths]
            
            # mmd_h0[j, k] = compute_biased_mmd(t_h0_paths[ix], t_h0_paths[jx])
            # mmd_h1[j, k] = compute_biased_mmd(t_h0_paths[ix], t_h1_paths[iy])
            
            mmd_h0[j, k] = compute_biased_mmd(signature_kernel, t_h0_paths[ix], t_h0_paths[jx])
            mmd_h1[j, k] = compute_biased_mmd(signature_kernel, t_h0_paths[ix], t_h1_paths[iy])
            
    # Flatten the MMD values 
    mmd_h0_phi = torch.matmul(mmd_h0, phi.weights_)
    mmd_h1_phi = torch.matmul(mmd_h1, phi.weights_)
    
    ### This is where it gets tricky
    if method_ == "gamma":
        # Use Gamma approximation as differentiable surrogate
        mean = mmd_h0_phi.mean()
        var  = mmd_h0_phi.var()

        mu    = scale(mean, var, n_paths).unsqueeze(-1)
        theta = rate(mean, var, n_paths).unsqueeze(-1)

        inputs = torch.hstack([mu, theta])

        # Transform the inputs using the scaler
        inputs_scaled = (inputs - scaler_means)/scaler_stds

        # Get critical threshold
        crit_thresh = model(inputs_scaled.unsqueeze(0))

        with torch.no_grad():
            is_cpu = inputs_scaled.cpu()
            trange_steps.write(f"Scaled input values: mu: {is_cpu[0]:.4f}, theta: {is_cpu[1]:.4f}")
            trange_steps.write(f"Estimated critical threshold: {crit_thresh.cpu()[0][0]: .4f}")
            trange_steps.write(f"True critical threshold: {gamma_ppf(1-alpha, mu.cpu(), theta.cpu())[0]: .4f}")

        # Calculate loss
        loss = 1 - smooth_step(n_paths*mmd_h1_phi - crit_thresh, scale_factor).mean()
        
    elif method_ == "comparison":
        loss = torch.mean(smooth_step(mmd_h0_phi - mmd_h1_phi, scale_factor))
    elif method_ == "soft_quantile":
        crit_val = soft_quantile(mmd_h0_phi, tau, temperature=1.0)
        loss = smooth_step(crit_val - mmd_h1_phi, scale_factor).mean()
    elif method_ == "quantile":
        crit_val = torch.quantile(mmd_h0_phi, 0.95)
        loss = smooth_step(crit_val - mmd_h1_phi, scale_factor).mean()

    if torch.isnan(loss).any():
        trange_steps.write("NaN values found in output")
        break
        
    if loss == 0:
        trange_steps.write("Loss reached minimum. Exiting")
        break
    
    if (loss < 0.3) and redr_flag:
        redr_flag = False
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
    
    if (step + 1) % 16 == 0:
        trange_steps.write(f"Epoch {step+1}: Current loss: {loss: .4f}")
        with torch.no_grad():
            fprint_lambdas_ = phi.lambda_.data
            trange_steps.write(f"Current scalings: {', '.join([str(l.item())[:5] for l in fprint_lambdas_])}")
    loss.backward()
        
    optimizer.step()
    optimizer.zero_grad()
    
    with torch.no_grad():
        phi.lambda_.data = torch.clamp(phi.lambda_.data, min=1e-2, max=1e4)
    
    # Update 
    tr_loss[step]  = loss.item()
    lambdas_[step] = phi.lambda_

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
with torch.no_grad():
    ax1.plot(tr_loss.cpu().numpy(), color="dodgerblue", alpha=0.75, label="loss")
    ax2.plot(lambdas_.squeeze(-1).cpu(), color="tomato", alpha=0.75, label="$\lambda_1$")
    ax1.legend(loc="best")
    ax2.legend(loc="best")
    make_grid(axis=ax1)
    make_grid(axis=ax2)

ax1.set_title(f"Loss, batch size ${n_paths}$, ${n_atoms}$ atoms", fontsize="medium")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Type II Error")
ax2.set_title(f"Scalings, number = ${n_scalings}$", fontsize="medium")
ax2.set_xlabel("Epochs")
ax2.set_ylabel("Scaling")
fig.suptitle("Learning optimal scaling, BM$(0.2)$ vs BM$(0.3)$");
plt.tight_layout()
plt.savefig("optimal_scaling_loss.png", dpi=300)

In [None]:
# Traverse and find a good scaling, brute force for now, to guide the optimization
n_scalings = 32
scalings = torch.linspace(1e-2, 3e1, n_scalings)
n_atoms = 128
n_paths = 32

kernels = [sigkernel.SigKernel(static_kernel=static_kernel, dyadic_order=i) for i in range(3)]

res = torch.zeros((3, n_scalings))
time_normalisation = True

for k, kernel in enumerate(kernels):
    for j, lambd_ in enumerate(tqdm(scalings)):
        t_h0_paths = lambd_*h0_paths.clone().to(device)
        t_h1_paths = lambd_*h1_paths.clone().to(device)

        if time_normalisation:
            t_h0_paths[..., 0] /= lambd_*T
            t_h1_paths[..., 0] /= lambd_*T

        mmd_h0 = torch.zeros(n_atoms)
        mmd_h1 = torch.zeros(n_atoms)
        for i in range(n_atoms):
            h0_rands = torch.randperm(path_bank_size)[:int(2*n_paths)]

            ix, jx = h0_rands[:n_paths], h0_rands[n_paths:]
            iy     = torch.randperm(path_bank_size)[:n_paths]

            mmd_h0[i] = compute_biased_mmd(kernel, t_h0_paths[ix], t_h0_paths[jx])
            mmd_h1[i] = compute_biased_mmd(kernel, t_h0_paths[ix], t_h1_paths[iy])

        crit_val  = mmd_h0.sort()[0][int(n_atoms*(0.95))]
        res[k, j] = expected_type2_error(mmd_h1, crit_val)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
colors = sns.diverging_palette(250, 20, n=3, center="dark")
for i, r in enumerate(res):
    ax.plot(scalings, r, alpha=0.5, label=f"dyadic_order_{i}", color=colors[i])
ax.legend()