# SimVP Dataset

> dataset for SimVP

In [None]:
#| default_exp simvp_dataset

In [None]:
#| export
from torch.utils.data import Dataset
import torch
import os
from torchvision import transforms
from maskpredformer.vis_utils import show_video_line

In [None]:
#| export
DEFAULT_DATA_PATH = "/home/enes/dev/maskpredformer/data/DL"

In [None]:
#| export
class DLDataset(Dataset):
    def __init__(self, root, mode, unlabeled=False):
        self.mask_path = os.path.join(root, f"{mode}_masks.pt")
        self.mode = mode
        print("INFO: Loading masks from", self.mask_path)
        if unlabeled:
            self.masks = torch.cat([
                torch.load(self.mask_path), 
                torch.load(os.path.join(root, f"unlabeled_masks.pt")).squeeze()
            ], dim=0)
        else:
            self.masks = torch.load(self.mask_path)
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
        ])

    def __len__(self):
        return self.masks.shape[0]
    
    def __getitem__(self, idx):
        if self.mode == "train":
            ep = self.transform(self.masks[idx])
        else:
            ep = self.masks[idx]
        data = ep[:11].long()
        labels = ep[11:].long()
        return data, labels

**Test dataset**

In [None]:
dataset = DLDataset('../data/DL', 'val')

In [None]:
x, y = dataset[0]

In [None]:
x.shape, y.shape

In [None]:
show_video_line(x, len(x))

In [None]:
show_video_line(y, len(y))

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()