In [1]:
import os
import cv2
import json
import torch
import numpy as np
import pandas as pd 
from torch.utils.data import Dataset

In [2]:
dataset_root = os.path.join('datasets', 'ASL-Citizen')
ttv = ['train.csv', 'test.csv', 'val.csv']
train, test, val = [pd.read_csv(os.path.join(dataset_root, i)) for i in ttv]

In [3]:
train.head()

Unnamed: 0,Participant ID,Video file,Gloss,ASL-LEX Code,label,fpath,prep_fpath,kpts_fpath
0,P22,15252109051698337-NOON.mp4,NOON,G_02_014,66,datasets\ASL-Citizen\top100_videos\15252109051...,datasets\ASL-Citizen\preprocess_videos\1525210...,datasets\ASL-Citizen\keypoints\152521090516983...
1,P22,45069896884439653-SAME 2.mp4,SAME,B_02_013,73,datasets\ASL-Citizen\top100_videos\45069896884...,datasets\ASL-Citizen\preprocess_videos\4506989...,datasets\ASL-Citizen\keypoints\450698968844396...
2,P35,9547038026063932-TEXT.mp4,TEXT,B_03_054,93,datasets\ASL-Citizen\top100_videos\95470380260...,datasets\ASL-Citizen\preprocess_videos\9547038...,datasets\ASL-Citizen\keypoints\954703802606393...
3,P27,09898660662683256-BOWL.mp4,BOWL,C_03_049,10,datasets\ASL-Citizen\top100_videos\09898660662...,datasets\ASL-Citizen\preprocess_videos\0989866...,datasets\ASL-Citizen\keypoints\098986606626832...
4,P47,13651840403204663-ELEVATOR.mp4,ELEVATOR,G_03_036,48,datasets\ASL-Citizen\top100_videos\13651840403...,datasets\ASL-Citizen\preprocess_videos\1365184...,datasets\ASL-Citizen\keypoints\136518404032046...


In [23]:
from data_loader import MSASLPreProcessedVideoDataset, plot_video_gif

In [None]:
train_ds = MSASLPreProcessedVideoDataset(
    train.prep_fpath.to_list(),
    train.label.to_list()
    )

valid_ds = MSASLPreProcessedVideoDataset(
    val.prep_fpath.to_list(),
    val.label.to_list()
)

In [27]:
plot_video_gif(train_ds[100][0], 8)

In [28]:
from transformers import AutoConfig, AutoModel, AutoFeatureExtractor
import torch.nn as nn
import torch
from fastai.vision.all import *

class ViViTWrapper(nn.Module):
    def __init__(self, model_name='google/vivit-b-16x2-kinetics400', num_classes=100, freeze_final_blocks=False):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name, num_labels=num_classes)
        self.base_model = AutoModel.from_pretrained(model_name, config=self.config)
        
        # Optional: Unfreeze final blocks
        if freeze_final_blocks:
            for name, param in self.base_model.named_parameters():
                if 'encoder.layer.11' in name or 'encoder.layer.10' in name:
                    param.requires_grad = True
                else:
                    param.requires_grad = False
        else:
            for param in self.base_model.parameters():
                param.requires_grad = False  # Start frozen
            # Optionally keep the final block trainable
            for name, param in self.base_model.named_parameters():
                if 'encoder.layer.11' in name:
                    param.requires_grad = True

        # Classifier head on pooled output
        self.head = nn.Sequential(
            nn.Linear(self.base_model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):  # x shape: (B, C, T, H, W)
        out = self.base_model(pixel_values=x)
        pooled = out.pooler_output  # or out.last_hidden_state[:, 0]
        return self.head(pooled)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from fastai.data.core import DataLoaders


dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=8, shuffle=True, num_workers=4)

In [None]:
learn = Learner(
    dls,
    ViViTWrapper(),
    loss_func=CrossEntropyLossFlat(),
    metrics=accuracy
)

learn.freeze()  # optional
learn.fit_one_cycle(5, 1e-3)
