# Tutorial #9: Monte Carlo Predictive Coding with transposed convolution

In this notebook we will see how to create and train a simple MCPC model to learn a Gaussian data distribution.

In [1]:
from typing import Callable, Union, Sequence
from typing import Any, Callable, Optional, Union
import math, os, random
from pathlib import Path
from tqdm import tqdm
import tempfile, shutil, os, subprocess, warnings


# Core dependencies
import jax
import jax.numpy as jnp
import numpy as np
import optax
import equinox as eqx
from optax._src import base 
from optax._src import combine
from optax._src import transform

# pcax
import pcax as px
import pcax.predictive_coding as pxc
import pcax.nn as pxnn
import pcax.utils as pxu
import pcax.functional as pxf
from pcax.nn import Layer
from pcax.core import RandomKeyGenerator, RKG

import torch
from pytorch_fid.fid_score_mnist import fid_mnist, save_stats_mnist
from inception_score import get_mnist_inception_score

from torchvision import transforms
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

2024-05-29 08:47:45.795343: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
class ConvTranspose(Layer):
    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[str, int, Sequence[int], Sequence[tuple[int, int]]] = 0,
        output_padding: Union[int, Sequence[int]] = 0,
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        padding_mode: str = "ZEROS",
        dtype=None,
        *,
        rkg: RandomKeyGenerator = RKG,
    ):
        super().__init__(
            eqx.nn.ConvTranspose,
            num_spatial_dims=num_spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
            use_bias=use_bias,
            padding_mode=padding_mode,
            key=rkg(),
        )

In [3]:
STATUS_FORWARD = "forward"

class Model(pxc.EnergyModule):
    def __init__(
        self,
        *,
        num_layers: int,
        input_dim: tuple[int, int, int],
        output_dim: tuple[int, int, int],
        out_channels_per_layer: list[int] | None = None,
        kernel_size: Union[int, Sequence[int]],
        bottleneck_dim: int,
        act_fn: Callable[[jax.Array], jax.Array],
        output_act_fn: Callable[[jax.Array], jax.Array] = lambda x: x,
        channel_last: bool = True,
    ):
        super().__init__()
        self.act_fn = px.static(act_fn)
        self.output_act_fn = px.static(output_act_fn)
        self.channel_last = px.static(channel_last)

        input_spatial_dim: np.array
        output_spatial_dim: np.array
        input_channels: int
        output_channels: int
        if self.channel_last.get():
            input_spatial_dim = np.array(input_dim[:-1])
            output_spatial_dim = np.array(output_dim[:-1])
            input_channels = input_dim[-1]
            output_channels = output_dim[-1]
        else:
            input_spatial_dim = np.array(input_dim[1:])
            output_spatial_dim = np.array(output_dim[1:])
            input_channels = input_dim[0]
            output_channels = output_dim[0]

        input_dim = (input_channels, *input_spatial_dim)
        output_dim = (output_channels, *output_spatial_dim)

        spatial_scale = output_spatial_dim / input_spatial_dim
        if np.any(spatial_scale % 1 != 0):
            raise ValueError(
                "scale=(output_dim/input_dim) must be an integer "
                f"input_dim: {input_dim}, output_dim: {output_dim}, scale: {spatial_scale}"
            )

        step_scale = spatial_scale ** (1 / num_layers)
        if np.any(step_scale % 1 != 0):
            raise ValueError(
                "The scale=(output_dim/input_dim) must be a power of the stride number: scale = stride^num_layers. "
                f"Scale: {spatial_scale}, num_layers: {num_layers}, stride: {step_scale}"
            )
        step_scale = step_scale.astype(np.int32)

        if out_channels_per_layer:
            if len(out_channels_per_layer) != num_layers:
                raise ValueError(
                    "out_channels_per_layer must be equal to the number of layers. "
                    f"num_layers: {num_layers}, channels_per_layer: {out_channels_per_layer}"
                )
            if out_channels_per_layer[-1] != output_channels:
                raise ValueError(
                    "The number of channels in the last layer must be equal to the number of output channels. "
                    f"output_channels: {output_channels}, channels_per_layer[-1]: {out_channels_per_layer[-1]}"
                )
        else:
            channel_diff = output_channels - input_channels
            if channel_diff >= 0:
                raise ValueError(
                    "The number of input channels must be greater than the number of output channels. "
                    f"input_channels: {input_channels}, output_channels: {output_channels}"
                )
            step_channel_diff = channel_diff // num_layers
            out_channels_per_layer = [
                (input_channels + i * step_channel_diff) if i < num_layers else output_channels
                for i in range(1, num_layers + 1)
            ]

        input_dims: list[tuple[int, int, int]] = [input_dim]
        output_dims: list[tuple[int, int, int]] = []
        for i in range(num_layers):
            inp = input_dims[i]
            output_dims.append(
                (
                    out_channels_per_layer[i],
                    inp[1] * step_scale[0],
                    inp[2] * step_scale[1],
                )
            )
            if i < num_layers - 1:
                input_dims.append(output_dims[-1])
        assert len(input_dims) == len(output_dims)
        assert output_dims[-1] == output_dim

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)

        self.layers = [pxnn.Linear(in_features=bottleneck_dim, out_features=input_dim[0] * input_dim[1] * input_dim[2])]

        for layer_input, layer_output in zip(input_dims, output_dims):
            paddings = [
                _calculate_padding_and_output_padding(
                    input_dim=layer_input[i + 1],
                    output_dim=layer_output[i + 1],
                    stride=step_scale[i],
                    kernel_size=kernel_size[i],
                )
                for i in range(2)
            ]

            padding, output_padding = zip(*paddings)

            expected_output = tuple(
                step_scale[i] * (layer_input[i + 1] - 1) + kernel_size[i] - 2 * padding[i] + output_padding[i]
                for i in range(2)
            )
            assert expected_output == tuple(layer_output[1:])

            self.layers.append(
                ConvTranspose(
                    num_spatial_dims=2,
                    in_channels=layer_input[0],
                    out_channels=layer_output[0],
                    kernel_size=kernel_size,
                    stride=(step_scale[0], step_scale[1]),
                    padding=padding,
                    output_padding=output_padding,
                )
            )

        self.vodes = [
            pxc.Vode(
                (bottleneck_dim,),
                energy_fn=pxc.zero_energy,
                ruleset={pxc.STATUS.INIT: ("h, u <- u:to_zero",)},
                tforms={"to_zero": lambda n, k, v, rkg: jnp.zeros(n.shape.get())},
            ),
            pxc.Vode(
                input_dim,
                ruleset={pxc.STATUS.INIT: ("h, u <- u:to_zero",), STATUS_FORWARD: ("h -> u",)},
                tforms={"to_zero": lambda n, k, v, rkg: jnp.zeros_like(v)},
            ),
        ]
        for layer_output in output_dims:
            self.vodes.append(
                pxc.Vode(
                    layer_output,
                    ruleset={pxc.STATUS.INIT: ("h, u <- u:to_zero",), STATUS_FORWARD: ("h -> u",)},
                    tforms={"to_zero": lambda n, k, v, rkg: jnp.zeros_like(v)},
                )
            )
        self.vodes[-1].h.frozen = True

        self.input_dim = input_dim

    def __call__(self, x, y) -> jax.Array:
        # The defined ruleset for the first node is to set the hidden state to zero,
        # independent of the input, so we always pass '-1'.
        if x is None:
            x = self.vodes[0](-1)
        
        x = self.layers[0](self.act_fn(x))
        x = self.vodes[1](x.reshape(self.input_dim))

        for i, layer in enumerate(self.layers[1:]):
            act_fn = self.act_fn
            x = layer(act_fn(x))
            if i == (len(self.layers) - 1):
                x = self.output_act_fn(x)
            x = self.vodes[i + 2](x)

        if y is not None:
            if self.channel_last.get():
                y = y.transpose(2, 0, 1)
            self.vodes[-1].set("h", y)

        pred = self.vodes[-1].get("u")
        if self.channel_last.get():
            pred = pred.transpose(1, 2, 0)
        return pred

    # def generate(self, internal_state: jax.Array | None = None) -> jax.Array:
    #     x = internal_state
    #     if x is None:
    #         x = self.internal_state

    #     x = self.act_fn(self.layers[0](x)).reshape((8, 8, 8))
    #     for i, layer in enumerate(self.layers[1:]):
    #         act_fn = self.act_fn if i < len(self.layers) - 1 else self.output_act_fn
    #         x = act_fn(layer(x))

    #     if self.channel_last.get():
    #         x = x.transpose(1, 2, 0)
    #     return x

    @property
    def internal_state(self) -> jax.Array:
        return self.vodes[0].get("h")


def _calculate_padding_and_output_padding(
    *, input_dim: int, output_dim: int, stride: int, kernel_size: int
) -> tuple[int, int]:
    """
    Calculate the padding and output_padding required for a ConvTranspose layer to achieve the desired output dimension.

    Parameters:
    input_dim (int): The size of the input dimension (height or width).
    output_dim (int): The desired size of the output dimension (height or width).
    stride (int): The stride of the convolution.
    kernel_size (int): The size of the convolution kernel.

    Returns:
    tuple: The required padding and output_padding to achieve the desired output dimension.
    """
    no_padding_output_dim = (input_dim - 1) * stride + kernel_size

    padding = math.ceil(max(no_padding_output_dim - output_dim, 0) / 2)
    output_padding = max(output_dim - (no_padding_output_dim - 2 * padding), 0)

    assert no_padding_output_dim - 2 * padding + output_padding == output_dim

    return padding, output_padding

In [4]:
@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=0)
def forward(x, y, *, model: Model):
    return model(x, y)

@pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=(None, 0), axis_name="batch")
def energy(x, *, model: Model):
    y_ = model(x, None)
    return jax.lax.pmean(model.energy().sum(), "batch"), y_

# @pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=0)
# def forward(example: jax.Array | None = None, *, model: PCDeconvDecoder) -> jax.Array:
#     return model(example=example)


# @pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=0)
# def generate(internal_state: jax.Array, *, model: PCDeconvDecoder) -> jax.Array:
#     return model.generate(internal_state=internal_state)


# @pxf.vmap(pxu.Mask(pxc.VodeParam | pxc.VodeParam.Cache, (None, 0)), in_axes=0, out_axes=(None, 0), axis_name="batch")
# def energy(example: jax.Array, *, model: PCDeconvDecoder) -> jax.Array:
#     y_ = model(example=example)
#     return jax.lax.pmean(model.energy().sum(), "batch"), y_


In [5]:
@pxf.jit(static_argnums=0)
def train_on_batch(
    T: int,
    x: jax.Array,
    y: jax.Array,
    *,
    model: Model,
    optim_w: pxu.Optim,
    optim_h: pxu.Optim
):
    def h_step(i, x, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True
            )(energy)(x, model=model)
        optim_h.step(model, g["model"], True)
        return x, None

    print("Training!")
    model.train()
    
    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, y, model=model)
    
    # Inference steps
    pxf.scan(h_step, xs=jax.numpy.arange(T))(x, model=model, optim_h=optim_h)

    # Learning step
    with pxu.step(model, clear_params=pxc.VodeParam.Cache):
        (e, y_), g = pxf.value_and_grad(pxu.Mask(pxnn.LayerParam, [False, True]), has_aux=True)(energy)(x, model=model)
    optim_w.step(model, g["model"])


def train(dl, T, *, model: Model, optim_w: pxu.Optim, optim_h: pxu.Optim):
    model.vodes[-1].h.frozen = True
    for x, y in tqdm(dl):
        train_on_batch(T, x, y, model=model, optim_w=optim_w, optim_h=optim_h)

In [6]:
@pxf.jit(static_argnums=0)
def eval_on_batch(
    T: int,
    x: jax.Array, 
    *, 
    model: Model,
    optim_h: pxu.Optim
    ):
    def h_step(i, x, *, model, optim_h):
        with pxu.step(model, clear_params=pxc.VodeParam.Cache):
            (e, y_), g = pxf.value_and_grad(
                pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True), [False, True]),
                has_aux=True
            )(energy)(x, model=model)
        optim_h.step(model, g["model"], True)
        return x, None

    print("Evaluation!")  
    model.train()

    if model.vodes[-1].h.frozen:
        print("vode[-1] should not be frozen! set frozen=False before calling eval function.")

    # Init step
    with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
        forward(x, None, model=model)
    
    # Inference steps
    x = pxf.scan(h_step, xs=jax.numpy.arange(T))(x, model=model, optim_h=optim_h)


def gen_imgs(dl, T, *, model: Model, optim_h: pxu.Optim):
    model.vodes[-1].h.frozen = False
    ys_ = []
    es = []
    for x, y in dl:
        eval_on_batch(T, x, model=model, optim_h=optim_h)
        u = forward(x, None, model=model)
        e = model.energy()
        ys_.append(u)
        es.append(e)
    return np.concatenate(ys_, axis=0), e

def tmp_save_imgs(imgs):
    tf = tempfile.NamedTemporaryFile()
    new_folder = False
    while not new_folder:
        try:
            new_folder=True
            os.makedirs("./data"+tf.name+"_")
        except OSError:
            print("ERROR")
            tf = tempfile.NamedTemporaryFile()
            new_folder=False
    for img_idx in range(len(imgs)):
        save_image(imgs[img_idx], "./data"+tf.name+"_"+"/"+str(img_idx)+".png")
    return "./data"+tf.name+"_"


def make_compressed_MNIST_files(test_dataset, data_folder):
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)
    # save test images
    test_img_folder = data_folder + "/mnist_test"
    data, label = list(test_loader)[0]
    images = data.view(-1,28,28)
    images = images/2 + 0.5     # remove normalisation
    os.makedirs(test_img_folder, exist_ok=True)
    for img_idx in tqdm(range(len(images))):
        save_image(images[img_idx], test_img_folder+"/"+str(img_idx)+".png")
    # get and save summary statistics of test images
    compressed_filename = test_img_folder + ".npz"    
    save_stats_mnist(test_img_folder, compressed_filename)    


from scipy.ndimage import zoom
# Assuming `images` is your numpy array of shape (batch, 32, 32)
def resize_images_scipy(images, new_size=(28, 28)):
    batch_size = images.shape[0]
    resized_images = np.zeros((batch_size, *new_size))
    for i in range(batch_size):
        resized_images[i] = zoom(images[i], (new_size[0] / images.shape[1], new_size[1] / images.shape[2]))
    return resized_images

# MCPC evaluation loop for 1D data
def eval(dl, dataset, T, *, model: Model, optim_h: pxu.Optim):
    model.vodes[-1].h.frozen = False

    # check if summary statistics of test dataset used for FID exist
    data_folder = './data'
    if not os.path.exists(data_folder + "/mnist_test.npz") :
        print(data_folder + "/mnist_test" + "does not exist")
        print("Creating compressed MNIST files for faster FID measure ...")
        make_compressed_MNIST_files(dataset, data_folder=data_folder)
 
    # generate images from model
    imgs, e = gen_imgs(dl, T, model=model, optim_h=optim_h)
    imgs = imgs/2 + 0.5     # from -1 -> 1 to 0 -> 1
    imgs = np.clip(imgs, 0, 1).squeeze()
    imgs = resize_images_scipy(imgs)

    # # save images
    img_folder = tmp_save_imgs(torch.tensor(imgs).reshape(-1,28,28))
    # get inceptions score
    is_mean, is_std = get_mnist_inception_score(img_folder)

    # get mnist fid
    fid = fid_mnist(data_folder + "/mnist_test.npz", img_folder, device=torch.device("cpu"), num_workers=0, verbose=False)

    shutil.rmtree(img_folder)

    return is_mean, fid, imgs, e

In [7]:
import optax

batch_size = 10
latent_dim = 128
img_dim = (1,32,32)
model = Model(
    num_layers=2,
    input_dim=(8,8,8),
    output_dim=img_dim,
    kernel_size=5,
    bottleneck_dim=latent_dim,
    act_fn=jax.nn.tanh,
    output_act_fn=lambda x: x,
    channel_last=False,
)

In [8]:
## define noisy sgd optimiser for MCPC
def sgdld(
    learning_rate: base.ScalarOrSchedule,
    momentum: Optional[float] = None,
    h_var: float = 1.0,
    gamma: float = 0.,
    nesterov: bool = False,
    accumulator_dtype: Optional[Any] = None,
    seed: int = 0
) -> base.GradientTransformation:
  eta = 2*h_var*(1-momentum)/learning_rate if momentum is not None else 2*h_var/learning_rate
  return combine.chain(
      transform.add_noise(eta, gamma, seed),
      (transform.trace(decay=momentum, nesterov=nesterov,
                       accumulator_dtype=accumulator_dtype)
       if momentum is not None else base.identity()),
      transform.scale_by_learning_rate(learning_rate)
  )

In [11]:
h_optimiser_fn = sgdld
lr = 0.03
momentum = 0.9
h_var = 1.0
gamma = 0.0
lr_p = 0.007

with pxu.step(model, pxc.STATUS.INIT, clear_params=pxc.VodeParam.Cache):
    forward(None, jax.numpy.zeros((batch_size, *img_dim)), model=model)
    # optim_h = pxu.Optim(h_optimiser_fn(lr, momentum, h_var, gamma), pxu.Mask(pxu.m(pxc.VodeParam).has_not(frozen=True))(model))
    # optim_w = pxu.Optim(optax.adam(lr_p), pxu.Mask(pxnn.LayerParam)(model))
    # # make optimiser that also optimises the activity of the model layer[-1]
    # model.vodes[-1].h.frozen = False
    # forward(jax.numpy.zeros((batch_size, latent_dim)), None, model=model)
    # optim_h_eval = pxu.Optim(h_optimiser_fn(lr, momentum, h_var, gamma), pxu.Mask(pxu.m(pxc.VodeParam))(model))
    # model.vodes[-1].h.frozen = True

ValueError: Tuple arity mismatch: 3 != 2; tuple: (None, Array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       ...,


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32), {'__RKG': (RandomKeyGenerator):
  .key: RKGState([2], uint32), 'model': (Model):
  .layers[0].nn.weight: LayerParam([512,128], float32)
  .layers[0].nn.bias: LayerParam([512], float32)
  .layers[1].nn.weight: LayerParam([4,8,5,5], float32)
  .layers[1].nn.bias: LayerParam([4,1,1], float32)
  .layers[2].nn.weight: LayerParam([1,4,5,5], float32)
  .layers[2].nn.bias: LayerParam([1,1,1], float32)
  .vodes[0].h: VodeParam(None)
  .vodes[0].cache: Cache(params={})
  .vodes[1].h: VodeParam(None)
  .vodes[1].cache: Cache(params={})
  .vodes[2].h: VodeParam(None)
  .vodes[2].cache: Cache(params={})
  .vodes[3].h: VodeParam(None)
  .vodes[3].cache: Cache(params={})
  .input_dim[0]: 8
  .input_dim[1]: 8
  .input_dim[2]: 8}).

In [None]:
# Define the transformation to scale pixels to the range [-1, 1], resize and add channel
transform = transforms.Compose([
    transforms.Resize((32, 32)),                # Resize the image to 32x32 pixels
    transforms.ToTensor(),                      # Convert the image to a PyTorch tensor
    transforms.Normalize((0.5,), (0.5,)),       # Normalize the tensor to the range [-1, 1]
    transforms.Lambda(lambda x: x.unsqueeze(0) if x.dim() == 2 else x) # Ensure the tensor shape is [1 x 32 x 32]
])

# Load the MNIST training dataset
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)

# Load the MNIST test dataset
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
# prep data for unsupervised training
train_dl = DataLoader(train_dataset, batch_size=train_dataset.__len__(), shuffle=True)

data, label = list(train_dl)[0]
nm_elements = len(data)
X = np.zeros((batch_size * (nm_elements // batch_size), latent_dim))
y = data.numpy()[:batch_size * (nm_elements // batch_size)]

nm_elements_test =  1024
X_test = np.zeros((batch_size * (nm_elements_test // batch_size), latent_dim))
y_test = np.zeros((batch_size * (nm_elements_test // batch_size), *img_dim)) # is not used

In [None]:
train_dl = list(zip(X.reshape(-1, batch_size, latent_dim), y.reshape(-1, batch_size, *img_dim)))
test_dl = tuple(zip(X_test.reshape(-1, batch_size, latent_dim), y_test.reshape(-1, batch_size, *img_dim)))

Inception score -> higher is better

FID -> lower is better

In [None]:
nm_epochs = 40

T = 250
T_eval = 10000
# is_, fid, imgs, energies = eval(test_dl, test_dataset, T_eval, model=model, optim_h=optim_h_eval)
# print(f"Epoch {0}/{nm_epochs} - Inception score: {is_ :.2f}, FID score: {fid :.2f}")
for e in range(nm_epochs):
    random.shuffle(train_dl)
    train(train_dl, T=T, model=model, optim_w=optim_w, optim_h=optim_h)
    if e % 10 == 9:
        is_, fid, imgs, energies = eval(test_dl, test_dataset, T_eval, model=model, optim_h=optim_h_eval)
        print(f"Epoch {e}/{nm_epochs} - Inception score: {is_ :.2f}, FID score: {fid :.2f}")

is_, fid, imgs, energies = eval(test_dl, test_dataset, T_eval, model=model, optim_h=optim_h_eval)
print(f"Epoch {e}/{nm_epochs} - Inception score: {is_ :.2f}, FID score: {fid :.2f}")

  0%|                                                                                                                                                                           | 0/6000 [00:00<?, ?it/s]


Training!


TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function _wrap_fn at /data/ndcn-computational-neuroscience/pemb6612/pcax/pcax/functional/_transform.py:252 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['model'].input_dim[0].
The error occurred while tracing the function _wrap_fn at /data/ndcn-computational-neuroscience/pemb6612/pcax/pcax/functional/_transform.py:252 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['model'].input_dim[1].
The error occurred while tracing the function _wrap_fn at /data/ndcn-computational-neuroscience/pemb6612/pcax/pcax/functional/_transform.py:252 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['model'].input_dim[2].