# Train predictor using autoregressive loss

In [None]:
#| default_exp autoregressive_trainer

In [None]:
#| export
import lightning as pl
import torch
from maskpredformer.mask_simvp import MaskSimVP
from maskpredformer.simvp_dataset import DLDataset

In [None]:
#| export
class MaskSimVPAutoRegressiveModule(pl.LightningModule):
    def __init__(self, in_shape, hid_S, hid_T, N_S, N_T, model_type,
                 batch_size, lr, weight_decay, max_epochs,
                 data_root, pre_seq_len=11, aft_seq_len=1,
                 drop_path=0.0, unlabeled=False, downsample=False):
        super().__init__()
        self.save_hyperparameters()
        self.model = MaskSimVP(
            in_shape, hid_S, hid_T, N_S, N_T, model_type, downsample=downsample, drop_path=drop_path,
            pre_seq_len=pre_seq_len, aft_seq_len=aft_seq_len
        )
        self.train_set = DLDataset(data_root, "train", unlabeled=unlabeled, pre_seq_len=11, aft_seq_len=11)
        self.val_set = DLDataset(data_root, "val", pre_seq_len=11, aft_seq_len=11)
        self.criterion = torch.nn.CrossEntropyLoss()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=True, pin_memory=True
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_set, batch_size=self.hparams.batch_size, 
            num_workers=8, shuffle=False, pin_memory=True
        )

    def calculate_loss(self, logits, target):
        b, t, *_ = logits.shape
        logits = logits.view(b*t, *logits.shape[2:])
        target = target.view(b*t, *target.shape[2:])
        loss = self.criterion(logits, target)
        return loss
    
    def step(self, x, y):
        y_hat_logits = []
        cur_seq = x.clone()
        for _ in range(11):
            y_hat_logit_t = self.model(cur_seq)
            y_hat_logits.append(y_hat_logit_t) # get logits for backprop
            y_hat = torch.argmax(y_hat_logit_t, dim=2) # get current prediction
            cur_seq = torch.cat([cur_seq[:, 1:], y_hat], dim=1) # autoregressive concatenation
            
        y_hat_logits = torch.stack(y_hat_logits)
        import pdb; pdb.set_trace()
        # calculate loss
        loss = self.calculate_loss(y_hat_logits, y)
        return loss, y_hat_logits, cur_seq

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, _, _ = self.step(x, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss, _, _ = self.step(x, y)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.hparams.lr, 
            weight_decay=self.hparams.weight_decay
        )
        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=self.hparams.lr,
            total_steps=self.hparams.max_epochs*len(self.train_dataloader()),
            final_div_factor=1e4
        )
        opt_dict = {
            "optimizer": optimizer,
            "lr_scheduler":{
                "scheduler": lr_scheduler,
                "interval": "step",
                "frequency": 1
            } 
        }

        return opt_dict
        

**Test out the MaskSimVPAutoRegressive Module**

In [None]:
%cd ..

In [None]:
mask_sim_vp_ckpt = torch.load("checkpoints/simvp_epoch=13-val_loss=0.015.ckpt")

In [None]:
autoregressive_params = mask_sim_vp_ckpt['hyper_parameters']
autoregressive_params['unlabeled'] = False

In [None]:
pl_module = MaskSimVPAutoRegressiveModule(**autoregressive_params)
pl_module.load_state_dict(mask_sim_vp_ckpt["state_dict"])

In [None]:
def test_prior_model_results():
    x, y = pl_module.val_set[0]
    x=x.unsqueeze(0).to(pl_module.device); y=y.unsqueeze(0).to(pl_module.device)
    loss, y_hat_logits, cur_seq = pl_module.step(x, y)
    print(cur_seq.shape)
test_prior_model_results()