# Theoretical Equilibrated Energy

In [None]:
#@title Installations


%%capture
!pip install equinox==0.11.2
!pip install diffrax==0.5.1

!pip install torch==2.3.1
!pip install torchvision==0.18.1

!pip install plotly==5.11.0
!pip install -U kaleido

In [None]:
#@title Imports

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from typing import Callable, Optional, Tuple, Dict
from jaxtyping import ArrayLike, Scalar, PyTree, Array, PRNGKeyArray

import jax
from jax import vmap
import jax.random as jr
from jax.tree_util import tree_map
import jax.numpy as jnp

import os
import random
import numpy as np

import equinox as eqx
import equinox.nn as nn
from equinox import filter_grad

from diffrax import (
    diffeqsolve,
    ODETerm,
    SaveAt,
    Heun,
    PIDController,
    AbstractSolver,
    AbstractStepSizeController
)

import optax
from optax import (
    GradientTransformation,
    GradientTransformationExtraArgs,
    OptState
)

import plotly.graph_objs as go

In [None]:
#@title data utils


def get_dataloaders(dataset_id, batch_size):
    train_data = get_dataset(
        name=dataset_id,
        train=True,
        normalise=True
    )
    test_data = get_dataset(
        name=dataset_id,
        train=False,
        normalise=True
    )
    train_loader = DataLoader(
        dataset=train_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )
    return train_loader, test_loader


def get_dataset(name, train, normalise):
    if name == "MNIST":
        dataset = MNIST(train, normalise)
    elif name == "Fashion-MNIST":
        dataset = FashionMNIST(train, normalise)
    elif name == "CIFAR10":
        dataset = CIFAR10(train, normalise)
    return dataset


class MNIST(datasets.MNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.1307), std=(0.3081)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


class FashionMNIST(datasets.FashionMNIST):
    def __init__(self, train, normalise=True, save_dir="data"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.5), std=(0.5)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


class CIFAR10(datasets.CIFAR10):
    def __init__(self, train, normalise=True, save_dir=f"data/CIFAR10"):
        if normalise:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=(0.4914, 0.4822, 0.4465),
                        std=(0.247, 0.243, 0.261)
                    )
                ]
            )
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        super().__init__(save_dir, download=True, train=train, transform=transform)

    def __getitem__(self, index):
        img, label = super().__getitem__(index)
        img = torch.flatten(img)
        label = one_hot(label)
        return img, label


def one_hot(labels, n_classes=10):
    arr = torch.eye(n_classes)
    return arr[labels]


In [None]:
#@title utils


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)


def make_linear_net(key, dataset, n_hidden, width):
    subkeys = jr.split(key, n_hidden+1)
    input_dim = 3072 if dataset == "CIFAR10" else 784
    output_dim = 10
    linear_net = []
    for l in range(n_hidden+1):
        linear_net.append(
            eqx.nn.Linear(
                in_features=input_dim if l == 0 else width,
                out_features=output_dim if l == n_hidden else width,
                key=subkeys[l],
                use_bias=False
            )
        )
    return linear_net


def compute_accuracy(truths: ArrayLike, preds: ArrayLike) -> Scalar:
    return jnp.mean(
        jnp.argmax(truths, axis=1) == jnp.argmax(preds, axis=1)
    )


In [None]:
#@title pc


def init_activities_with_ffwd(
        model: PyTree[Callable],
        x: ArrayLike
) -> PyTree[Array]:

    activities = [vmap(model[0])(x)]
    for l in range(1, len(model)):
        activities.append(vmap(model[l])(activities[l-1]))

    return activities


def pc_energy_fn(
        model: PyTree[Callable],
        activities: PyTree[ArrayLike],
        y: ArrayLike,
        x: Optional[ArrayLike] = None,
        record_layers: bool = False
) -> Scalar | Array:
    """Computes the free energy for a feedforward neural network of the form

    $$
    \mathcal{F}(\mathbf{z}; θ) = 1/N \sum_i^N \sum_{\ell=1}^L || \mathbf{z}_{i, \ell} - f_\ell(\mathbf{z}_{i, \ell-1}; θ) ||^2
    $$

    given parameters $θ$, free activities $\mathbf{z}$, output
    $\mathbf{z}_L = \mathbf{y}$ and optionally input $\mathbf{z}_0 = \mathbf{x}$.
    The activity of each layer $\mathbf{z}_\ell$ is some function of the previous
    layer, e.g. ReLU$(W_\ell \mathbf{z}_{\ell-1} + \mathbf{b}_\ell)$
    for a fully connected layer with biases and ReLU as activation.

    !!! note

        The input x and output y correspond to the prior and observation of
        the generative model, respectively.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `activities`: List of activities for each layer free to vary.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    **Other arguments:**

    - `record_layers`: If `True`, returns energies for each layer.

    **Returns:**

    The total or layer-wise energy normalised by batch size.

    """
    batch_size = y.shape[0]
    start_activity_l = 1 if x is not None else 2
    n_activity_layers = len(activities) - 1
    n_layers = len(model) - 1

    eL = y - vmap(model[-1])(activities[-2])
    energies = [jnp.sum(eL ** 2)]

    for act_l, net_l in zip(
            range(start_activity_l, n_activity_layers),
            range(1, n_layers)
    ):
        err = activities[act_l] - vmap(model[net_l])(activities[act_l-1])
        energies.append(jnp.sum(err ** 2))

    e1 = activities[0] - vmap(model[0])(x) if (
            x is not None
    ) else activities[1] - vmap(model[0])(activities[0])
    energies.append(jnp.sum(e1 ** 2))

    if record_layers:
        return jnp.array(energies) / batch_size
    else:
        return jnp.sum(jnp.array(energies)) / batch_size


def neg_activity_grad(
        t: float | int,
        activities: PyTree[ArrayLike],
        args: Tuple[PyTree[Callable], ArrayLike, Optional[ArrayLike]],
        energy_fn: Callable = pc_energy_fn,
) -> PyTree[Array]:
    """Computes the negative gradient of the energy with respect to the activities $- \partial \mathcal{F} / \partial \mathbf{z}$.

    This defines an ODE system to be integrated by `solve_pc_activities`.

    **Main arguments:**

    - `t`: Time step of the ODE system, used for downstream integration by
        `diffrax.diffeqsolve`.
    - `activities`: List of activities for each layer free to vary.
    - `args`: 3-Tuple with
        (i) list of callable layers for the generative model,
        (ii) network output (observation), and
        (iii) network input (prior).
    - `pc_energy_fn`: Free energy to take the gradient of.

    **Returns:**

    List of negative gradients of the energy w.r.t. the activities.

    """
    model, y, x = args
    dFdzs = jax.grad(energy_fn, argnums=1)(
        model,
        activities,
        y,
        x
    )
    return tree_map(lambda dFdz: -dFdz, dFdzs)


def compute_pc_param_grads(
        model: PyTree[Callable],
        activities: PyTree[ArrayLike],
        y: ArrayLike,
        x: Optional[ArrayLike] = None
) -> PyTree[Array]:
    """Computes the gradient of the energy with respect to model parameters $\partial \mathcal{F} / \partial θ$.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `activities`: List of activities for each layer free to vary.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    **Returns:**

    List of parameter gradients for each network layer.

    """
    return filter_grad(pc_energy_fn)(
        model,
        activities,
        y,
        x
    )


def solve_pc_activities(
        model: PyTree[Callable],
        activities: PyTree[ArrayLike],
        y: ArrayLike,
        x: Optional[ArrayLike] = None,
        solver: AbstractSolver = Heun(),
        t1: int = 20,
        dt: float | int = None,
        stepsize_controller: AbstractStepSizeController = PIDController(
            rtol=1e-3, atol=1e-3
        ),
        record_iters: bool = False
) -> PyTree[Array]:
    """Solves the activity (inference) dynamics of a predictive coding network.

    This is a wrapper around `diffrax.diffeqsolve` to integrate the gradient
    ODE system `_neg_activity_grad` defining the PC activity dynamics

    $$
    \partial \mathbf{z} / \partial t = - \partial \mathcal{F} / \partial \mathbf{z}
    $$

    where $\mathcal{F}$ is the free energy, $\mathbf{z}$ are the activities,
    with $\mathbf{z}_L$ clamped to some target and $\mathbf{z}_0$ optionally
    equal to some prior.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `activities`: List of activities for each layer free to vary.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    **Other arguments:**

    - `solver`: Diffrax (ODE) solver to be used. Default is Heun, a 2nd order
        explicit Runge--Kutta method.
    - `t1`: Maximum end of integration region (20 by default).
    - `dt`: Integration step size. Defaults to None since the default
        `stepsize_controller` will automatically determine it.
    - `stepsize_controller`: diffrax controller for step size integration.
        Defaults to `PIDController`.
    - `record_iters`: If `True`, returns all integration steps.

    **Returns:**

    List with solution of the activity dynamics for each layer.

    """
    sol = diffeqsolve(
        terms=ODETerm(neg_activity_grad),
        solver=solver,
        t0=0,
        t1=t1,
        dt0=dt,
        y0=activities,
        args=(model, y, x),
        stepsize_controller=stepsize_controller,
        saveat=SaveAt(t1=True, steps=record_iters)
    )
    return sol.ys


def get_t_max(activities_iters: PyTree[Array]) -> Array:
    return jnp.argmax(activities_iters[0][:, 0, 0]) - 1


def compute_infer_energies(
        model: PyTree[Callable],
        activities_iters: PyTree[Array],
        t_max: Array,
        y: ArrayLike,
        x: Optional[ArrayLike] = None
) -> PyTree[Scalar]:
    """Calculates layer energies during predictive coding inference.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `activities_iters`: Layer-wise activities at every inference iteration.
        Note that each set of activities will have 4096 steps as first
        dimension by diffrax default.
    - `t_max`: Maximum number of inference iterations to compute energies for.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    **Returns:**

    List of layer-wise energies at every inference iteration.

    """
    def loop_body(state):
        t, energies_iters = state

        energies = pc_energy_fn(
            model=model,
            activities=tree_map(lambda act: act[t], activities_iters),
            y=y,
            x=x,
            record_layers=True
        )
        energies_iters = energies_iters.at[:, t].set(energies)
        return t + 1, energies_iters

    # 4096 is the max number of steps set in diffrax
    energies_iters = jnp.zeros((len(model), 4096))
    _, energies_iters = jax.lax.while_loop(
        lambda state: state[0] < t_max,
        loop_body,
        (0, energies_iters)
    )
    return energies_iters[::-1, :]


@eqx.filter_jit
def make_pc_step(
      model: PyTree[Callable],
      optim: GradientTransformation | GradientTransformationExtraArgs,
      opt_state: OptState,
      y: ArrayLike,
      x: Optional[ArrayLike] = None,
      ode_solver: AbstractSolver = Heun(),
      t1: int = 20,
      dt: float | int = None,
      stepsize_controller: AbstractStepSizeController = PIDController(
          rtol=1e-3, atol=1e-3
      ),
      key: Optional[PRNGKeyArray] = None,
      layer_sizes: Optional[PyTree[int]] = None,
      batch_size: Optional[int] = None,
      sigma: Scalar = 0.05,
      record_activities: bool = False,
      record_energies: bool = False
) -> Dict:
    """Updates network parameters with predictive coding.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `optim`: Optax optimiser, e.g. `optax.sgd()`.
    - `opt_state`: State of Optax optimiser.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    !!! note

        `key`, `layer_sizes` and `batch_size` must be passed if `input` is
        `None`, since unsupervised training will be assumed and activities need
        to be initialised randomly.

    **Other arguments:**

    - `ode_solver`: Diffrax ODE solver to be used. Default is Heun, a 2nd order
        explicit Runge--Kutta method.
    - `t1`: Maximum end of integration region (20 by default).
    - `dt`: Integration step size. Defaults to None since the default
        `stepsize_controller` will automatically determine it.
    - `stepsize_controller`: diffrax controller for step size integration.
        Defaults to `PIDController`.
    - `key`: `jax.random.PRNGKey` for random initialisation of activities.
    - `layer_sizes`: Dimension of all layers (input, hidden and output).
    - `batch_size`: Dimension of data batch for activity initialisation.
    - `sigma`: Standard deviation for Gaussian to sample activities from for
        random initialisation. Defaults to 5e-2.
    - `record_activities`: If `True`, returns activities at every inference
        iteration.
    - `record_energies`: If `True`, returns layer-wise energies at every
        inference iteration.

    **Returns:**

    Dict including model with updated parameters, optimiser, updated optimiser
    state, equilibrated activities, last inference step, MSE loss, and energies.

    **Raises:**

    - `ValueError` for inconsistent inputs.

    """
    if x is None and any(arg is None for arg in (key, layer_sizes, batch_size)):
        raise ValueError("""
            If there is no input (i.e. `x` is None), then unsupervised training
            is assumed, and `key`, `layer_sizes` and `batch_size` must be
            passed for random initialisation of activities.
        """)

    if record_energies:
        record_activities = True

    activities = init_activities_with_ffwd(model=model, x=x)

    mse_loss = jnp.mean((y - activities[-1])**2) if x is not None else None
    equilib_activities = solve_pc_activities(
        model=model,
        activities=activities,
        y=y,
        x=x,
        solver=ode_solver,
        t1=t1,
        dt=dt,
        stepsize_controller=stepsize_controller,
        record_iters=record_activities
    )
    t_max = get_t_max(equilib_activities) if record_activities else None
    energies = compute_infer_energies(
        model=model,
        activities_iters=equilib_activities,
        t_max=t_max,
        y=y,
        x=x
    ) if record_energies else None

    param_grads = compute_pc_param_grads(
        model=model,
        activities=tree_map(
            lambda act: act[t_max if record_activities else jnp.array(0)],
            equilib_activities
        ),
        y=y,
        x=x
    )
    updates, opt_state = optim.update(
        updates=param_grads,
        state=opt_state,
        params=model
    )
    model = eqx.apply_updates(model=model, updates=updates)
    return {
        "model": model,
        "optim": optim,
        "opt_state": opt_state,
        "activities": equilib_activities,
        "t_max": t_max,
        "loss": mse_loss,
        "energies": energies
    }


@eqx.filter_jit
def test_discriminative_pc(
        model: PyTree[Callable],
        y: ArrayLike,
        x: ArrayLike,
) -> Scalar:
    """Computes prediction accuracy of a discriminative predictive coding network.

    **Main arguments:**

    - `model`: List of callable model (e.g. neural network) layers.
    - `y`: Observation or target of the generative model.
    - `x`: Optional prior of the generative model.

    **Returns:**

    Accuracy of output predictions.

    """
    preds = init_activities_with_ffwd(model=model, x=x)[-1]
    return compute_accuracy(y, preds)


In [None]:
#@title analytical


def linear_equilib_energy_single(
        network: PyTree[nn.Linear],
        x: ArrayLike,
        y: ArrayLike
) -> Array:
    Ws = [l.weight for l in network]
    L = len(Ws)

    # Compute product of weight matrices
    WLto1 = jnp.eye(Ws[-1].shape[0])
    for i in range(L - 1, -1, -1):
        WLto1 = WLto1 @ Ws[i]

    # Compute rescaling
    S = jnp.eye(Ws[-1].shape[0])
    cumulative_prod = jnp.eye(Ws[-1].shape[0])
    for i in range(L - 1, 0, -1):
        cumulative_prod = cumulative_prod @ Ws[i]
        S += cumulative_prod @ cumulative_prod.T

    # Compute full expression
    r = y - WLto1 @ x
    return r.T @ jnp.linalg.inv(S) @ r


@eqx.filter_jit
def linear_equilib_energy(
        network: PyTree[nn.Linear],
        x: ArrayLike,
        y: ArrayLike
) -> Array:
    """Computes the theoretical equilibrated PC energy for a deep linear network (DLN).

    $$
    \mathcal{F}^* = 1/N \sum_i^N (\mathbf{y}_i - W_{L:1}\mathbf{x}_i)^T S^{-1}(\mathbf{y}_i - W_{L:1}\mathbf{x}_i)
    $$

    where the rescaling is $S = I_{d_y} + \sum_{\ell=2}^L (W_{L:\ell})(W_{L:\ell})^T$,
    and we use the shorthand $W_{L:\ell} = W_L W_{L-1} \dots W_\ell$.

    !!! note

        This expression assumes no biases.

    **Main arguments:**

    - `network`: Linear network defined as a list of Equinox Linear layers.
    - `x`: Network input.
    - `y`: Network output.

    **Returns:**

    Mean total analytical energy across data batch.

    """
    return vmap(lambda x, y: linear_equilib_energy_single(
        network,
        x,
        y
    ))(x, y).mean()


In [None]:
#@title train & test scripts


def evaluate(model, test_loader):
    test_acc = 0
    for batch_id, (img_batch, label_batch) in enumerate(test_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        test_acc += test_discriminative_pc(
            model=model,
            y=label_batch,
            x=img_batch
        )

    return test_acc / len(test_loader)


def train(
      seed,
      dataset,
      n_hidden,
      width,
      lr,
      batch_size,
      t1,
      test_every,
      n_train_iters
):
    key = jr.PRNGKey(SEED)
    model = make_linear_net(key, dataset, n_hidden, width)

    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))
    train_loader, test_loader = get_dataloaders(dataset, batch_size)

    num_total_energies, theory_total_energies = [], []
    for iter, (img_batch, label_batch) in enumerate(train_loader):
        img_batch = img_batch.numpy()
        label_batch = label_batch.numpy()

        theory_total_energies.append(
            linear_equilib_energy(
                network=model,
                x=img_batch,
                y=label_batch
            )
        )
        result = make_pc_step(
            model,
            optim,
            opt_state,
            y=label_batch,
            x=img_batch,
            t1=t1,
            record_energies=True
        )
        model, optim, opt_state = result["model"], result["optim"], result["opt_state"]
        train_loss, t_max = result["loss"], result["t_max"]
        num_total_energies.append(result["energies"][:, t_max-1].sum())

        if ((iter+1) % test_every) == 0:
            avg_test_acc = evaluate(model, test_loader)
            print(
                f"Train iter {iter+1}, train loss={train_loss:4f}, "
                f"avg test accuracy={avg_test_acc:4f}"
            )
            if (iter+1) >= n_train_iters:
                break

    return {
        "experiment": jnp.array(num_total_energies),
        "theory": jnp.array(theory_total_energies)
    }


In [None]:
#@title plotting


def plot_total_energies(energies, save_path):
    n_train_iters = len(energies["theory"])
    train_iters = [b+1 for b in range(n_train_iters)]

    fig = go.Figure()
    for energy_type, energy in energies.items():
        is_theory = energy_type == "theory"
        fig.add_traces(
            go.Scatter(
                x=train_iters,
                y=energy,
                name=energy_type,
                mode="lines",
                line=dict(
                    width=3,
                    dash="dash" if is_theory else "solid",
                    color="rgb(27, 158, 119)" if is_theory else "#00CC96"
                ),
                legendrank=1 if is_theory else 2
            )
        )

    fig.update_layout(
        height=300,
        width=450,
        xaxis=dict(
            title="Training iteration",
            tickvals=[1, int(train_iters[-1]/2), train_iters[-1]],
            ticktext=[1, int(train_iters[-1]/2), train_iters[-1]],
        ),
        yaxis=dict(
            title="Energy",
            nticks=3
        ),
        font=dict(size=16),
    )
    fig.write_image(save_path)


In [None]:
DATASETS = ["MNIST", "Fashion-MNIST", "CIFAR10"]
N_HIDDENS = [2, 5, 10]

SEED = 0
RESULTS_DIR = "results"
WIDTH = 300
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
T1 = 300
TEST_EVERY = 10
N_TRAIN_ITERS = 100

In [None]:
set_seed(SEED)
os.makedirs(RESULTS_DIR, exist_ok=True)

for dataset in DATASETS:
    for n_hidden in N_HIDDENS:
        print(f"\n{dataset}, {n_hidden} hidden layers")
        energies = train(
            seed=SEED,
            dataset=dataset,
            n_hidden=n_hidden,
            width=WIDTH,
            lr=LEARNING_RATE,
            batch_size=BATCH_SIZE,
            test_every=TEST_EVERY,
            t1=T1,
            n_train_iters=N_TRAIN_ITERS
        )
        plot_total_energies(
            energies,
            f"{RESULTS_DIR}/theory_energy_n_hidden_{dataset}_{n_hidden}.pdf"
        )


MNIST, 2 hidden layers
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 39194193.65it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1220031.36it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9705737.78it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4780559.29it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






Train iter 10, train loss=0.078076, avg test accuracy=0.752704
Train iter 20, train loss=0.062510, avg test accuracy=0.776342
Train iter 30, train loss=0.050910, avg test accuracy=0.780248
Train iter 40, train loss=0.052440, avg test accuracy=0.802684
Train iter 50, train loss=0.044833, avg test accuracy=0.829728
Train iter 60, train loss=0.045435, avg test accuracy=0.805188
Train iter 70, train loss=0.043837, avg test accuracy=0.798177
Train iter 80, train loss=0.056742, avg test accuracy=0.817508
Train iter 90, train loss=0.048744, avg test accuracy=0.810096
Train iter 100, train loss=0.045621, avg test accuracy=0.818409

MNIST, 5 hidden layers
Train iter 10, train loss=0.068379, avg test accuracy=0.755609
Train iter 20, train loss=0.057660, avg test accuracy=0.773538
Train iter 30, train loss=0.058808, avg test accuracy=0.774740
Train iter 40, train loss=0.047768, avg test accuracy=0.793970
Train iter 50, train loss=0.051430, avg test accuracy=0.776743
Train iter 60, train loss=0.06

100%|██████████| 26421880/26421880 [00:01<00:00, 13457452.78it/s]


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 272162.20it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5028931.59it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 7706023.19it/s]


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Train iter 10, train loss=0.059050, avg test accuracy=0.678185
Train iter 20, train loss=0.048410, avg test accuracy=0.718550
Train iter 30, train loss=0.047848, avg test accuracy=0.759115
Train iter 40, train loss=0.046312, avg test accuracy=0.763922
Train iter 50, train loss=0.048569, avg test accuracy=0.750801
Train iter 60, train loss=0.044502, avg test accuracy=0.771635
Train iter 70, train loss=0.053061, avg test accuracy=0.768530
Train iter 80, train loss=0.053993, avg test accuracy=0.747095
Train iter 90, train loss=0.037717, avg test accuracy=0.773638
Train iter 100, train loss=0.039773, avg test accuracy=0.772135

Fashion-MNIST, 5 hidden layers
Train iter 10, train loss=0.059829, avg test accuracy=0.673978
Train iter 20, train loss=0.051413, avg test accuracy=0.724659
Train iter 30, train loss=0.056933, avg test accuracy=0.758413
Train iter 40, train loss=0.053957, avg test accuracy=0.766026


100%|██████████| 170498071/170498071 [00:02<00:00, 80706797.83it/s]


Extracting data/CIFAR10/cifar-10-python.tar.gz to data/CIFAR10
Files already downloaded and verified
Train iter 10, train loss=0.552155, avg test accuracy=0.115184
Train iter 20, train loss=0.258253, avg test accuracy=0.174279
Train iter 30, train loss=0.195730, avg test accuracy=0.203225
Train iter 40, train loss=0.137791, avg test accuracy=0.190605
Train iter 50, train loss=0.151842, avg test accuracy=0.193810
Train iter 60, train loss=0.145391, avg test accuracy=0.246394
Train iter 70, train loss=0.108184, avg test accuracy=0.223257
Train iter 80, train loss=0.111019, avg test accuracy=0.200621
Train iter 90, train loss=0.121446, avg test accuracy=0.232472
Train iter 100, train loss=0.105853, avg test accuracy=0.250501

CIFAR10, 5 hidden layers
Files already downloaded and verified
Files already downloaded and verified
Train iter 10, train loss=0.101535, avg test accuracy=0.186298
Train iter 20, train loss=0.105133, avg test accuracy=0.200321
Train iter 30, train loss=0.096691, avg 