Colab version to train on T4 GPU:
[https://colab.research.google.com/drive/1TRE6DIhLm5P6k1bdFxXlRpOvzfalH4zC#scrollTo=xfosNdAaoWeH](https://colab.research.google.com/drive/1TRE6DIhLm5P6k1bdFxXlRpOvzfalH4zC#scrollTo=xfosNdAaoWeH)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import av

from tqdm import tqdm

from IPython.display import HTML

from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
from bayesian_torch.layers.variational_layers.linear_variational import LinearReparameterization


### Load Dataset

In [2]:
from utils.dataset import load_msasl

label_threshold = 100
test_dataset, train_dataset, validation_dataset = load_msasl('bin', label_threshold)

[TRAIN] Loaded 3012 videos with top 100 labels
[TEST] Loaded 458 videos with top 100 labels
[VALIDATION] Loaded 815 videos with top 100 labels


In [3]:

from torch.utils.data import DataLoader

data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
for videos, labels, metadata in data_loader:
        print(f"Batch of videos: {videos.shape}") # (batch_size, 64, C, H, W)
        print(f"Batch of labels: {labels.shape}") # (batch_size,)
        print(f"Metadata sample: {metadata}") # Dictionary of metadata
        print(labels)
        break # Checking the first batch

test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=4)



validation_loader = DataLoader(validation_dataset, batch_size=8, shuffle=True, num_workers=4)

for videos, labels, metadata in validation_loader:
        print(f"Batch of videos: {videos.shape}") # (B, T, C, H, W)
        print(f"Batch of labels: {labels.shape}") # (batch_size,)
        print(f"Metadata sample: {metadata}") # Dictionary of metadata
        print(labels)
        break # Checking the first batch

Batch of videos: torch.Size([8, 64, 3, 224, 224])
Batch of labels: torch.Size([8])
Metadata sample: {'id': ['5c4907d5-07f9-4dd5-8d76-6e42e8d0f2e7', '8e933be6-40d2-4cce-b9e4-22751b1a6de4', '87926ff3-a734-4fa1-9369-5f327e1d7d8e', '39eba28a-fae8-4262-b34d-9112c079e2d7', '4f203dfd-6b08-49d4-8a32-adc29618418a', '59b45547-7016-4078-a1f5-982f65238b9a', '847a72a9-fef6-4cb6-853e-0dbf2799ae46', '132a4836-bca7-4146-b4e7-10425c95267f'], 'org_text': ['Friend', 'repeat/AGAIN', 'WATER', 'TABLE', 'TIRED', 'Brother', 'Cat', 'milk'], 'clean_text': ['friend', 'again', 'water', 'table', 'tired', 'brother', 'cat', 'milk'], 'signer_id': tensor([450,   0,  77,  32, 349,  32,  12, 152]), 'signer': tensor([55,  0, -1, 17, 10, 17,  6, 98]), 'file': ['ASL ABC Song  NEW with ASL Letters and Signs', 'repeatAGAIN', 'Unit 11 Vocabularymp4', 'ASL 1 Unit 3 Vocabulary', 'ASL 1st class VOCABULARY BUILDER', 'ASL 1 Unit 4 Vocabulary', 'Mastering ASL Unit 4 Vocabulary signed by Dr Wooten', 'Milk - Asl'], 'label': tensor([ 

In [4]:
# test_dataset.show_video(0)

### Print Train Device

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f'num of GPU: {torch.cuda.device_count()}')
    print(torch.cuda.get_device_properties(0))
else:
    print(device)

num of GPU: 1
_CudaDeviceProperties(name='NVIDIA GeForce RTX 3060 Laptop GPU', major=8, minor=6, total_memory=6143MB, multi_processor_count=30, uuid=a43e09d2-abbb-44a0-a8cb-9ebfcebe6d64, L2_cache_size=3MB)


### Load Model Architecture

In [6]:
from model.RetNet18_GRU import ResNet18_GRU

cnn_gru_model = ResNet18_GRU().to(device)
cnn_gru_model.to(device)

ResNet18_GRU(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

### Train Model

In [7]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer (Adam, no learning rate scheduler)
optimizer = optim.Adam(cnn_gru_model.parameters(), lr=0.001)  # No scheduler

# Scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epochs = 12  # Adjust based on performance

def check_nan(tensor, name):
    if torch.isnan(tensor).any() or torch.isinf(tensor).any():
        print(f"⚠️ NaN or Inf detected in {name}!")

for epoch in range(num_epochs):
    cnn_gru_model.train()  # Set model to training mode
    total_loss = 0

    train_loader = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

    for videos, labels, metadata in train_loader:
        videos, labels = videos.to(device), labels.to(device)
        
        # Reset gradients
        optimizer.zero_grad() 
        videos = videos.float() / 255.0
        
        # Forward pass
        outputs = cnn_gru_model(videos)

        # Check NaN values
        for name, param in cnn_gru_model.named_parameters():
            check_nan(param, f"Param {name}")
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backpropagation
        loss.backward()

        # Gradient Clipping
        for name, param in cnn_gru_model.named_parameters():
            if param.grad is not None:
                check_nan(param.grad, f"Grad {name}")

        torch.nn.utils.clip_grad_norm_(cnn_gru_model.parameters(), max_norm=5) # Gradient clipping

        # Update model weights
        optimizer.step()

        total_loss += loss.item()
        train_loader.set_postfix(loss=loss.detach().item())  # Update tqdm display

    avg_loss = total_loss / len(data_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

    # Validation step
    cnn_gru_model.eval()
    correct, total = 0, 0

    # Wrap validation loader with tqdm for validation progress
    val_loader = tqdm(validation_loader, desc="Validating", leave=False)

    with torch.no_grad():
        for videos, labels, metadata in test_loader:
            videos = videos.to(device).float()  # Convert videos to float32
            labels = labels.to(device).long()   # Convert labels to long
            outputs = cnn_gru_model(videos)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

Epoch 1/12:   0%|          | 0/377 [00:39<?, ?it/s]


KeyboardInterrupt: 