In [1]:
import os

import torch

from cs330_project.datasets.video_data import TinyVIRAT
from cs330_project.models import ViTAutoEncoder
from cs330_project.datasets.data_loading import MaskedVideoAutoencoderTransform, TransformDataset, DataLoader
from cs330_project.training import train_mae_single_epoch, make_optimizer, make_scheduler
from cs330_project.losses import autoencoder_loss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
img_size = (32, 32)
num_frames = 16
num_channels = 3
patch_size = 4
tublet_size = 2
sampling_rate = 4

In [3]:
# root_dir = r"D:\tiny_virat_composite_dataset"
root_dir = r"C:\Users\Windows\Desktop\Shahir\cs330-final-project-2022\resources\tiny_virat_processed"
dataset_train_orig = TinyVIRAT(
    root_dir=root_dir,
    train=True,
    new_length=num_frames,
    new_step=sampling_rate,
    temporal_jitter=False,
    verbose=False)

In [4]:
model = ViTAutoEncoder(
    in_img_size=img_size,
    in_channels=num_channels,
    patch_size=patch_size,
    spatio_temporal=True,
    tubelet_size=2,
    in_num_frames=num_frames,
    encoder_embed_dim=96,
    encoder_depth=9,
    encoder_num_heads=12,
    decoder_embed_dim=48,
    decoder_depth=3,
    decoder_num_heads=16,
    mlp_dim_ratio=2,
    head_dim=16,
    class_embed=True,
    is_spt=True,
    is_lsa=False)

In [5]:
train_transform = MaskedVideoAutoencoderTransform(
    input_size=img_size,
    num_patches=model.encoder.num_patches)
dataset_train = TransformDataset(
    dataset_train_orig,
    labeled=True,
    transform_func=train_transform)

In [6]:
optimizer = make_optimizer(model)
scheduler = make_scheduler(optimizer)

In [7]:
dataloader_train = DataLoader(
    dataset_train,
    batch_size=20,
    num_workers=20,
    pin_memory=True,
    prefetch_factor=10)

In [None]:
train_mae_single_epoch(model, autoencoder_loss, dataloader_train, optimizer, device)

  0%|          | 0/384 [00:39<?, ?it/s]