In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline

from itertools import combinations
import torch.nn.functional as F
from matplotlib.colors import ListedColormap
import os
import random
import time

## Initialisations

In [None]:
def mala_acceptance_probability(x_k: torch.Tensor, x_k_plus_1: torch.Tensor,
                               grad_log_pi_x_k: torch.Tensor, grad_log_pi_x_k_plus_1: torch.Tensor,
                               log_pi_x_k: torch.Tensor, log_pi_x_k_plus_1: torch.Tensor,
                               epsilon: float) -> torch.Tensor:
    proposal_mean_forward = x_k + epsilon * grad_log_pi_x_k
    diff_forward = x_k_plus_1 - proposal_mean_forward
    log_q_forward = -torch.sum(diff_forward ** 2) / (4 * epsilon)

    proposal_mean_backward = x_k_plus_1 + epsilon * grad_log_pi_x_k_plus_1
    diff_backward = x_k - proposal_mean_backward
    log_q_backward = -torch.sum(diff_backward ** 2) / (4 * epsilon)

    log_accept_ratio = (log_pi_x_k + log_q_backward - log_pi_x_k_plus_1 - log_q_forward)
    return torch.min(torch.ones_like(log_accept_ratio), torch.exp(log_accept_ratio))


def av_log_likelihood(model, X, Y, noise_variance=1.0):
    model.eval()
    with torch.no_grad():
        pred = model(X)
        log_probs = -0.5 * ((Y - pred)**2 / noise_variance + torch.log(torch.tensor(2 * np.pi * noise_variance)))
        return log_probs.mean()

In [None]:
def sgld_llc_estimator_behavioral(model, x_data, y_data=None, scale=1.0, step_size=1e-5, sgld_iters=1000, batch_size=32, burn_in=200, diagnostics = False):

    dataset_size = len(x_data)
    beta_star = 1.0 / np.log(dataset_size)

    model.eval()
    with torch.no_grad():
        f_theta_star = model(x_data).detach()

    # Save original weights
    w_star = {k: v.clone() for k, v in model.state_dict().items()}
    array_neg_behavioral_loss = []
    mala_probs = []

    for t in range(sgld_iters):
        idx = torch.randint(0, dataset_size, (batch_size,))
        x_batch = x_data[idx]
        f_star_batch = f_theta_star[idx]

        model.train()
        pred = model(x_batch)
        loss_before = ((pred - f_star_batch) ** 2).mean()  # Behavioral loss

        model.zero_grad()
        loss_before.backward()

        grads_before = torch.cat([p.grad.flatten() for p in model.parameters()])
        params_before = torch.cat([p.flatten() for p in model.parameters()])

        with torch.no_grad():
            for name, param in model.named_parameters():
                grad = param.grad
                eta = torch.randn_like(param) * np.sqrt(step_size)
                prior_force = scale * (w_star[name] - param)
                delta = (step_size / 2) * (prior_force + dataset_size * beta_star * (-grad)) + eta
                param.add_(delta)

        model.eval()
        with torch.no_grad():
            pred_post = model(x_batch)
            behav_loss = ((pred_post - f_star_batch) ** 2).mean()
            array_neg_behavioral_loss.append(-behav_loss.item())

        if t % 20 == 0 and t >= burn_in:
            model.train()
            pred = model(x_batch)
            loss_after = ((pred - f_star_batch) ** 2).mean()
            model.zero_grad()
            loss_after.backward()

            grads_after = torch.cat([p.grad.flatten() for p in model.parameters()])
            params_after = torch.cat([p.flatten() for p in model.parameters()])
            w_star_flat = torch.cat([w_star[name].flatten() for name in w_star])

            log_pi_before = beta_star * (-dataset_size * loss_before.item()) - 0.5 * scale * torch.sum((params_before - w_star_flat) ** 2)
            log_pi_after = beta_star * (-dataset_size * loss_after.item()) - 0.5 * scale * torch.sum((params_after - w_star_flat) ** 2)

            mala_prob = mala_acceptance_probability(params_before, params_after, grads_before, grads_after,
                                                    log_pi_before, log_pi_after, step_size)
            mala_probs.append(mala_prob.item())

    array_neg_behavioral_loss = np.array(array_neg_behavioral_loss[burn_in:])
    wbic_behavioral = -dataset_size * np.mean(array_neg_behavioral_loss)

    model.load_state_dict(w_star)
    model.eval()

    lambda_hat_behavioral = wbic_behavioral / np.log(dataset_size)

    if diagnostics:
      if mala_probs:
          print(f"Average MALA acceptance probability: {np.mean(mala_probs):.4f}")

    return lambda_hat_behavioral

In [None]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

class MonotonicLoss:
    def __init__(self, model, base_loss=None, lambda_mono=5.0, delta=1e-3):
        self.model = model
        self.base_loss = base_loss if base_loss is not None else nn.MSELoss()
        self.lambda_mono = lambda_mono
        self.delta = delta

    def __call__(self, pred, target, inputs):
        base = self.base_loss(pred, target)
        mono = self.monotonicity_penalty(inputs)
        return base + self.lambda_mono * mono, base

    def monotonicity_penalty(self, x):
        x = x.clone().detach().requires_grad_(True)
        y = self.model(x)
        grad = torch.autograd.grad(outputs=y.sum(), inputs=x, create_graph=True)[0]
        violations = F.relu(-grad)
        return violations.mean()

    def grad_stats(self, xb, yb):
        """Compute base vs penalty gradient dominance metrics for one minibatch."""
        params = list(self.model.parameters())

        pred = self.model(xb)
        base_loss = self.base_loss(pred, yb)
        mono_pen = self.monotonicity_penalty(xb)

        g_base = torch.autograd.grad(base_loss, params, retain_graph=True,
                                     create_graph=True, allow_unused=True)
        g_pen = torch.autograd.grad(mono_pen, params, retain_graph=True,
                                    create_graph=True, allow_unused=True)

        g_base = [torch.zeros_like(p) if g is None else g for g,p in zip(g_base,params)]
        g_pen  = [torch.zeros_like(p) if g is None else g for g,p in zip(g_pen, params)]

        g_base = torch.cat([g.view(-1) for g in g_base])
        g_pen  = torch.cat([g.view(-1) for g in g_pen])

        # scaled penalty gradient (as it actually enters F)
        g_pen_scaled = self.lambda_mono * g_pen
        g_total = g_base + g_pen_scaled

        Rg = (g_pen_scaled.norm() / (g_base.norm() + 1e-12)).item()
        C_lp = torch.nn.functional.cosine_similarity(g_base, g_pen_scaled, dim=0).item()
        Dg = (g_pen_scaled @ g_total).item() / ((g_base @ g_total).item() + 1e-12)

        return Rg, C_lp, Dg


class DeepNeuralNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, loss_fn=None):
        super().__init__()

        layers = []
        dims = [input_dim] + hidden_dims + [output_dim]
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1], bias=True))
            if i < len(dims) - 2:
                layers.append(nn.ReLU())
        self.model = nn.Sequential(*layers)

        if loss_fn == 'mono':
            self.loss_fn = MonotonicLoss(self)
        else:
            self.loss_fn = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def train_model(self, X_train, y_train, epochs=200, lr=1e-2, batch_size=64):

        X_test = X_train[:1000]
        y_test = y_train[:1000]
        X_train = X_train[1000:]
        y_train = y_train[1000:]

        param_trajectory = []

        self.train()
        optimizer = torch.optim.SGD(self.parameters(), lr=lr)
        dataset = torch.utils.data.TensorDataset(X_train, y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        epoch_losses, epoch_base_losses = [], []
        converged_point, converged_loss = None, None
        previous_loss = float('inf')
        cosine_sims, prev_grad, time_elapsed = [], None, []

        grad_ratios, grad_cosines, grad_dir_ratios = [], [], []

        # initial test loss
        self.eval()
        pred_init = self.forward(X_test)
        if isinstance(self.loss_fn, MonotonicLoss):
            _, init_loss = self.loss_fn(pred_init, y_test, X_test)
            init_loss = init_loss.item()
        else:
            init_loss = self.loss_fn(pred_init, y_test).item()
        self.train()

        for epoch in range(epochs):
            time_start = time.time()
            epoch_loss, epoch_base_loss = 0.0, 0.0

            for xb, yb in dataloader:
                pred = self(xb)
                if isinstance(self.loss_fn, MonotonicLoss):
                    loss, base_loss = self.loss_fn(pred, yb, xb)
                    epoch_base_loss += base_loss.item() * xb.size(0)
                else:
                    loss = self.loss_fn(pred, yb)
                    epoch_base_loss += loss.item() * xb.size(0)

                optimizer.zero_grad()
                loss.backward()

                curr_grad = torch.cat([p.grad.view(-1) for p in self.parameters() if p.grad is not None])
                if prev_grad is not None:
                    cos_sim = torch.nn.functional.cosine_similarity(curr_grad, prev_grad, dim=0).item()
                    cosine_sims.append(cos_sim)
                prev_grad = curr_grad.detach()

                optimizer.step()
                epoch_loss += loss.item() * xb.size(0)

                if isinstance(self.loss_fn, MonotonicLoss):
                    Rg, C_lp, Dg = self.loss_fn.grad_stats(xb, yb)
                    grad_ratios.append(Rg)
                    grad_cosines.append(C_lp)
                    grad_dir_ratios.append(Dg)

            avg_loss = epoch_loss / len(dataloader.dataset)
            avg_base_loss = epoch_base_loss / len(dataloader.dataset)
            epoch_losses.append(avg_loss)
            epoch_base_losses.append(avg_base_loss)

            if (epoch + 1) % 50 == 0 or epoch < 15:
                if isinstance(self.loss_fn, MonotonicLoss):
                    print(f"Epoch {epoch+1:3d} | Total Loss: {avg_loss:.6f} | MSE Loss: {avg_base_loss:.6f}")
                else:
                    print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.6f}")
                param_vector = torch.cat([p.detach().flatten() for p in self.parameters()])
                param_trajectory.append(param_vector.clone().cpu())

            loss_change = previous_loss - avg_base_loss
            if converged_point is None and loss_change < 1e-2:
                converged_point, converged_loss = epoch + 1, avg_base_loss
                print(f"Converged at epoch {converged_point}")
            elif converged_point is not None and loss_change >= 1e-2:
                converged_point, converged_loss = None, None
                print(f"Convergence broken at epoch {epoch+1}")
            previous_loss = avg_base_loss

            time_elapsed.append(time.time() - time_start)

        avg_cos_sim = sum(cosine_sims)/len(cosine_sims) if cosine_sims else None

        # test evaluation
        self.eval()
        with torch.no_grad():
            y_test_pred = self(X_test)
            mse_loss_fn = nn.MSELoss()
            test_loss_val = mse_loss_fn(y_test_pred, y_test).item()
        print(f"Final test loss on held-out 1000 points: {test_loss_val:.6f}")

        if isinstance(self.loss_fn, MonotonicLoss) and grad_ratios:
            print(f"Avg gradient norm ratio (penalty/base): {sum(grad_ratios)/len(grad_ratios):.3f}")
            print(f"Avg grad cosine (base vs penalty): {sum(grad_cosines)/len(grad_cosines):.3f}")
            print(f"Avg directional dominance ratio: {sum(grad_dir_ratios)/len(grad_dir_ratios):.3f}")

        if converged_point:
            return init_loss, converged_point, converged_loss, \
                   converged_point * sum(time_elapsed)/len(time_elapsed), \
                   avg_cos_sim, test_loss_val, param_trajectory
        else:
            return init_loss, None, None, None, None, test_loss_val, param_trajectory


In [None]:
class MonotoneLinear(nn.Module):
    def __init__(self, in_features, out_features, bias = True):
        super().__init__()
        self.weight_unconstrained = nn.Parameter(torch.randn(out_features, in_features))
        if bias:
            self.bias_unconstrained = nn.Parameter(torch.randn(out_features))
        else:
            self.register_parameter("bias_unconstrained", None)

    def forward(self, x):
        weight = F.softplus(self.weight_unconstrained)
        bias = self.bias_unconstrained
        return F.linear(x, weight, bias)

class MonotoneNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, loss_fn=None):
        super().__init__()

        if loss_fn is None:
            self.loss_fn = nn.MSELoss()
        else:
            self.loss_fn = loss_fn

        layers = []
        dims = [input_dim] + hidden_dims

        for i in range(len(hidden_dims)):
            layers.append(MonotoneLinear(dims[i], dims[i + 1]))
            layers.append(nn.ReLU())

        layers.append(MonotoneLinear(dims[-1], output_dim))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

    def train_model(self, X_train, y_train, epochs=200, lr=1e-2, batch_size=64):

        X_test = X_train[:1000]
        y_test = y_train[:1000]
        X_train = X_train[1000:]
        y_train = y_train[1000:]

        self.train()
        optimizer = torch.optim.SGD(self.parameters(), lr=lr)
        loss_fn = self.loss_fn

        converged_point = None
        converged_loss = None
        previous_loss = float('inf')

        cosine_sims = []
        prev_grad = None

        dataset = torch.utils.data.TensorDataset(X_train, y_train)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        epoch_losses = []
        time_elapsed = []

        self.eval()
        with torch.no_grad():
            pred_init = self.forward(X_test)
            init_loss = self.loss_fn(pred_init, y_test).item()
        self.train()

        convergence_broken = False

        for epoch in range(epochs):
            time_start = time.time()
            epoch_loss = 0.0
            for xb, yb in dataloader:
                pred = self(xb)
                loss = loss_fn(pred, yb)

                optimizer.zero_grad()
                loss.backward()

                curr_grad = torch.cat([p.grad.view(-1) for p in self.parameters() if p.grad is not None])
                if prev_grad is not None:
                    cos_sim = torch.nn.functional.cosine_similarity(curr_grad, prev_grad, dim=0).item()
                    cosine_sims.append(cos_sim)
                prev_grad = curr_grad.detach()

                optimizer.step()
                epoch_loss += loss.item() * xb.size(0)

            avg_loss = epoch_loss / len(dataloader.dataset)
            epoch_losses.append(avg_loss)

            if (epoch + 1) % 50 == 0 or epoch < 15:
                print(f"Epoch {epoch + 1:3d} | Loss: {avg_loss:.6f}")

            loss_change = previous_loss - avg_loss

            if converged_point is None:
                if loss_change < 1e-2:
                    converged_point = epoch + 1
                    converged_loss = avg_loss
                    print(f"Converged at epoch {converged_point}")
            else:
                if loss_change >= 1e-2:
                    converged_point = None
                    converged_loss = None
                    convergence_broken = True
                    print(f"Convergence broken at epoch {epoch + 1}")

            previous_loss = avg_loss
            time_end = time.time()
            time_elapsed.append(time_end - time_start)

        if converged_point is not None:
            total_steps = len(cosine_sims)
            steps_per_epoch = total_steps // epochs
            cutoff = steps_per_epoch * converged_point
            cosine_sims = cosine_sims[:cutoff]

        avg_cos_sim = sum(cosine_sims) / len(cosine_sims) if cosine_sims else None

        self.eval()
        with torch.no_grad():
            y_test_pred = self(X_test)
            mse_loss_fn = nn.MSELoss()
            test_loss_val = mse_loss_fn(y_test_pred, y_test).item()

        print(f"Final test loss on held-out 1000 points: {test_loss_val:.6f}")

        if converged_point:
          return init_loss, converged_point, converged_loss, converged_point * sum(time_elapsed)/len(time_elapsed), avg_cos_sim, test_loss_val
        else:
          return init_loss, None, None, None, None, test_loss_val

## Visualisation Code

In [None]:
def visualise_2d_loss_grid(x_data, y_data, nn, param_range=2.0, resolution=50,
                           trajectory=None, savepath='monotonic.png'):

    params = list(nn.parameters())
    original_params = [p.clone() for p in params]
    original_vector = torch.cat([p.flatten() for p in params])
    n_params = len(original_vector)

    fig, axes = plt.subplots(n_params, n_params, figsize=(3*n_params, 3*n_params))
    if n_params == 1:
        axes = [[axes]]
    elif n_params == 2:
        axes = axes.reshape(2, 2)

    nn.eval()

    for i in range(n_params):
        for j in range(n_params):
            ax = axes[i][j]

            if i == j:
                # 1D slice
                orig_val = original_vector[i].item()
                param_vals = np.linspace(orig_val - param_range, orig_val + param_range, resolution)
                losses = []

                for val in param_vals:
                    current_vector = original_vector.clone()
                    current_vector[i] = val
                    _set_parameters_from_vector(nn, params, current_vector)
                    pred = nn(x_data)
                    if isinstance(nn.loss_fn, MonotonicLoss):
                        _, loss = nn.loss_fn(pred, y_data, x_data)
                    else:
                        loss = nn.loss_fn(pred, y_data)
                    losses.append(loss.item())

                ax.plot(param_vals, losses, 'b-', linewidth=2)
                ax.axvline(orig_val, color='red', linestyle='--', alpha=0.7)

                # ✅ overlay trajectory points
                if trajectory is not None:
                    traj_vals = [vec[i].item() for vec in trajectory]
                    ax.plot(traj_vals, [None]*len(traj_vals), 'ko--', markersize=4)

                ax.set_xlabel(f"W[{i}]")
                ax.set_ylabel('Loss')
                ax.set_title(f'Param {i} slice')
                ax.grid(True, alpha=0.3)

            else:
                # 2D slice
                pi_orig = original_vector[i].item()
                pj_orig = original_vector[j].item()
                pi_range = np.linspace(pi_orig - param_range, pi_orig + param_range, resolution)
                pj_range = np.linspace(pj_orig - param_range, pj_orig + param_range, resolution)
                Pi, Pj = np.meshgrid(pi_range, pj_range)
                loss_surface = np.zeros_like(Pi)

                for ii in range(resolution):
                    for jj in range(resolution):
                        current_vector = original_vector.clone()
                        current_vector[i] = Pi[ii, jj]
                        current_vector[j] = Pj[ii, jj]
                        _set_parameters_from_vector(nn, params, current_vector)
                        pred = nn(x_data)
                        if isinstance(nn.loss_fn, MonotonicLoss):
                            loss, _ = nn.loss_fn(pred, y_data, x_data)
                        else:
                            loss = nn.loss_fn(pred, y_data)
                        loss_surface[ii, jj] = loss.item()

                contourf_plot = ax.contourf(Pi, Pj, loss_surface, levels=12, cmap='viridis', alpha=0.6)
                ax.contour(Pi, Pj, loss_surface, levels=12, cmap='viridis')
                ax.plot(pi_orig, pj_orig, 'ro', markersize=4)

                # ✅ overlay trajectory
                if trajectory is not None:
                    traj_pi = [vec[i].item() for vec in trajectory]
                    traj_pj = [vec[j].item() for vec in trajectory]
                    ax.plot(traj_pi, traj_pj, 'ko--', markersize=4)

                plt.colorbar(contourf_plot, ax=ax, shrink=0.8).set_label('Loss', rotation=270, labelpad=15)

                ax.set_xlabel(f"W[{i}]")
                ax.set_ylabel(f"W[{j}]")
                ax.set_title(f"W[{i}] vs W[{j}]")
                ax.grid(True, alpha=0.3)

    for param, orig in zip(params, original_params):
        param.data.copy_(orig.data)

    plt.tight_layout()
    plt.savefig(savepath, dpi=300, bbox_inches='tight')
    plt.show()


def _set_parameters_from_vector(nn, params, param_vector):
    """Helper function to set network parameters from a flattened vector."""
    start_idx = 0
    for param in params:
        param_size = param.numel()
        param.data = param_vector[start_idx:start_idx + param_size].reshape(param.shape)
        start_idx += param_size

In [None]:
def check_monotonicity_violation(nn, x_data):
    """
    Check if current network parameters violate monotonicity constraints.
    Returns the total violation (0 means feasible).
    - Uses PURELY the weights, no loss function involved.
    """
    nn.eval()

    x_data = x_data.clone().detach().requires_grad_(True)
    y_pred = nn(x_data)

    grad = torch.autograd.grad(
        outputs=y_pred.sum(),
        inputs=x_data,
        create_graph=False,
        retain_graph=False
    )[0]

    violations = F.relu(-grad)
    return violations.mean()

In [None]:
def visualise_feasible_region_mono(x_data, y_data, nn, param_range=2.0, resolution=50,
                                  violation_threshold=0, savepath='constraint.png'):
    """
    Visualize the feasible region where monotonicity constraints are satisfied,
    overlaid with the loss landscape.
    - Feasible region (green/red) depends ONLY on weights (gradient check).
    - Loss contours depend on the model's loss function (MSE vs. MonotonicLoss).
    """
    params = list(nn.parameters())
    param_names = [f"W{i}[{j},{k}]" if len(p.shape) > 1 else f"W{i}[{j}]"
                   for i, p in enumerate(params) for j in range(p.shape[0])
                   for k in (range(p.shape[1]) if len(p.shape) > 1 else [0])]

    original_params = [p.clone() for p in params]
    original_vector = torch.cat([p.flatten() for p in params])
    n_params = len(original_vector)

    fig, axes = plt.subplots(n_params, n_params, figsize=(4*n_params, 4*n_params))
    if n_params == 1:
        axes = [[axes]]
    elif n_params == 2:
        axes = axes.reshape(2, 2)

    nn.eval()

    for i in range(n_params):
        for j in range(n_params):
            ax = axes[i][j]

            if i == j:
                # 1D slice: Plot loss curve with feasible regions colored
                orig_val = original_vector[i].item()
                param_vals = np.linspace(orig_val - param_range, orig_val + param_range, resolution)
                losses = []
                violations = []

                for val in param_vals:
                    current_vector = original_vector.clone()
                    current_vector[i] = val
                    _set_parameters_from_vector(nn, params, current_vector)

                    # Compute loss (depends on model's loss function)
                    pred = nn(x_data)
                    if isinstance(nn.loss_fn, MonotonicLoss):
                        _, loss = nn.loss_fn(pred, y_data, x_data)
                    else:
                        loss = nn.loss_fn(pred, y_data)
                    losses.append(loss.item())

                    # Compute violation (independent of loss function)
                    violation = check_monotonicity_violation(nn, x_data)
                    violations.append(violation.item())

                # Color segments based on feasibility
                for k in range(len(param_vals) - 1):
                    color = 'green' if violations[k] <= violation_threshold else 'red'
                    ax.plot([param_vals[k], param_vals[k + 1]], [losses[k], losses[k + 1]],
                           color=color, linewidth=3, alpha=0.7)

                ax.axvline(orig_val, color='blue', linestyle='--', alpha=0.7, label='Original')
                ax.set_xlabel(param_names[i])
                ax.set_ylabel('Loss')
                ax.set_title(f'{param_names[i]} slice\n(Green=Feasible, Red=Infeasible)')
                ax.grid(True, alpha=0.3)
                ax.legend()

            else:
                # 2D slice: Plot loss contours with feasible overlay
                pi_orig = original_vector[i].item()
                pj_orig = original_vector[j].item()

                pi_range = np.linspace(pi_orig - param_range, pi_orig + param_range, resolution)
                pj_range = np.linspace(pj_orig - param_range, pj_orig + param_range, resolution)
                Pi, Pj = np.meshgrid(pi_range, pj_range)

                loss_surface = np.zeros_like(Pi)
                feasible_mask = np.zeros_like(Pi, dtype=bool)

                for ii in range(resolution):
                    for jj in range(resolution):
                        current_vector = original_vector.clone()
                        current_vector[i] = Pi[ii, jj]
                        current_vector[j] = Pj[ii, jj]
                        _set_parameters_from_vector(nn, params, current_vector)

                        # Compute loss (model-specific)
                        pred = nn(x_data)
                        if isinstance(nn.loss_fn, MonotonicLoss):
                            loss, _ = nn.loss_fn(pred, y_data, x_data)
                        else:
                            loss = nn.loss_fn(pred, y_data)
                        loss_surface[ii, jj] = loss.item()

                        # Compute feasibility (weight-based only)
                        violation = check_monotonicity_violation(nn, x_data)
                        feasible_mask[ii, jj] = violation <= violation_threshold

                # Plot loss contours (model-specific)
                contour = ax.contour(Pi, Pj, loss_surface, levels=12, colors='gray', alpha=0.5)
                ax.clabel(contour, inline=True, fontsize=8)

                # Overlay feasible region (weight-based only)
                feasible_colored = np.where(feasible_mask, 1, 0)
                cmap = ListedColormap(['red', 'lightgreen'])
                ax.contourf(Pi, Pj, feasible_colored, levels=[0, 0.5, 1], cmap=cmap, alpha=0.4)

                # Mark original point
                _set_parameters_from_vector(nn, params, original_vector)
                original_violation = check_monotonicity_violation(nn, x_data)
                original_color = 'blue' if original_violation <= violation_threshold else 'orange'
                ax.plot(pi_orig, pj_orig, 'o', color=original_color, markersize=8,
                        label=f'Original ({"Feasible" if original_violation <= violation_threshold else "Infeasible"})')

                ax.set_xlabel(param_names[i])
                ax.set_ylabel(param_names[j])
                ax.set_title(f'{param_names[i]} vs {param_names[j]}\n(Green=Feasible, Red=Infeasible)')
                ax.grid(True, alpha=0.3)
                ax.legend()

    # Restore original parameters
    for param, orig in zip(params, original_params):
        param.data.copy_(orig.data)

    plt.tight_layout()
    plt.savefig(savepath, dpi=300, bbox_inches='tight')
    plt.show()

## Begin Training

In [None]:
torch.manual_seed(10)
np.random.seed(10)
random.seed(10)

In [None]:
def clone_weights(model):
    return [param.detach().clone() for param in model.parameters()]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# store metrics
base_losses, loss_losses, arch_losses = [], [], []

init_weights_all = []
final_weights_base = []
final_weights_loss = []
final_weights_arch = []

base_distances = []
loss_distances = []
arch_distances = []

def clone_monotone_weights(model):
    """Clones the unconstrained weights for MonotoneNetwork."""
    weights = []
    for layer in model.model:
        if isinstance(layer, MonotoneLinear):
            weights.append(layer.weight_unconstrained.detach().clone())
            if layer.bias_unconstrained is not None:
                 weights.append(layer.bias_unconstrained.detach().clone())
        elif isinstance(layer, nn.Linear):
             weights.append(layer.weight.detach().clone())
             if layer.bias is not None:
                  weights.append(layer.bias.detach().clone())
    return weights


for run in range(10):
    # Generate synthetic data
    example_mono_x = torch.randn(2000, 2) * 4
    example_mono_x = F.softplus(example_mono_x)

    A, alpha, beta = 1.0, 0.6, 0.4
    example_mono_y = A * (example_mono_x[:, 0] ** alpha) * (example_mono_x[:, 1] ** beta)
    example_mono_y = example_mono_y.unsqueeze(1)

    # Models
    baseline_nn = DeepNeuralNetwork(2, [2], 1)
    loss_nn = DeepNeuralNetwork(2, [2], 1, loss_fn='mono')
    arch_nn = MonotoneNetwork(2, [2], 1)

    # Clone initial weights (using the standard clone_weights for DeepNeuralNetwork)
    init_weights_base = clone_weights(baseline_nn)
    init_weights_loss = clone_weights(loss_nn)
    # Clone initial weights for MonotoneNetwork
    init_weights_arch = clone_monotone_weights(arch_nn)


    # Copy initial weights into models
    with torch.no_grad():
        for p, w in zip(baseline_nn.parameters(), init_weights_base):
            p.copy_(w)
        for p, w in zip(loss_nn.parameters(), init_weights_loss):
            p.copy_(w)

        # Copy initial weights into MonotoneNetwork (unconstrained weights)
        i = 0
        for layer in arch_nn.model:
            if isinstance(layer, MonotoneLinear):
                layer.weight_unconstrained.copy_(init_weights_arch[i])
                i += 1
                if layer.bias_unconstrained is not None:
                    layer.bias_unconstrained.copy_(init_weights_arch[i])
                    i += 1


    # Train
    base_losses.append(baseline_nn.train_model(example_mono_x, example_mono_y))
    loss_losses.append(loss_nn.train_model(example_mono_x, example_mono_y))
    arch_losses.append(arch_nn.train_model(example_mono_x, example_mono_y))

    # Final weights
    final_w_base = clone_weights(baseline_nn)
    final_w_loss = clone_weights(loss_nn)
    final_w_arch = clone_monotone_weights(arch_nn)

    final_weights_base.append(final_w_base)
    final_weights_loss.append(final_w_loss)
    final_weights_arch.append(final_w_arch)


    # Distances (squared difference norm) - Flatten and concatenate for comparison
    base_diff_sq = [(f - i) ** 2 for f, i in zip(final_w_base, init_weights_base)]
    loss_diff_sq = [(f - i) ** 2 for f, i in zip(final_w_loss, init_weights_loss)]
    arch_diff_sq = [(f - i) ** 2 for f, i in zip(final_w_arch, init_weights_arch)]


    base_distances.append(torch.mean(torch.cat([d.flatten() for d in base_diff_sq])))
    loss_distances.append(torch.mean(torch.cat([d.flatten() for d in loss_diff_sq])))
    arch_distances.append(torch.mean(torch.cat([d.flatten() for d in arch_diff_sq])))

    # Store all initial weights for averaging
    init_weights_all.append(torch.cat([w.flatten() for w in init_weights_base])) # Use base as representative


# === Summary statistics ===
print("Average initial weight value:", torch.mean(torch.cat(init_weights_all)))
print("Average final weight (Base):", torch.mean(torch.cat([torch.cat([w.flatten() for w in weights]) for weights in final_weights_base])))
print("Average final weight (Loss):", torch.mean(torch.cat([torch.cat([w.flatten() for w in weights]) for weights in final_weights_loss])))
print("Average final weight (Arch):", torch.mean(torch.cat([torch.cat([w.flatten() for w in weights]) for weights in final_weights_arch])))

print("Average distance (Base):", torch.mean(torch.stack(base_distances)))
print("Average distance (Loss):", torch.mean(torch.stack(loss_distances)))
print("Average distance (Arch):", torch.mean(torch.stack(arch_distances)))

Epoch   1 | Loss: 1.871620
Epoch   2 | Loss: 0.842809
Epoch   3 | Loss: 0.724643
Epoch   4 | Loss: 0.654878
Epoch   5 | Loss: 0.644125
Epoch   6 | Loss: 0.634808
Converged at epoch 6
Epoch   7 | Loss: 0.638245
Epoch   8 | Loss: 0.639710
Epoch   9 | Loss: 0.637361
Epoch  10 | Loss: 0.635154
Epoch  11 | Loss: 0.634940
Epoch  12 | Loss: 0.637070
Epoch  13 | Loss: 0.636958
Epoch  14 | Loss: 0.636602
Epoch  15 | Loss: 0.638195
Convergence broken at epoch 42
Converged at epoch 43
Epoch  50 | Loss: 0.638882
Convergence broken at epoch 55
Converged at epoch 56
Convergence broken at epoch 59
Converged at epoch 60
Convergence broken at epoch 64
Converged at epoch 65
Convergence broken at epoch 66
Converged at epoch 67
Convergence broken at epoch 78
Converged at epoch 79
Convergence broken at epoch 91
Converged at epoch 92
Epoch 100 | Loss: 0.639124
Convergence broken at epoch 149
Epoch 150 | Loss: 0.638067
Converged at epoch 150
Convergence broken at epoch 187
Converged at epoch 188
Epoch 200 | 

In [None]:
print(base)
print(loss)
print(arch)

NameError: name 'base' is not defined

In [None]:
# del loss[9]
# del loss[7]
# del loss[4]
# del loss[1]
# del loss[0]

# del base[9]
# del base[7]
# del base[4]
# del base[1]
# del base[0]

# del arch[9]
# del arch[7]
# del arch[4]
# del arch[1]
# del arch[0]

In [None]:
# base = [(a, b, c, d, g) for (a, b, c, d, _, g) in base]
# loss = [(a, b, c, d, g) for (a, b, c, d, _, g) in loss]
# arch = [(a, b, c, d, g) for (a, b, c, d, _, g) in arch]

In [None]:
# base = [sum(x) / len(x) for x in zip(*base)]
# loss = [sum(x) / len(x) for x in zip(*loss)]
# arch = [sum(x) / len(x) for x in zip(*arch)]

In [None]:
base

In [None]:
loss

In [None]:
arch

In [None]:
base_trajectories = base[-1][-1]
loss_trajectories = loss[-1][-1]

In [None]:
visualise_2d_loss_grid(example_mono_x, example_mono_y, baseline_nn, trajectory=base_trajectories,  savepath='baseline.png')
visualise_2d_loss_grid(example_mono_x, example_mono_y, loss_nn, trajectory=loss_trajectories, savepath='loss.png')
# visualise_2d_loss_grid(example_mono_x, example_mono_y, arch_nn, savepath='arch.png')

In [None]:
# visualise_feasible_region_mono(example_mono_x, example_mono_y, baseline_nn, savepath='baseline_f.png')
# visualise_feasible_region_mono(example_mono_x, example_mono_y, loss_nn, savepath='loss_f.png')
# visualise_feasible_region_mono(example_mono_x, example_mono_y, arch_nn, savepath='arch_f.png')

In [None]:
print(clone_weights(baseline_nn))
print(clone_weights(loss_nn))
print(clone_weights(arch_nn))