In [1]:
%cd ../../
%pwd

/Users/lucien/Documents/projects/torch_icnn


'/Users/lucien/Documents/projects/torch_icnn'

In [2]:
# TODO; experiment with a univariate version that handles the
# monotonoicity and convexity constraints.

import pandas as pd
import numpy as np
import altair as alt
from typing import Literal, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader


from torch_icnn.networks import (
    ConstraintSpec,
    PartiallyConvexNetwork,
    PartiallyConcaveNetwork,
    PartiallyMixedNetwork,
)
from torch import nn


In [None]:
def apply_col_monotone(
    raw_weight: torch.Tensor, mono_list: list | None, softplus: nn.Softplus
) -> torch.Tensor:
    """Apply per-column monotonicity constraints to a weight matrix.

    This mirrors the previous behavior but is a standalone helper so that
    multiple modules can reuse it.
    """
    if mono_list is None:
        return raw_weight
    _, n_in = raw_weight.shape
    if n_in == 0:
        return raw_weight
    raw_pos = softplus(raw_weight)
    W = raw_weight.clone()
    for j, m in enumerate(mono_list):
        if j >= n_in:
            break
        if m == "increasing":
            W[:, j] = raw_pos[:, j]
        elif m == "decreasing":
            W[:, j] = -raw_pos[:, j]
        # else keep raw
    return W


class UnivariateICNN(nn.Module):
    def __init__(
        self,
        hidden_sizes: list[int] = [16, 16],
        activation: nn.Module = nn.Softplus(),
        monotonic: Literal["increasing", "decreasing", "free"] = "free",
    ) -> None:
        """
        Class for building a univariate ICNN. Follows Amos et al. 2017, Eq. 2 closely.
        Small differences:
        * added parameters (one per layter) for scaling z_i before applying the monotonic,
        convex activation function g_i(z_i). This keeps the values of z_i in the most
        expressive part of g_i. Without it, the fit looks too much like g_i.
        * added biases (one per layer) to add/subtract from g_i(z_i). This allows
        outputs to be negative even if the activation function is always positive,
        as with Softplus.

        Parameters
        ----------
        hidden_sizes : list[int], optional
            Sizes of hidden layers, by default [16, 16]
        activation : nn.Module, optional
            Convex, monotonic activation function, by default nn.Softplus()
        monotonic : Literal["increasing", "decreasing", "free"], optional
            Whether the network should be monotonic increasing, decreasing, or free, by default "free"
        """
        super(UnivariateICNN, self).__init__()

        # Initialize Wx, which feeds inputs into each hidden layer.
        Wx = nn.ParameterList()
        for i, h_size in enumerate(hidden_sizes):
            Wx.append(
                nn.Parameter(torch.randn(h_size, 1))  # Input is 1D.
                if i == 0
                else nn.Parameter(torch.randn(h_size, 1))
            )
        Wx.append(nn.Parameter(torch.randn(1, 1)))

        # Initialize Wiz, which evolves the hidden state.
        Wi = nn.ParameterList()
        for i, h_size in enumerate(hidden_sizes):
            Wi.append(
                torch.zeros(h_size, 1)  # Zero in first layer.
                if i == 0
                else nn.Parameter(torch.randn(h_size, hidden_sizes[i - 1]))
            )
        Wi.append(nn.Parameter(torch.randn(1, hidden_sizes[-1])))

        # Initialize the biases
        biases_in = nn.ParameterList()
        biases_out = nn.ParameterList()
        z_scales = nn.ParameterList()
        for i, h_size in enumerate(hidden_sizes):
            biases_in.append(nn.Parameter(torch.randn(h_size)))
            biases_out.append(nn.Parameter(torch.randn(1)))
            z_scales.append(nn.Parameter(torch.randn(1)))
        biases_in.append(nn.Parameter(torch.randn(1)))
        biases_out.append(nn.Parameter(torch.randn(1)))
        z_scales.append(nn.Parameter(torch.randn(1)))

        self.Wx = Wx
        self.Wi = Wi
        self.biases_in = biases_in
        self.biases_out = biases_out
        self.z_scales = z_scales
        self.activation = activation
        self.mono_list = [monotonic]
        self.softplus = nn.Softplus()

    def forward(self, x):
        z = torch.zeros(x.size(0), 1)
        for i in range(len(self.Wx)):
            # compute pre-activation (z contribution, input contribution, bias)
            z = (
                F.linear(z, nn.Softplus()(self.Wi[i]))
                + F.linear(
                    x,
                    apply_col_monotone(
                        self.Wx[i], mono_list=self.mono_list, softplus=self.softplus
                    ),
                )
                + self.biases_in[i]
            )
            # batch-normalize pre-activation; note: small batch sizes can make this noisy
            z = (
                self.activation(z * nn.Softplus()(self.z_scales[i]))
                + self.biases_out[i]
            )
        return z


class MultivariateICNN(nn.Module):
    def __init__(
        self,
        hidden_sizes: list[int] = [16, 16],
        input_dim: int = 2,
        activation: nn.Module = nn.Softplus(),
        constraints: list[ConstraintSpec] | None = None,
        n_outputs: int = 1,
    ) -> None:
        """Multivariate ICNN-style network with per-input convexity/concavity/monotonicity constraints.

        Architecture (overview)
        - Splits input x = (y, n, f) into three groups:
            - y : inputs declared `convex` (may be empty)
            - n : inputs declared `concave` (may be empty)
            - f : inputs declared `free` (may be empty)
        - Builds ICNN-style layers that combine:
            - convex contributions from `y`,
            - concave contributions implemented as negated convex blocks from `n`,
            - unconstrained contributions from `f`.
        - Uses per-column **monotonicity** enforcement (Softplus sign parametrization)
        and non-negativity on recurrence/readout where required.

        Per-layer recursion (compact)
        Let z^{(0)} := 0. For layer l = 0..L-1:

        z^{(l+1)} = (
            sigma( s_cv^{(l)} * ( Wu_cv^{(l)} z^{(l)} + Wx_cv^{(l)} y + bx_cv^{(l)} )
            - sigma(s_cc^{(l)} * ( Wu_cc^{(l)} z^{(l)} + Wx_cc^{(l)} n + bx_cc^{(l)} )
            + sigma(s_fr^{(l)} * (Wu_fr^{(l)} z^{(l)} + Wx_fr^{(l)} f + bx_fr^{(l)} )
        )

        Where z^{0} = 0, z^{(L)} = output, s^{(l)} must be positive,
        and Wu_cv^{(l)} and Wu_cc^{(l)} must all be positive.

        Final readout (scalar or vector):
        f(x) = w^T z^{(L)} + lin_y(y) + lin_f(f) + b

        Guarantees (why convexity / monotonicity hold)
        - Activation σ must be **convex & non-decreasing** (e.g., ReLU, Softplus).
        - Recurrence matrices U are enforced elementwise ≥ 0 (implemented via Softplus on raw parameters),
        and readout weights w are enforced ≥ 0 for convex components.
        - Convexity in `y`: by induction, z^{(l)} is convex in y because (a) linear W_{x,cv} y is linear, (b) Upos z^{(l)} is a non-negative linear combination of convex z^{(l)} (hence convex), (c) σ preserves convexity when convex & non-decreasing, and (d) a non-negative readout preserves convexity of the final scalar.
        - Monotonicity per coordinate j: all columns that depend on x_j are constrained to have fixed sign
        (Softplus for positive, -Softplus for negative) across every matrix that consumes that input (W_{x,*}, lin terms, etc.).
        Since σ is non-decreasing, each path derivative is a product of fixed-sign terms, so ∂f/∂x_j has the requested sign globally.
        - Concavity: implemented by building a convex subnetwork h for concave dims and including −h(x); flipping monotonicity appropriately ensures correct monotone behavior after negation.

        Notes
        - Monotonicity/concavity constraints are enforced smoothly at forward time via Softplus parametrization (differentiable).
        - If you need full unconstrained expressivity when there are no `convex` dims, consider using a plain MLP for the corresponding paths (this is a safe fallback to avoid excessive inductive bias).

        Parameters
        ----------
        hidden_sizes : list[int], optional
            _description_, by default [16, 16]
        input_dim : int, optional
            _description_, by default 2
        activation : nn.Module, optional
            _description_, by default nn.Softplus()
        constraints : list[ConstraintSpec] | None, optional
            _description_, by default None
        n_outputs : int, optional
            _description_, by default 1

        Returns
        -------
        _type_
            _description_

        Raises
        ------
        ValueError
            _description_
        ValueError
            _description_
        """
        super(MultivariateICNN, self).__init__()
        self.activation = activation

        # parse constraints
        if constraints is None:
            # If no constraints, everything is free/free
            constraints = [
                ConstraintSpec(convexity="free", monotonicity="free")
                for _ in range(input_dim)
            ]
        # Any None constraints are free/free
        constraints = [
            ConstraintSpec(convexity="free", monotonicity="free") if c is None else c
            for c in constraints
        ]

        if len(constraints) != input_dim:
            raise ValueError("constraints length must equal input_dim")
        
        if any(c.convexity == 'concave' for c in constraints):
            raise NotImplementedError("Concavity not properly enforced.")
            # Figure out why, try to fix. If can't fix, switch to separate 
            # concave/convex with shared free variables. 

        # Figure out what input indices correspond with which convexities, which of those
        # are monotonic increasing/decreasing/free.
        self.cv_idx = [i for i, c in enumerate(constraints) if c.convexity == "convex"]
        self.cc_idx = [i for i, c in enumerate(constraints) if c.convexity == "concave"]
        self.fr_idx = [i for i, c in enumerate(constraints) if c.convexity == "free"]
        self.cv_mono = [constraints[i].monotonicity for i in self.cv_idx]
        self.cc_mono_r = [constraints[i].monotonicity_r for i in self.cc_idx]
        self.fr_mono = [constraints[i].monotonicity for i in self.fr_idx]
        self.n_cv = len(self.cv_idx)
        self.n_cc = len(self.cc_idx)
        self.n_fr = len(self.fr_idx)

        # Initialize W's

        # First, the ones used to propagate the hidden state.
        def _init_Wi(n_in: int) -> nn.ParameterList:
            Wi = nn.ParameterList()
            for i in range(len(hidden_sizes[:-1])):
                Wi.append(
                    nn.Parameter(torch.randn(hidden_sizes[i + 1], hidden_sizes[i]))
                )
            Wi.append(nn.Parameter(torch.randn(n_outputs, hidden_sizes[-1])))
            return Wi

        self.Wi_cv = _init_Wi(self.n_cv)
        self.Wi_cc = _init_Wi(self.n_cc)
        self.Wi_fr = _init_Wi(self.n_fr)

        def _init_Wx_bx(n_in: int) -> tuple[nn.ParameterList, nn.ParameterList]:
            Wx = nn.ParameterList()
            bx = nn.ParameterList()
            Wx.append(nn.Parameter(torch.randn(hidden_sizes[0], n_in)))
            bx.append(nn.Parameter(torch.randn(hidden_sizes[0])))
            for i in range(len(hidden_sizes[:-1])):
                Wx.append(nn.Parameter(torch.randn(hidden_sizes[i + 1], n_in)))
                bx.append(nn.Parameter(torch.randn(hidden_sizes[i + 1])))
            Wx.append(nn.Parameter(torch.randn(n_outputs, n_in)))
            bx.append(nn.Parameter(torch.randn(n_outputs)))
            return Wx, bx

        # Next, the ones used to directly feed inputs into each layer.
        # These are the ones that need montonicity constraints.
        self.Wx_cv, self.bx_cv = _init_Wx_bx(self.n_cv)
        self.Wx_cc, self.bx_cc = _init_Wx_bx(self.n_cc)
        self.Wx_fr, self.bx_fr = _init_Wx_bx(self.n_fr)

        # Scaling factors pre activation
        scale_cv = nn.ParameterList()
        scale_cc = nn.ParameterList()
        scale_fr = nn.ParameterList()
        for _ in range(len(hidden_sizes)):
            scale_cv.append(nn.Parameter(torch.randn(1)))
            scale_cc.append(nn.Parameter(torch.randn(1)))
            scale_fr.append(nn.Parameter(torch.randn(1)))
        scale_cv.append(nn.Parameter(torch.randn(1)))
        scale_cc.append(nn.Parameter(torch.randn(1)))
        scale_fr.append(nn.Parameter(torch.randn(1)))
        self.scale_cv = scale_cv
        self.scale_cc = scale_cc
        self.scale_fr = scale_fr

        self.softplus = nn.Softplus()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Split x into convex, concave, free parts
        x_cv = x[:, self.cv_idx] if self.n_cv > 0 else torch.empty(x.size(0), 0)
        x_cc = x[:, self.cc_idx] if self.n_cc > 0 else torch.empty(x.size(0), 0)
        x_fr = x[:, self.fr_idx] if self.n_fr > 0 else torch.empty(x.size(0), 0)

        # Deal with first layer first - no hidden state yet
        z = torch.zeros(x.size(0), self.Wx_cv[0].size(0))
        if self.n_cv > 0:
            z_cv = self.activation(
                self.softplus(self.scale_cv[0])
                * F.linear(
                    x_cv,
                    apply_col_monotone(self.Wx_cv[0], self.cv_mono, self.softplus),
                    self.bx_cv[0],
                )
            )
            z += z_cv
        if self.n_cc > 0:
            z_cc = self.activation(
                self.softplus(self.scale_cc[0])
                * F.linear(
                    x_cc,
                    apply_col_monotone(self.Wx_cc[0], self.cc_mono_r, self.softplus),
                    self.bx_cc[0],
                )
            )
            z -= z_cc
        if self.n_fr > 0:
            z_fr = self.activation(
                self.softplus(self.scale_fr[0])
                * F.linear(
                    x_fr,
                    apply_col_monotone(self.Wx_fr[0], self.fr_mono, self.softplus),
                    self.bx_fr[0],
                )
            )
            z += z_fr

        # Now deal with subsequent layers
        for i in range(1, len(self.Wx_cv)):
            z_ip1 = torch.zeros(x.size(0), self.Wx_cv[i].size(0))
            if self.n_cv > 0:
                z_cv = self.activation(
                    self.softplus(self.scale_cv[i])
                    * (
                        F.linear(
                            z,
                            self.softplus(self.Wi_cv[i - 1]),
                        )
                        + F.linear(
                            x_cv,
                            apply_col_monotone(
                                self.Wx_cv[i], self.cv_mono, self.softplus
                            ),
                            self.bx_cv[i],
                        )
                    )
                )
                z_ip1 += z_cv
            if self.n_cc > 0:
                z_cc = self.activation(
                    self.softplus(self.scale_cc[i])
                    * (
                        F.linear(
                            z,
                            self.softplus(self.Wi_cc[i - 1]),
                        )
                        + F.linear(
                            x_cc,
                            apply_col_monotone(
                                self.Wx_cc[i], self.cc_mono_r, self.softplus
                            ),
                            self.bx_cc[i],
                        )
                    )
                )
                z_ip1 -= z_cc
            if self.n_fr > 0:
                z_fr = self.activation(
                    self.softplus(self.scale_fr[i])
                    * (
                        F.linear(
                            z,
                            self.softplus(self.Wi_fr[i - 1]),
                        )
                        + F.linear(
                            x_fr,
                            apply_col_monotone(
                                self.Wx_fr[i], self.fr_mono, self.softplus
                            ),
                            self.bx_fr[i],
                        )
                    )
                )
                z_ip1 += z_fr
            z = z_ip1

        return z

## Univariate testing

In [144]:
def f(x: np.ndarray) -> np.ndarray:
    return np.exp(0.5 * -x)


def generate_fake_data(
    n: int, f: Callable[[np.ndarray], np.ndarray]
) -> tuple[pd.DataFrame, pd.DataFrame]:
    x = np.random.normal(0, 1, size=n)
    y = f(x) + 0.1 * np.random.normal(0, 1, size=n)
    x_line = np.linspace(-4, 4, 1000)
    y_line = f(x_line)
    test_df = pd.DataFrame({"x": x, "y": y})
    line_df = pd.DataFrame({"x": x_line, "y": y_line})
    return test_df, line_df


test_df, line_df = generate_fake_data(3000, f)

In [None]:
# net = PartiallyConvexNetwork(
#     input_dim=1,
#     hidden_sizes=[8, 8, 8],
#     activation=nn.Softplus,
#     constraints=[ConstraintSpec(monotonicity="free", convexity="free")],
# )

# net = nn.Sequential(
#     MonotonicUnivariateLayer(1, 16),
#     nn.Softplus(),
#     MonotonicUnivariateLayer(16, 1),
#     # net,
# )

net = UnivariateICNN(
    hidden_sizes=[16, 16, 16, 16], activation=nn.Softplus(), monotonic="free"
)

# Train `net` on test_df (inputs: x, targets: y)

# prepare tensors (raw)
X = torch.tensor(test_df["x"].values, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(test_df["y"].values, dtype=torch.float32).unsqueeze(1)

# baseline MSE (predicting the mean)
mse_mean = ((y - y.mean()) ** 2).mean().item()
print("baseline MSE (mean):", mse_mean)

# dataset uses scaled values
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = net.to("cpu")


# reinitialize parameters (Xavier for weight-like tensors, zeros for biases)
def reinit_params(m):
    for _, p in m.named_parameters():
        if p.dim() >= 2:
            nn.init.xavier_uniform_(p)
        else:
            nn.init.zeros_(p)


reinit_params(model)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=20, factor=0.5)
loss_fn = torch.nn.MSELoss()

model.train()
for epoch in range(100):
    epoch_loss = 0.0
    for batch_idx, (xb, yb) in enumerate(loader):
        opt.zero_grad()
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        opt.step()
        epoch_loss += loss.item() * xb.size(0)
    epoch_loss = epoch_loss / len(dataset)
    scheduler.step(epoch_loss)
    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch {epoch + 1:3d} loss: {epoch_loss:.6f}, lr={opt.param_groups[0]['lr']:.3e}"
        )

print("Final loss:", epoch_loss)


baseline MSE (mean): 0.3694140613079071
Epoch  10 loss: 0.250788, lr=1.000e-03
Epoch  20 loss: 0.011361, lr=1.000e-03
Epoch  30 loss: 0.011737, lr=1.000e-03
Epoch  40 loss: 0.010660, lr=1.000e-03
Epoch  50 loss: 0.010780, lr=1.000e-03
Epoch  60 loss: 0.011510, lr=1.000e-03
Epoch  70 loss: 0.010136, lr=5.000e-04
Epoch  80 loss: 0.010544, lr=5.000e-04
Epoch  90 loss: 0.010137, lr=5.000e-04
Epoch 100 loss: 0.010226, lr=2.500e-04
Final loss: 0.01022616225729386


In [242]:
# compute predictions on the grid and add to `line_df` (unscale predictions)
import torch

model = model.to("cpu")
model.eval()
with torch.no_grad():
    X_line = torch.tensor(line_df["x"].values, dtype=torch.float32).unsqueeze(1)
    preds_scaled = model(X_line).cpu().numpy().flatten()

line_df["y-hat"] = preds_scaled
print(line_df.head())

          x         y     y-hat
0 -4.000000  7.389056  6.500770
1 -3.991992  7.359529  6.480800
2 -3.983984  7.330121  6.460871
3 -3.975976  7.300830  6.440985
4 -3.967968  7.271656  6.421143


In [243]:
scatter = alt.Chart(test_df.iloc[:1000]).mark_point().encode(x="x", y="y")
line = (
    alt.Chart(line_df)
    .mark_line(color="red")
    .encode(x="x", y="y")
    .properties(width=800, height=400)
)
line_mod = (
    alt.Chart(line_df)
    .mark_line(color="black")
    .encode(x="x", y="y-hat")
    .properties(width=800, height=400)
)
chart = scatter + line + line_mod
chart  # .interactive()

## Bivariate testing

In [160]:
x0 = np.linspace(0, 4, 41)
x1 = np.linspace(0, 4, 41)
X0, X1 = np.meshgrid(x0, x1)
X_grid = np.vstack([X0.flatten(), X1.flatten()]).T
y_grid_true = X_grid[:, 0] ** 2 * X_grid[:, 1] ** 2 / 16
y_grid_vals = y_grid_true + np.random.normal(0, 1, size=y_grid_true.shape)

In [161]:
# Convert this grid to columnar data expected by Altair
source = pd.DataFrame({"x": X_grid[:, 0], "y": X_grid[:, 1], "z": y_grid_vals})
sourcet = pd.DataFrame({"x": X_grid[:, 0], "y": X_grid[:, 1], "z": y_grid_true})

alt.data_transformers.disable_max_rows()
c1 = (
    alt.Chart(source)
    .mark_rect()
    .encode(
        x="x:O",
        y="y:O",
        color=alt.Color(
            "z:Q",
        ),
    )
)
c2 = (
    alt.Chart(sourcet)
    .mark_rect()
    .encode(
        x="x:O",
        y="y:O",
        color=alt.Color(
            "z:Q",
        ),
    )
)
c1 | c2


In [204]:
net = MultivariateICNN(
    hidden_sizes=[16, 16],
    input_dim=2,
    activation=nn.Softplus(),
    constraints=[
        ConstraintSpec(monotonicity="increasing", convexity="concave"),
        ConstraintSpec(monotonicity="increasing", convexity="convex"),
    ],
)

X = torch.tensor(X_grid, dtype=torch.float32).to("cpu")
y = torch.tensor(y_grid_vals.reshape(-1, 1), dtype=torch.float32).to("cpu")

# baseline MSE (predicting the mean)
mse_mean = ((y - y.mean()) ** 2).mean().item()
print("baseline MSE (mean):", mse_mean)

# dataset uses scaled values
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = net.to("cpu")


# reinitialize parameters (Xavier for weight-like tensors, zeros for biases)
def reinit_params(m):
    for _, p in m.named_parameters():
        if p.dim() >= 2:
            nn.init.xavier_uniform_(p)
        else:
            nn.init.zeros_(p)


reinit_params(model)

opt = torch.optim.Adam(model.parameters(), lr=5e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=20, factor=0.5)
loss_fn = torch.nn.MSELoss()

model.train()
for epoch in range(600):
    epoch_loss = 0.0
    for batch_idx, (xb, yb) in enumerate(loader):
        opt.zero_grad()
        out = model(xb)
        loss = loss_fn(out, yb)
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        opt.step()
        epoch_loss += loss.item() * xb.size(0)
    epoch_loss = epoch_loss / len(dataset)
    scheduler.step(epoch_loss)
    if (epoch + 1) % 30 == 0:
        print(
            f"Epoch {epoch + 1:3d} loss: {epoch_loss:.6f}, lr={opt.param_groups[0]['lr']:.3e}"
        )

print("Final loss:", epoch_loss)

preds = model(X).detach().cpu().numpy().flatten()
preds_df = pd.DataFrame({"x": X_grid[:, 0], "y": X_grid[:, 1], "z": preds})

baseline MSE (mean): 8.821823120117188
Epoch  30 loss: 1.321659, lr=5.000e-03
Epoch  60 loss: 1.287611, lr=2.500e-03
Epoch  90 loss: 1.279380, lr=2.500e-03
Epoch 120 loss: 1.258749, lr=1.250e-03
Epoch 150 loss: 1.266613, lr=1.250e-03
Epoch 180 loss: 1.255242, lr=6.250e-04
Epoch 210 loss: 1.251470, lr=3.125e-04
Epoch 240 loss: 1.250673, lr=3.125e-04
Epoch 270 loss: 1.249029, lr=7.813e-05
Epoch 300 loss: 1.248443, lr=3.906e-05
Epoch 330 loss: 1.248156, lr=3.906e-05
Epoch 360 loss: 1.247304, lr=1.953e-05
Epoch 390 loss: 1.247068, lr=1.953e-05
Epoch 420 loss: 1.246865, lr=9.766e-06
Epoch 450 loss: 1.246763, lr=4.883e-06
Epoch 480 loss: 1.246715, lr=1.221e-06
Epoch 510 loss: 1.246724, lr=6.104e-07
Epoch 540 loss: 1.246719, lr=1.526e-07
Epoch 570 loss: 1.246715, lr=7.629e-08
Epoch 600 loss: 1.246715, lr=3.815e-08
Final loss: 1.2467150988144793


In [205]:
c1 = (
    alt.Chart(source)
    .mark_rect()
    .encode(
        x="x:O",
        y="y:O",
        color=alt.Color(
            "z:Q",
        ),
    )
)
c2 = (
    alt.Chart(preds_df.assign(err=preds_df["z"] - source["z"]))
    .mark_rect()
    .encode(
        x="x:O",
        y="y:O",
        color=alt.Color(
            "z:Q",
        ),
    )
)
c3 = (
    alt.Chart(preds_df.assign(err=preds_df["z"] - source["z"]))
    .mark_rect()
    .encode(
        x="x:O",
        y="y:O",
        color=alt.Color(
            "err:Q",
        ),
    )
)
c1 | c2 | c3

In [206]:
x_ = 3
pred = (
    alt.Chart(preds_df[np.isclose(preds_df["x"], x_)])
    .mark_line(color="black")
    .encode(
        x="y:Q",
        y="z:Q",
    )
)
true = (
    alt.Chart(sourcet[np.isclose(sourcet["x"], x_)])
    .mark_line(color="red")
    .encode(
        x="y:Q",
        y="z:Q",
    )
)
train = (
    alt.Chart(source[np.isclose(source["x"], x_)])
    .mark_point()
    .encode(x="y:Q", y="z:Q", color=alt.value("blue"))
)
pred + true + train

In [207]:
y_ = 3
pred = (
    alt.Chart(preds_df[np.isclose(preds_df["y"], y_)])
    .mark_line(color="black")
    .encode(
        x="x:Q",
        y="z:Q",
    )
)
true = (
    alt.Chart(sourcet[np.isclose(sourcet["y"], y_)])
    .mark_line(color="red")
    .encode(
        x="x:Q",
        y="z:Q",
    )
)
train = (
    alt.Chart(source[np.isclose(source["y"], y_)])
    .mark_point()
    .encode(x="x:Q", y="z:Q", color=alt.value("blue"))
)
pred + true + train