# SimVP Mask Predictor Model

> simvp_mask_predictor_model

In [1]:
#| default_exp mask_simvp

In [2]:
#| export
from openstl.models.simvp_model import SimVP_Model
import torch.nn as nn

In [3]:
#| export
DEFAULT_MODEL_CONFIG = {
    # For MetaVP models, the most important hyperparameters are: 
    # N_S, N_T, hid_S, hid_T, model_type
    'in_shape': [11, 32, 160, 240],
    'hid_S': 64,
    'hid_T': 256,
    'N_S': 4,
    'N_T': 8,
    'model_type': 'gSTA',
}

In [4]:
#| export
class MaskSimVP(nn.Module):
    def __init__(self, in_shape, hid_S, hid_T, N_S, N_T, model_type):
        super().__init__()
        c = in_shape[1]
        self.simvp = SimVP_Model(
            in_shape=in_shape, hid_S=hid_S, 
            hid_T=hid_T, N_S=N_S, N_T=N_T, 
            model_type=model_type)
        self.token_embeddings = nn.Embedding(49, c)
        self.out_conv = nn.Conv2d(c, 49, 1, 1)

    def forward(self, tokens):
        x = self.token_embeddings(tokens)
        x = x.permute(0, 1, 4, 2, 3)
        x = self.simvp(x)

        b, t, c, h, w = x.shape
        x = x.view(b*t, c, h, w)

        x = self.out_conv(x)
        x = x.view(b, t, 49, h, w)
        return x



# Tests

In [5]:
model = MaskSimVP(**DEFAULT_MODEL_CONFIG)

In [6]:
import torch

x = torch.randint(0, 49, (1, 11, 160, 240)).long()
out = model(x)
assert out.shape == (1, 11, 49, 160, 240)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()