In [1]:
import os

import torch
from torch import nn

from cs330_project.datasets.video_data import TinyVIRAT
from cs330_project.models import ViTAutoEncoder, ViTClassifier
from cs330_project.datasets.data_loading import MaskedVideoAutoencoderTransform, VideoAugmentTransform, TransformDataset, DataLoader
from cs330_project.training import train_classifier_model, make_optimizer, make_scheduler
from cs330_project.utils import get_rel_pkg_path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
img_size = (32, 32)
num_frames = 16
num_channels = 3
patch_size = 8
tubelet_size = 4
sampling_rate = 4

In [3]:
# root_dir = r"D:\tiny_virat_composite_dataset"
root_dir = r"C:\Users\Windows\Desktop\Shahir\cs330-final-project-2022\resources\tiny_virat_processed"
dataset_train_orig = TinyVIRAT(
    root_dir=root_dir,
    train=True,
    new_length=num_frames,
    new_step=sampling_rate,
    temporal_jitter=False,
    verbose=False)
dataset_test_orig = TinyVIRAT(
    root_dir=root_dir,
    train=False,
    new_length=num_frames,
    new_step=sampling_rate,
    temporal_jitter=False,
    verbose=False)

In [4]:
weights_fname = "C:\\Users\\Windows\\Desktop\\Shahir\\cs330-final-project-2022\\weights\\Experiment 12-13-2022 01-33-21 AM\\Weights Latest.pckl"

In [5]:
vmae_model = ViTAutoEncoder(
    in_img_size=img_size,
    in_channels=num_channels,
    patch_size=patch_size,
    spatio_temporal=True,
    tubelet_size=tubelet_size,
    in_num_frames=num_frames,
    encoder_embed_dim=96,
    encoder_depth=6,
    encoder_num_heads=8,
    decoder_embed_dim=48,
    decoder_depth=3,
    decoder_num_heads=8,
    mlp_dim_ratio=2,
    head_dim=16,
    class_embed=True,
    is_spt=True,
    is_lsa=False,
    use_masking=True)
vmae_model = vmae_model.to(device)
vmae_model.load_state_dict(torch.load(weights_fname))

RuntimeError: Error(s) in loading state_dict for ViTAutoEncoder:
	Missing key(s) in state_dict: "mask_token", "decoder.mask_token". 
	size mismatch for decoder.pos_embedding: copying a param with shape torch.Size([1, 65, 48]) from checkpoint, the shape in current model is torch.Size([1, 66, 48]).

In [None]:
model = ViTClassifier(
    in_img_size=img_size,
    in_channels=num_channels,
    patch_size=patch_size,
    spatio_temporal=True,
    tubelet_size=tubelet_size,
    num_classes=26,
    in_num_frames=num_frames,
    encoder_embed_dim=96,
    encoder_depth=6,
    encoder_num_heads=8,
    mlp_dim_ratio=2,
    head_dim=16,
    class_embed=True,
    is_spt=True,
    is_lsa=False,
    use_masking=True)
model = model.to(device)
model.encoder.load_state_dict(vmae_model.encoder.state_dict())

In [None]:
train_transform = MaskedVideoAutoencoderTransform(
    input_size=img_size,
    num_patches=model.encoder.num_patches,
    mask_ratio=0.0)
test_transform = MaskedVideoAutoencoderTransform(
    input_size=img_size,
    num_patches=model.encoder.num_patches,
    crop_type=VideoAugmentTransform.CROP_TYPE_CENTER,
    mask_ratio=0.0)
dataset_train = TransformDataset(
    dataset_train_orig,
    labeled=True,
    transform_func=train_transform)
dataset_test = TransformDataset(
    dataset_test_orig,
    labeled=True,
    transform_func=test_transform)

In [None]:
optimizer = make_optimizer(model)
scheduler = make_scheduler(optimizer)

In [None]:
dataloader_train = DataLoader(
    dataset_train,
    batch_size=20,
    num_workers=20,
    pin_memory=True,
    prefetch_factor=10,
    persistent_workers=True)
dataloader_test = DataLoader(
    dataset_test,
    batch_size=20,
    num_workers=20,
    pin_memory=True,
    prefetch_factor=10,
    persistent_workers=True)

In [None]:
weights_dir = get_rel_pkg_path("weights/")

In [None]:
tracker = train_classifier_model(
    device,
    model,
    dataloader_train,
    dataloader_test,
    nn.CrossEntropyLoss(),
    optimizer,
    weights_dir,
    num_epochs=10,
    save_model=True,
    save_latest=True,
    save_log=True)

In [None]:
tracker.save_dir

In [None]:
for x in dataloader_train:
    break

In [None]:
x[1].shape