<a href="https://colab.research.google.com/github/durml91/MMath-Project/blob/duo-branch/Image_Diffusion_(working)/MNIST_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### MLP Mixer

Mixer Layer

Here we haven't introduced a Gelu in the MLP layer. The basic idea is that we will have divided our image into a bunch of nicely disjoint patches that we flatten (the height and width dimension are flattened into a vector). Think cube! We have flattened the image into a vector but we project each patch into a hidden size that acts as our channels (so no we have a square) and then we have a batch of these images, which in this case is the number of patches (and not the batch of images here) which represents the third dimension of the cube. Ok so per patch, we plug a vector (won't be as long as full image given we flatten each patch and not the full image). So we can say we have S non-overlapping/disjoint image patches per image (think of these as your channels now, so if we were to have rgb channels rather than simply grayscale, then I'm pretty certain the S would be 3 x S - check!). We then proceed to project each patch to a hidden dimension (what we label here as the hidden dimnensions). Lets switch S to P to be. So we have num_patches = P and hidden_size = C. P = num_patches should equal (h x w) / patch_size **2. 
Now we have our token mixing block then a channel mixing block. So first we take our patches of dim c x p and get the layer norm which is across all c and p (and NOT the batch i.e. just for a single input) and pass them through the token mixer with input P and output P (i.e. it's a map from $\mathbb{R}^{s} \to \mathbb{R}^{s}$. The output will still be of dimension c p so we transpose it to p x c and put it through a second layer norm and then the token mixer. Once it is out, we can simply transpose the array again to recover the identical dimensions with which we started.

In [None]:
class MixerBlock(eqx.Module):
    patch_mixer: eqx.nn.MLP
    hidden_mixer: eqx.nn.MLP
    norm1: eqx.nn.LayerNorm
    norm2: eqx.nn.LayerNorm

    def __init__(
        self, num_patches, hidden_size, mix_patch_size, mix_hidden_size, *, key
    ):
        tkey, ckey = jr.split(key, 2)
        self.patch_mixer = eqx.nn.MLP(
            num_patches, num_patches, mix_patch_size, depth=1, key=tkey
        )
        self.hidden_mixer = eqx.nn.MLP(
            hidden_size, hidden_size, mix_hidden_size, depth=1, key=ckey
        )
        self.norm1 = eqx.nn.LayerNorm((hidden_size, num_patches))
        self.norm2 = eqx.nn.LayerNorm((num_patches, hidden_size))

    def __call__(self, y):
        y = y + jax.vmap(self.patch_mixer)(self.norm1(y))
        y = einops.rearrange(y, "c p -> p c")
        y = y + jax.vmap(self.hidden_mixer)(self.norm2(y))
        y = einops.rearrange(y, "p c -> c p")
        return y

Actual network

Now for the full mode. Recall hidden_size = C and num_patches = P. Before we start, we get standardise our time relative to the total number of steps of the diffusion process. We then get the shape of our image in order to be able to create an array of t's so that we can concatenate them together along axis 0 which I'm pretty sure is in the channel axis (so in this case I'm pretty sure the number of channels is 2 hence the reason why input_size is + 1 i.e. t acts as anther channel). We first have to get our patches. Our input convolution has spatial_dims set to 2. The in_channel is input_size + 1 (where input_size is the first dim of img_size) and we want the out_channel to have size hidden_size i.e C - makes sense as we will implictly have num_patches by setting the kernel_size and stride. Need to look at in/out channels more. Once we have a patches, we calculate their height and width. Now we can flatten the h x w for each patch and plug into our Mixer Layer to recover dimensionally invariant patches. We can normalise these patches again and the proceed to unflatten each patch. Then finally plug into the upsampler (transpose of the convolution) in order to piece the patches back together.

In [None]:
class Mixer2d(eqx.Module):
    conv_in: eqx.nn.Conv2d
    conv_out: eqx.nn.ConvTranspose2d
    blocks: list
    norm: eqx.nn.LayerNorm
    t1: float

    def __init__(
        self,
        img_size,
        patch_size,
        hidden_size,
        mix_patch_size,
        mix_hidden_size,
        num_blocks,
        t1,
        *,
        key,
    ):
        input_size, height, width = img_size
        assert (height % patch_size) == 0
        assert (width % patch_size) == 0
        num_patches = (height // patch_size) * (width // patch_size)
        inkey, outkey, *bkeys = jr.split(key, 2 + num_blocks)

        self.conv_in = eqx.nn.Conv2d(
            input_size + 1, hidden_size, patch_size, stride=patch_size, key=inkey
        )
        self.conv_out = eqx.nn.ConvTranspose2d(
            hidden_size, input_size, patch_size, stride=patch_size, key=outkey
        )
        self.blocks = [
            MixerBlock(
                num_patches, hidden_size, mix_patch_size, mix_hidden_size, key=bkey
            )
            for bkey in bkeys
        ]
        self.norm = eqx.nn.LayerNorm((hidden_size, num_patches))
        self.t1 = t1

    def __call__(self, t, y):
        t = t / self.t1
        _, height, width = y.shape
        t = einops.repeat(t, "-> 1 h w", h=height, w=width)
        y = jnp.concatenate([y, t])
        y = self.conv_in(y)
        _, patch_height, patch_width = y.shape
        y = einops.rearrange(y, "c h w -> c (h w)")
        for block in self.blocks:
            y = block(y)
        y = self.norm(y)
        y = einops.rearrange(y, "c (h w) -> c h w", h=patch_height, w=patch_width)
        return self.conv_out(y)

### Loss

The single loss is per image. It's that weird rearranged loss function (form the DDPMs stuff) and uses the reparamterisation the data to then plug that into the model. We then take the mean as this bit is the expectation/monte carlo estimate over the random gaussian (noise). We also multiplty by some weight function.

In [None]:
def single_loss_fn(model, weight, int_beta, data, t, key):
    mean = data * jnp.exp(-0.5 * int_beta(t))
    var = jnp.maximum(1 - jnp.exp(-int_beta(t)), 1e-5)
    std = jnp.sqrt(var)
    noise = jr.normal(key, data.shape)
    y = mean + std * noise
    pred = model(t, y)
    return weight(t) * jnp.mean((pred + noise / std) ** 2)

The following is the loss over a batch. Here we don't compute two means (recall there are three means to compute in the loss). This is due to simply being able to calculate an array of uniform value from 0 to t1. We then compute the single loss function over the batch of data and take the mean (this is the monte carlo esimate for the randomly sampled data).

In [None]:
def batch_loss_fn(model, weight, int_beta, data, t1, key):
    batch_size = data.shape[0]
    tkey, losskey = jr.split(key)
    losskey = jr.split(losskey, batch_size)
    # Low-discrepancy sampling over t to reduce variance
    t = jr.uniform(tkey, (batch_size,), minval=0, maxval=t1 / batch_size)
    t = t + (t1 / batch_size) * jnp.arange(batch_size)
    loss_fn = ft.partial(single_loss_fn, model, weight, int_beta)
    loss_fn = jax.vmap(loss_fn)
    return jnp.mean(loss_fn(data, t, losskey))

### Update step

Recall that equinox works slightly differently to flax in that the pytrees structure doesn't require you to separate the parameters from the model. Here we work out the value/loss and grad. Basically loss_fn computes the average loss, spits this out, and backpropogates the error through the network to the grads of the loss w.r.t. to the weights and biases. We then call opt_update on the grads; here opt_state is the optimiser (e.g. sgd or adam). So we give opt_update the grads and the optimiser and is gives back the optimiser and the updated weights which we then apply to the model. We then return another key (not sure why?)

In [None]:
def make_step(model, weight, int_beta, data, t1, key, opt_state, opt_update):
    loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
    loss, grads = loss_fn(model, weight, int_beta, data, t1, key)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    key = jr.split(key, 1)[0]
    return loss, model, key, opt_state

### ODE solver

This is the ODE solver using the diffrax package. We define the drift term - weird function but it basically spits out a vector of int_beta applied to t in the right dimension i.e. it's the vector of time and we apply beta_t to each one. We use the Tsit5() solver. We sample from a standard Gaussian in the right dimnesions and then sample backwards from t1 to t0, using the discretised timesteps dt0 (but negative given we are going the other direction)

If the SDE sampler doesn't work this is precisely why. t comes into that function as a vector as it is the entire discretisation step applied at once. If you want beta_t to be multi dim, you need increase the dimension of beta in another direction i.e. 28 in new dimnesion on top of the new beta (so it's actually just a matrix, with the timesteps in one dim and just repeated values of that specific timestep in another as we need it running in the right dimension).

In [None]:
def single_sample_fn(model, int_beta, data_shape, dt0, t1, key):
    def drift(t, y, args):
        _, beta = jax.jvp(int_beta, (t,), (jnp.ones_like(t),))
        return -0.5 * beta * (y + model(t, y))

    term = dfx.ODETerm(drift)
    solver = dfx.Tsit5()
    t0 = 0
    y1 = jr.normal(key, data_shape)
    # reverse time, solve from t1 to t0
    sol = dfx.diffeqsolve(term, solver, t1, t0, -dt0, y1, adjoint=dfx.NoAdjoint())
    return sol.ys[0]

Data import

Nothing much to say except we specify the data dimension here.

In [None]:
def mnist():
    filename = "train-images-idx3-ubyte.gz"
    url_dir = "https://storage.googleapis.com/cvdf-datasets/mnist"
    target_dir = os.getcwd() + "/data/mnist"
    url = f"{url_dir}/{filename}"
    target = f"{target_dir}/{filename}"

    if not os.path.exists(target):
        os.makedirs(target_dir, exist_ok=True)
        urllib.request.urlretrieve(url, target)
        print(f"Downloaded {url} to {target}")

    with gzip.open(target, "rb") as fh:
        _, batch, rows, cols = struct.unpack(">IIII", fh.read(16))
        shape = (batch, 1, rows, cols)
        return jnp.array(array.array("B", fh.read()), dtype=jnp.uint8).reshape(shape)

Data batch randomiser

The data loader shuffles the data into randomly permuted batches. We get the indices of the dataset, permute them and extract a randomy batch form the random array of indices (number same as batch_size). Yield is this weird generator (vs iterator) function but basically think of it as return. However, here we can now set the end to where we left off and aim to end at the next batch. All pretty simple (just remember the code only passes through the first while column once whilst it keeps iterating over the nested while loop until finished).

In [None]:
def dataloader(data, batch_size, *, key):
    dataset_size = data.shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield data[batch_perm]
            start = end
            end = start + batch_size

### Main function

Here we input all of the parameters and hyperparameters that we need. We load the data but we also remember to standardise the data (so the value are in $[0,1]$). We initiate the model and optimisers. We also define the weight function and the beta_t function. We then train the model for the number of epochs specified. We then sample the data. We then "de-standardise" it. We then clip it (which means ensure all the data is in between the max and min values specified

In [None]:
def main(
    # Model hyperparameters
    patch_size=4,
    hidden_size=64,
    mix_patch_size=512,
    mix_hidden_size=512,
    num_blocks=4,
    t1=10.0,
    # Optimisation hyperparameters
    num_steps=1_000_000,
    lr=3e-4,
    batch_size=256,
    print_every=10_000,
    # Sampling hyperparameters
    dt0=0.1,
    sample_size=10,
    # Seed
    seed=5678,
):
    key = jr.PRNGKey(seed)
    model_key, train_key, loader_key, sample_key = jr.split(key, 4)
    data = mnist()
    data_mean = jnp.mean(data)
    data_std = jnp.std(data)
    data_max = jnp.max(data)
    data_min = jnp.min(data)
    data_shape = data.shape[1:]
    data = (data - data_mean) / data_std

    model = Mixer2d(
        data_shape,
        patch_size,
        hidden_size,
        mix_patch_size,
        mix_hidden_size,
        num_blocks,
        t1,
        key=model_key,
    )
    int_beta = lambda t: t  # Try experimenting with other options here!
    weight = lambda t: 1 - jnp.exp(
        -int_beta(t)
    )  # Just chosen to upweight the region near t=0.

    opt = optax.adabelief(lr)
    # Optax will update the floating-point JAX arrays in the model.
    opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

    total_value = 0
    total_size = 0
    for step, data in zip(
        range(num_steps), dataloader(data, batch_size, key=loader_key)
    ):
        value, model, train_key, opt_state = make_step(
            model, weight, int_beta, data, t1, train_key, opt_state, opt.update
        )
        total_value += value.item()
        total_size += 1
        if (step % print_every) == 0 or step == num_steps - 1:
            print(f"Step={step} Loss={total_value / total_size}")
            total_value = 0
            total_size = 0

    sample_key = jr.split(sample_key, sample_size**2)
    sample_fn = ft.partial(single_sample_fn, model, int_beta, data_shape, dt0, t1)
    sample = jax.vmap(sample_fn)(sample_key)
    sample = data_mean + data_std * sample
    sample = jnp.clip(sample, data_min, data_max)
    sample = einops.rearrange(
        sample, "(n1 n2) 1 h w -> (n1 h) (n2 w)", n1=sample_size, n2=sample_size
    )
    plt.imshow(sample, cmap="Greys")
    plt.axis("off")
    plt.tight_layout()
    plt.show()