Colab version to train on T4 GPU:
[https://colab.research.google.com/drive/1TRE6DIhLm5P6k1bdFxXlRpOvzfalH4zC#scrollTo=xfosNdAaoWeH](https://colab.research.google.com/drive/1TRE6DIhLm5P6k1bdFxXlRpOvzfalH4zC#scrollTo=xfosNdAaoWeH)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

import os

import matplotlib.pyplot as plt
import matplotlib.animation as animation

import av

from tqdm import tqdm

from IPython.display import HTML

from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
from bayesian_torch.layers.variational_layers.linear_variational import LinearReparameterization


### Load Dataset

In [3]:
from utils.dataset import load_msasl

label_threshold = 100
test_dataset, train_dataset, validation_dataset = load_msasl('bin', label_threshold)

[TRAIN] Loaded 3012 videos with top 100 labels
[TEST] Loaded 458 videos with top 100 labels
[VALIDATION] Loaded 815 videos with top 100 labels


In [4]:

from torch.utils.data import DataLoader

data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
for videos, labels, metadata in data_loader:
        print(f"Batch of videos: {videos.shape}") # (batch_size, 64, C, H, W)
        print(f"Batch of labels: {labels.shape}") # (batch_size,)
        print(f"Metadata sample: {metadata}") # Dictionary of metadata
        print(labels)
        break # Checking the first batch

test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=4)



validation_loader = DataLoader(validation_dataset, batch_size=8, shuffle=True, num_workers=4)

for videos, labels, metadata in validation_loader:
        print(f"Batch of videos: {videos.shape}") # (B, T, C, H, W)
        print(f"Batch of labels: {labels.shape}") # (batch_size,)
        print(f"Metadata sample: {metadata}") # Dictionary of metadata
        print(labels)
        break # Checking the first batch

Batch of videos: torch.Size([8, 64, 3, 224, 224])
Batch of labels: torch.Size([8])
Metadata sample: {'id': ['4e27d655-a1c3-4e18-88a7-66d54a79faa8', 'b69acd18-b57e-4ecd-b34c-58c41350d2ca', 'a8a79544-1d0d-4200-a1c6-bdc96bcc1faf', '043f4248-8c6b-417c-8802-cc46db558491', 'edfd9c7b-5332-402d-bd37-c92fcf754118', '44bfe295-d98a-4752-900e-4d7742817267', '7f605d99-a586-4f22-96b1-c8dee1dc9fbc', '72610c19-ef0d-412d-bb81-d47e911fa7a9'], 'org_text': ['pink ', 'FAMILY', 'learn', 'boy cousin(male location)', 'FOOD', 'SICK', 'what', 'wish'], 'clean_text': ['pink', 'family', 'learn', 'cousin', 'eat', 'sick', 'what', 'wish'], 'signer_id': tensor([144,  72,  40, 107, 124,  77,  40,  36]), 'signer': tensor([-1, 23, -1,  2, -1, -1, -1, 89]), 'file': ['pink - ASL sign for pink', 'Unit 08 Vocabulary', 'ActionsVerbs in school ASL', 'Unit 4 vocabulary pg 1-3', 'Major Exports Vocabulary  ASL - American Sign Language', 'Unit 6 Vocabularymp4', 'Robber and the Geek Vocab video for ASL class', 'wish'], 'label': ten

In [5]:
# test_dataset.show_video(0)

### Print Train Device

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print(f'num of GPU: {torch.cuda.device_count()}')
    print(torch.cuda.get_device_properties(0))
else:
    print(device)

num of GPU: 1
_CudaDeviceProperties(name='NVIDIA GeForce RTX 3060 Laptop GPU', major=8, minor=6, total_memory=6143MB, multi_processor_count=30, uuid=a43e09d2-abbb-44a0-a8cb-9ebfcebe6d64, L2_cache_size=3MB)


### Load Model Architecture

In [13]:
from model.RetNet18_GRU import ResNet18_GRU

cnn_gru_model = ResNet18_GRU()

### Convert FC Layer to Bayesian

In [15]:
# convert to bayesian
ori_in_features = cnn_gru_model.fc.in_features
ori_out_features = cnn_gru_model.fc.out_features

cnn_gru_model.fc = LinearReparameterization(in_features=ori_in_features,
                                out_features=label_threshold,
                                prior_mean=0,
                                prior_variance=1,
                                posterior_mu_init=0,
                                posterior_rho_init=-3.0,
                                bias=True)

cnn_gru_model.fc.dnn_to_bnn_flag = True

cnn_gru_model.to(device)

ResNet18_GRU(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

### Train Model

In [None]:
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from datetime import datetime

# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer (Adam, no learning rate scheduler)
optimizer = optim.Adam(cnn_gru_model.parameters(), lr=0.001)  # No scheduler

# Scheduler
scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epochs = 12  # Adjust based on performance
batch_size = 8

# Directory to save models and results
output_dir = "model_outputs"
os.makedirs(output_dir, exist_ok=True)

def check_nan(tensor, name):
    if torch.isnan(tensor).any() or torch.isinf(tensor).any():
        print(f"⚠️ NaN or Inf detected in {name}!")

results_file = os.path.join(output_dir, "training_metrics.txt")

for epoch in range(num_epochs):
    cnn_gru_model.train()  # Set model to training mode
    total_loss = 0
    correct_train = 0
    total_train = 0

    train_loader_tqdm = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

    for videos, labels, metadata in train_loader_tqdm:
        videos, labels = videos.to(device), labels.to(device)

        # Reset gradients
        optimizer.zero_grad()
        videos = videos.float() / 255.0

        # Forward pass
        outputs = cnn_gru_model(videos)

        # Check NaN values
        for name, param in cnn_gru_model.named_parameters():
            check_nan(param, f"Param {name}")

        # Compute loss
        kl = get_kl_loss(cnn_gru_model)
        loss = criterion(outputs, labels)
        loss = loss + kl / batch_size

        # Backpropagation
        loss.backward()

        # Gradient Clipping
        for name, param in cnn_gru_model.named_parameters():
            if param.grad is not None:
                check_nan(param.grad, f"Grad {name}")

        torch.nn.utils.clip_grad_norm_(cnn_gru_model.parameters(), max_norm=5) # Gradient clipping

        # Update model weights
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        train_loader_tqdm.set_postfix(loss=loss.detach().item())  # Update tqdm display

    avg_loss = total_loss / len(data_loader)
    train_accuracy = 100 * correct_train / total_train
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

    # Validation step
    cnn_gru_model.eval()
    correct_val, total_val = 0, 0

    # Wrap validation loader with tqdm for validation progress
    val_loader_tqdm = tqdm(validation_loader, desc="Validating", leave=False)

    with torch.no_grad():
        for videos, labels, metadata in val_loader_tqdm:
            videos = videos.to(device).float() / 255.0 # Convert videos to float32
            labels = labels.to(device).long()   # Convert labels to long
            outputs = cnn_gru_model(videos)
            predicted_val = torch.argmax(outputs, dim=1)
            correct_val += (predicted_val == labels).sum().item()
            total_val += labels.size(0)

    val_accuracy = 100 * correct_val / total_val
    print(f"Validation Accuracy: {val_accuracy:.2f}%")

    # Save the model after each epoch
    timestamp = datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
    model_name = f"cnn_gru_epoch_{epoch+1}_{timestamp}.pth"
    model_path = os.path.join(output_dir, model_name)
    torch.save(cnn_gru_model.state_dict(), model_path)
    print(f"Model saved to: {model_path}")

    # Output metrics to a text file
    with open(results_file, "a") as f:
        f.write(f"Epoch: {epoch+1}, Loss: {avg_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Validation Accuracy: {val_accuracy:.2f}%, Model Name: {model_name}\n")

    # scheduler.step()

print("Training finished. Results saved to:", results_file)

Epoch 1/12:   0%|          | 0/377 [01:11<?, ?it/s]


KeyboardInterrupt: 

### Test set evaluation

In [None]:
# Final evaluation on the test set
cnn_gru_model.eval()
correct_test, total_test = 0, 0
test_loader_tqdm = tqdm(test_loader, desc="Testing", leave=False)

with torch.no_grad():
    for videos, labels, metadata in test_loader_tqdm:
        videos = videos.to(device).float() / 255.0
        labels = labels.to(device).long()
        outputs = cnn_gru_model(videos)
        predicted_test = torch.argmax(outputs, dim=1)
        correct_test += (predicted_test == labels).sum().item()
        total_test += labels.size(0)

test_accuracy = 100 * correct_test / total_test
print(f"Test Accuracy: {test_accuracy:.2f}%")

with open(results_file, "a") as f:
    f.write(f"Final Test Accuracy: {test_accuracy:.2f}%\n")