# Testing Notebook

## Imports

In [None]:
# %load ~/dev/marthaler/header.py
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np

from dataclasses import dataclass
from typing import Dict, Any, Tuple
from datetime import datetime, timedelta

import warnings

plt.style.use("ggplot")
warnings.filterwarnings('ignore')

from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()

plt.rc('figure', figsize=(16, 10))
plt.rc('font', size=14)


In [None]:
import math

import equinox as eqx
import jax
from jax import vmap
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax
import torch

In [None]:
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping

In [None]:
from typing import Sequence, Union

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

## Functions

In [None]:
def shift_image(
    img: np.ndarray[np.uint8],
    horizontal_shift: int,
    vertical_shift: int,
)->np.ndarray[np.uint8]:
    """
      Function to shift an image by fixed numbers of pixels and fill in with 0.
      
      The image origin in numpy are at the top-left corner so to get a positive
      vertical shift, we need to negate the vertical shift value passed in.
      
      :param img: The input image
      :type img: np.ndarray[np.uint8]
      :param horizontal_shift: the number of pixels to shift horizontally
      :type horizontal_shift: int
      :param vertical_shift: the number of pixels to shift vertically
      :type vertical_shift: int
      :return: The shifted image
      :rtype: np.ndarray[np.uint8]
    """
    # Negate the vertical shift to compensate for the origin at the top-left of image
    vertical_shift = -vertical_shift
    shift_img = np.roll(img, vertical_shift, axis=0)
    shift_img = np.roll(shift_img, horizontal_shift, axis=1)
    if vertical_shift>0:
        shift_img[:vertical_shift, :] = 0
    elif vertical_shift<0:
        shift_img[vertical_shift:, :] = 0
    if horizontal_shift>0:
        shift_img[:, :horizontal_shift] = 0
    elif horizontal_shift<0:
        shift_img[:, horizontal_shift:] = 0
    return shift_img

## Load Data

In [None]:
# model parameters
in_shape=[10, 1, 64, 64]
spatio_kernel_enc = 3
spatio_kernel_dec = 3
model_type = 'gSTA'
hid_S = 64
hid_T = 512
N_T = 8
N_S = 2
# training
lr = 1e-3
SEED = 42
batch_size = 16
val_batch_size=16
num_workers=8
drop_path = 0
sched = 'onecycle'

In [None]:
# Training parameters
epochs=1000
log_step=1
lr=0.01

In [None]:
imgs = np.load('/Users/daniel.marthaler/dev/SimVP/data/moving_mnist/mnist_test_seq.npy')
from dataloader_moving_mnist import load_data

In [None]:
train_loader, vali_loader, test_loader, data_mean, data_std = load_data(
    batch_size,
    val_batch_size,
    '/Users/daniel.marthaler/dev/SimVP/data',
    num_workers
)

In [None]:
x_batch, y_batch = next(iter(train_loader))
x_batch = x_batch.numpy()
y_batch = y_batch.numpy()

In [None]:
from einops import rearrange
from functools import partial

#x = jnp.arange(5 * 8 * 8 * 3).reshape(5, 8, 8, 3)
#pixel_shuffle(x) = print(x.shape) # (5, 4, 4, 12)
# pixel_unshuffle(pixel_shuffle(x)) = print(x.shape) # (5, 8, 8, 3)

class PixelShuffle(eqx.Module):
    scale_factor: int
    layer: partial

    def __init__(self, scale_factor: int)->None:
        self.scale_factor = scale_factor
        self.layer = partial(
            rearrange,
            pattern='... (c b1 b2) h w -> ... c (h b1) (w b2)',
            b1=self.scale_factor,
            b2=self.scale_factor
        )

    def __call__(self, x: Array, key: jax.random.PRNGKey=None) -> Array:
        return self.layer(x)


In [None]:
class BasicConv2d(eqx.Module):
    act_norm: bool
    conv: list
    norm: eqx.nn.GroupNorm
    act: jax.nn.silu

    def __init__(
        self,
        key: jax.random.PRNGKey,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]]=3,
        stride: Union[int, Sequence[int]]=1,
        padding: Union[str, int, Sequence[int]]=0,
        dilation: Union[str, int, Sequence[int]]=1,
        upsampling: bool=False,
        act_norm: bool=False,
        act_inplace: bool=True,
    )->None:
        super(BasicConv2d, self).__init__()
        self.act_norm = act_norm
        if upsampling is True:
            self.conv = eqx.nn.Sequential([*[
                eqx.nn.Conv2d(in_channels, out_channels*4, kernel_size=kernel_size,
                          stride=stride, padding=padding, dilation=dilation, key=key),
                PixelShuffle(2)
            ]])
        else:
            self.conv = eqx.nn.Conv2d(
                in_channels, out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation,key=key)
            
        self.norm = eqx.nn.GroupNorm(2, out_channels)
        self.act = jax.nn.silu
        
    def __call__(self, x):
        y = self.conv(x)
        if self.act_norm:
            y = self.act(self.norm(y))
        return y


In [None]:
class ConvSC(eqx.Module):
    conv: eqx.nn.Conv
    def __init__(
        self,
        key: jax.random.PRNGKey,
        C_in: int,
        C_out: int,
        kernel_size: int=3,
        downsampling=False,
        upsampling=False,
        act_norm: bool=True
    )->None:
        super(ConvSC, self).__init__()

        stride = 2 if downsampling is True else 1
        padding = (kernel_size - stride + 1) // 2

        self.conv = BasicConv2d(
            key,
            C_in, 
            C_out, 
            kernel_size=kernel_size, 
            stride=stride,
            upsampling=upsampling,
            padding=padding,
            act_norm=act_norm
        )

    def __call__(self, x):
        y = self.conv(x)
        return y

In [None]:
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping
def sampling_generator(N:int, reverse:bool=False)->list:
    samplings = [False, True] * (N // 2)
    if reverse: 
        return list(reversed(samplings[:N]))
    else: 
        return samplings[:N]

class Encoder(eqx.Module):
    """3D Encoder for SimVP"""
    enc: list

    def __init__(
        self,
        key: jax.random.PRNGKey,
        C_in: int,
        C_hid: int, 
        N_S: int, 
        spatio_kernel: int,
    )->None:
        super(Encoder, self).__init__()
        samplings = sampling_generator(N_S)
        keys = jax.random.split(key, N_S)
        self.enc = eqx.nn.Sequential([
              ConvSC(keys[0], C_in, C_hid, spatio_kernel, downsampling=samplings[0],),
            *[ConvSC(k,C_hid, C_hid, spatio_kernel, downsampling=s) for k,s in zip(keys[1:],samplings[1:])]
        ]
        )
        
    def __call__(self, x: Array)->Array:
        enc1 = self.enc[0](x)
        latent = enc1
        for i in range(1, len(self.enc)):
            latent = self.enc[i](latent)
        return latent, enc1

In [None]:
key = jrandom.PRNGKey(SEED) 

In [None]:
in_shape

In [None]:
T, C, H, W = tuple(in_shape)
enc = Encoder(key, C, hid_S, N_S, spatio_kernel_enc)

In [None]:
B, T, C, H, W = x_batch.shape
x = x_batch.reshape(B*T, C, H, W)

In [None]:
embed, skip = vmap(enc)(x)

In [None]:
x.shape, embed.shape, skip.shape

In [None]:
class Decoder(eqx.Module):
    """3D Decoder for SimVP"""
    readout: eqx.nn.Conv2d
    dec: list

    def __init__(
        self,
        key: jax.random.PRNGKey,
        C_hid: int,
        C_out: int, 
        N_S: int,
        spatio_kernel: int,
    )->None:
        super(Decoder, self).__init__()
        samplings = sampling_generator(N_S, reverse=True)
        keys = jax.random.split(key, N_S+1)
        self.dec = eqx.nn.Sequential([
            *[ConvSC(k, C_hid, C_hid, spatio_kernel, upsampling=s,) for k,s in zip(keys[:-2],samplings[:-1])],
              ConvSC(keys[-2],C_hid, C_hid, spatio_kernel, upsampling=samplings[-1])
        ])

        self.readout = eqx.nn.Conv2d(C_hid, C_out, 1, key=keys[-1],)

    def __call__(self, hid, enc1=None):
        for i in range(0, len(self.dec)-1):
            hid = self.dec[i](hid)
        Y = self.dec[-1](hid+enc1)
        return self.readout(Y)

In [86]:
dec = Decoder(key, hid_S, C, N_S, spatio_kernel_dec)

In [91]:
y_pred = vmap(dec)(embed,skip)

In [94]:
y_pred.shape, x.shape

((160, 1, 64, 64), (160, 1, 64, 64))

In [100]:
class SimVP(eqx.Module):
    enc: Encoder
    dec: Decoder
    def __init__(
        self,
        key: jax.random.PRNGKey,
        in_shape: Tuple,
        hid_S: int=16,
        hid_T: int=256,
        N_S: int=4,
        N_T: int=4,
        model_type: str='gSTA',
        spatio_kernel_enc: int=3,
        spatio_kernel_dec: int=3,
    )->None:
        super(SimVP, self).__init__()
        T, C, H, W = in_shape
        keys = jax.random.split(key, 2)
        self.enc = Encoder(keys[0], C, hid_S, N_S, spatio_kernel_enc)
        self.dec = Decoder(keys[1], hid_S, C, N_S, spatio_kernel_dec)


    def __call__(self, x_raw: Array)->Array:
        B, T, C, H, W = x_raw.shape
        x = x_raw.reshape(B*T, C, H, W)
        embed, skip = vmap(self.enc)(x)
        return vmap(self.dec)(embed, skip).reshape(B, T, C, H, W)

## Parameters

In [147]:
seed = 42
batch_size = 16
val_batch_size=16
num_workers=8

# model parameters
in_shape=[10, 1, 64, 64]  
hid_S=64
hid_T=256
num_layers=6
N_T=8
groups=4

# Training parameters
epochs=1000
log_step=1
lr=0.01

In [149]:
# Hyperparameters

BATCH_SIZE = 64
LEARNING_RATE = 3e-4
TOTAL_STEPS = 1500
PRINT_EVERY = 30
SEED = 42

key = jax.random.PRNGKey(SEED)

## Model

In [150]:
model = SimVP(key, tuple(in_shape), hid_S, hid_T, num_layers, N_T, spatio_kernel_enc=spatio_kernel_enc, spatio_kernel_dec=spatio_kernel_dec)

In [151]:
def compute_loss(model, x, y):
    pred_y = model(x)
    # Trains with respect to huber loss
    #return optax.losses.huber_loss(pred_y, y).sum()
    return optax.losses.l2_loss(pred_y, y).sum()

loss = eqx.filter_jit(compute_loss)

In [152]:
def evaluate(model: SimVP, testloader: torch.utils.data.DataLoader):
    """This function evaluates the model on the test dataset,
    computing both the average loss and the average accuracy.
    """
    avg_loss = 0
    for x, y in testloader:
        x = x.numpy()
        y = y.numpy()
        # Note that all the JAX operations happen inside `loss` ,
        # and both have JIT wrappers, so this is fast.
        avg_loss += loss(model, x, y)
    return avg_loss / len(testloader)

In [153]:
cosine_decay_scheduler = optax.cosine_decay_schedule(LEARNING_RATE, decay_steps=TOTAL_STEPS, alpha=0.95)

In [154]:
optim = optax.adamw(learning_rate=cosine_decay_scheduler)

In [155]:
def train(
    model: SimVP,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    steps: int,
    print_every: int,
) -> SimVP:
    # Just like earlier: It only makes sense to train the arrays in our model,
    # so filter out everything else.
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    # Always wrap everything -- computing gradients, running the optimiser, updating
    # the model -- into a single JIT region. This ensures things run as fast as
    # possible.
    @eqx.filter_jit
    def make_step(
        model: SimVP,
        opt_state: PyTree,
        x: Float[Array, "batch 10 1 64 64"],
        y: Float[Array, "batch 10 1 64 64"],
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    # Loop over our training dataset as many times as we need.
    def infinite_trainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(steps), infinite_trainloader()):
        # PyTorch dataloaders give PyTorch tensors by default,
        # so convert them to NumPy arrays.
        x = x.numpy()
        y = y.numpy()
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
        if (step % print_every) == 0 or (step == steps - 1):
            test_loss = evaluate(model, testloader)
            print(
                f"{step}, train_loss={train_loss.item()}, test_loss={test_loss.item()}"
            )
    return model

In [None]:
model = train(model, train_loader, test_loader, optim, TOTAL_STEPS, PRINT_EVERY)

0, train_loss=66819.3203125, test_loss=78841.9609375
30, train_loss=12492.623046875, test_loss=12646.6259765625
60, train_loss=12329.431640625, test_loss=12379.294921875
90, train_loss=11546.8603515625, test_loss=12329.7177734375
120, train_loss=13075.142578125, test_loss=12279.2744140625
150, train_loss=11741.0400390625, test_loss=12291.4951171875
180, train_loss=12274.7685546875, test_loss=12248.453125
210, train_loss=12309.384765625, test_loss=12240.2294921875
240, train_loss=12319.455078125, test_loss=12233.90234375
270, train_loss=12109.650390625, test_loss=12215.8271484375
300, train_loss=11715.0361328125, test_loss=12208.9091796875
330, train_loss=12492.328125, test_loss=12219.3857421875
360, train_loss=11388.431640625, test_loss=12201.052734375
390, train_loss=12307.001953125, test_loss=12209.7314453125
420, train_loss=10747.2216796875, test_loss=12249.8671875
450, train_loss=11461.69140625, test_loss=12235.90234375
480, train_loss=11432.734375, test_loss=12177.7744140625
510, 

In [None]:
y_pred = model(x_batch)

In [None]:
plt.imshow(y_pred[0,0,0,:,:])

In [None]:
plt.imshow(y_batch[0,0,0,:,:])