## Surfing Maneuver Recognition using R3D Transfer Learning
---


This notebook demonstrates transfer learning for video-based surfing maneuver recognition using the **R3D (3D ResNet)** model.

**Dataset**: Surf Maneuver Recognition Dataset

The dataset contains video clips of surfers performing four different maneuvers:
1. **Cutback-Frontside**: A sharp turn back towards the breaking part of the wave
2. **Take-off**: The act of catching the wave and standing up on the surfboard
3. **360**: A full rotation maneuver while on the wave
4. **Roller**: Riding on top of the breaking wave

**Model**: [R3D](https://arxiv.org/abs/1711.11248) (3D ResNet) is a powerful video recognition model that extends the ResNet architecture to 3D convolutions for capturing spatiotemporal features in videos. It applies 3D convolutions throughout the network to jointly model spatial and temporal information.

**Approach**: We'll use transfer learning by loading R3D pretrained on Kinetics-400 (a large action recognition dataset with 400 classes) and fine-tune it for our surfing maneuver classification task.


### Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision.models.video import r3d_18, R3D_18_Weights
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import os
from pathlib import Path
import cv2
import numpy as np
import random


### Device Configuration


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.mps.is_available() else "cpu")  # For Mac M1/M2
print(f"Using device: {device}")


### ⚡ Performance Optimization: Cached Dataset

**Problem:** Loading and preprocessing videos on-the-fly is very slow:
- Video decoding happens for every epoch
- Frame extraction and resizing repeated unnecessarily
- With `num_workers=0`, this creates a major bottleneck

**Solution:** Preprocess videos **once** and cache them as `.pt` files:
1. Run `preprocess_videos_cache.py` to create cached tensors
2. Use `CachedSurfingManeuverDataset` to load preprocessed data
3. Get **10-20x faster** training speed!

**Usage:**
```bash
# First time only - preprocess and cache all videos
python preprocess_videos_cache.py
```

Then use the cached dataset below instead of the regular dataset.


In [None]:
class CachedSurfingManeuverDataset(Dataset):
    """
    Fast dataset that loads pre-cached video tensors.
    
    Videos should be preprocessed using preprocess_videos_cache.py first.
    This provides 10-20x faster data loading compared to on-the-fly preprocessing.
    """
    
    def __init__(self, cache_dir):
        self.cache_dir = Path(cache_dir)
        
        # Get all cached .pt files
        self.cache_files = []
        self.labels = []
        
        for class_idx, class_name in enumerate(class_names):
            class_cache_dir = self.cache_dir / class_name
            if class_cache_dir.exists():
                pt_files = list(class_cache_dir.glob('*.pt'))
                self.cache_files.extend(pt_files)
                self.labels.extend([class_idx] * len(pt_files))
        
        print(f"Found {len(self.cache_files)} cached videos across {len(class_names)} classes")
        
        if len(self.cache_files) == 0:
            print(f"\n⚠️  WARNING: No cached files found in {cache_dir}")
            print(f"Please run 'python preprocess_videos_cache.py' first to create the cache.")
    
    def __len__(self):
        return len(self.cache_files)
    
    def __getitem__(self, idx):
        """Load a preprocessed video tensor from cache."""
        cache_file = self.cache_files[idx]
        
        try:
            # Load cached data (already preprocessed!)
            data = torch.load(cache_file, map_location='cpu')
            video = data['video']
            label = data['label']
            
            return video, label
            
        except Exception as e:
            print(f"Error loading cached file {cache_file}: {e}")
            # Return a random valid sample instead
            return self.__getitem__((idx + 1) % len(self))


### Hyperparameters


In [None]:
# Adjust these values as needed
BATCH_SIZE = 4  # Smaller batch size for video (memory intensive)
EPOCHS = 5
NUM_FRAMES = 64  # R3D typically uses 64 frames (can be adjusted)
FRAME_SIZE = 224  # R3D input size (for reference - official transforms handle resizing)
LEARNING_RATE = 0.001
NUM_CLASSES = 4  # Cutback-Frontside, Take-off, 360, Roller

criterion = nn.CrossEntropyLoss()

# Set all random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Make training deterministic (may impact performance slightly)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Create generator for DataLoader
dataloader_generator = torch.Generator()
dataloader_generator.manual_seed(SEED)


### Class Names


In [None]:
class_names = [
    'cutback-frontside',
    'take-off',
    '360',
    'roller'
]


### Load Dataset

**Note:** This notebook uses cached preprocessed videos for fast training (10-20x speedup).

Before running this cell for the first time, you must preprocess the videos:
```bash
python preprocess_videos_cache.py
```

This will create `surfing_dataset_cache/` with preprocessed tensors.


### Custom Video Dataset

Since we're working with video data, we need a custom dataset class to:
- Load video files
- Sample frames uniformly
- Apply transformations
- Convert to the format expected by R3D (C, T, H, W)


In [None]:
class SurfingManeuverDataset(Dataset):
    """
    Custom dataset for loading surfing maneuver videos.
    
    """
    
    def __init__(self, data_dir, num_frames=13, transform=None):
        self.data_dir = Path(data_dir)
        self.num_frames = num_frames
        self.transform = transform
        
        # Get all video files and their labels
        self.video_paths = []
        self.labels = []
        
        # Supported video extensions
        video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.MP4', '.AVI', '.MOV', '.MKV']
        
        for class_idx, class_name in enumerate(class_names):
            class_dir = self.data_dir / class_name
            if class_dir.exists():
                for video_file in class_dir.iterdir():
                    if video_file.suffix in video_extensions:
                        self.video_paths.append(video_file)
                        self.labels.append(class_idx)
        
        print(f"Found {len(self.video_paths)} videos across {len(class_names)} classes")
        
    def __len__(self):
        return len(self.video_paths)
    
    def uniform_sample_indices(self, num_frames, target_frames):
        """
        Uniformly sample frame indices from a video.
        
        Args:
            num_frames: Total number of frames in the video
            target_frames: Number of frames to sample
            
        Returns:
            Array of frame indices
        """
        if num_frames < target_frames:
            # If video has fewer frames than needed, use linspace which will repeat indices
            indices = np.linspace(0, num_frames - 1, target_frames).round().astype(int)
        else:
            # Uniform sampling across the video
            indices = np.linspace(0, num_frames - 1, target_frames).round().astype(int)
        
        return indices
    
    def load_video(self, video_path):
        """
        Load video and sample frames uniformly.
        Returns: tensor of shape (T, H, W, C) in range [0, 255] as uint8
        R3D VideoClassification transforms expect this format.
        """
        cap = cv2.VideoCapture(str(video_path))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames == 0:
            raise ValueError(f"Could not read video: {video_path}")
        
        # Uniform temporal sampling
        frame_indices = self.uniform_sample_indices(total_frames, self.num_frames)
        
        frames = []
        for idx in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                # CRITICAL: Convert BGR to RGB (OpenCV loads as BGR)
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # NOTE: We don't resize here - let the official transforms handle it
                frames.append(frame)
            else:
                # If frame reading fails, use the last successful frame
                if frames:
                    frames.append(frames[-1])
                else:
                    raise ValueError(f"Could not read frame {idx} from {video_path}")
        
        cap.release()
        
        # Convert to numpy array: (T, H, W, C)
        video = np.stack(frames)
        
        # Convert to tensor, keep as uint8 in [T, H, W, C] format
        # R3D VideoClassification transforms expect this format
        video = torch.from_numpy(video).to(torch.uint8)
        
        return video
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        try:
            video = self.load_video(video_path)
            
            if self.transform:
                video = self.transform(video)
            
            return video, label
        except Exception as e:
            print(f"Error loading video {video_path}: {e}")
            # Return a random valid sample instead
            return self.__getitem__((idx + 1) % len(self))


### Data Transforms

R3D expects inputs with specific normalization (from Kinetics-400 pretraining).


In [None]:
# R3D preprocessing - using manual transforms since official ones have format issues
from torchvision.models.video import R3D_18_Weights
import torchvision.transforms.functional as F

# Normalization values from R3D Kinetics-400 pretraining
MEAN = [0.43216, 0.394666, 0.37645]
STD = [0.22803, 0.22145, 0.216989]

def video_transform(video):
    """
    Transform video from [T, H, W, C] uint8 to [C, T, H, W] float32 normalized.
    
    Args:
        video: torch.Tensor of shape [T, H, W, C] with uint8 values [0, 255]
    
    Returns:
        Processed video tensor [C, T, H, W] ready for R3D model
    """
    # Convert from [T, H, W, C] to [C, T, H, W]
    video = video.permute(3, 0, 1, 2)  # [T, H, W, C] -> [C, T, H, W]
    
    # Convert to float and scale to [0, 1]
    video = video.float() / 255.0
    
    # Resize each frame to 256x256 then center crop to 224x224
    # Process frame by frame to handle the temporal dimension correctly
    T = video.shape[1]
    resized_frames = []
    for t in range(T):
        frame = video[:, t, :, :]  # [C, H, W]
        # Resize to 256x256
        frame = F.resize(frame, [256, 256], antialias=True)
        # Center crop to 224x224
        frame = F.center_crop(frame, [224, 224])
        resized_frames.append(frame)
    
    # Stack back to [C, T, H, W]
    video = torch.stack(resized_frames, dim=1)
    
    # Normalize with R3D mean and std
    mean = torch.tensor(MEAN).view(3, 1, 1, 1)
    std = torch.tensor(STD).view(3, 1, 1, 1)
    video = (video - mean) / std
    
    return video

train_transform = video_transform
val_transform = video_transform


In [None]:
# Cached dataset paths
CACHE_TRAIN_DIR = './surfing_dataset_cache/train'
CACHE_VAL_DIR = './surfing_dataset_cache/val'
CACHE_TEST_DIR = './surfing_dataset_cache/test'

# Load cached datasets
train_dataset = CachedSurfingManeuverDataset(cache_dir=CACHE_TRAIN_DIR)
val_dataset = CachedSurfingManeuverDataset(cache_dir=CACHE_VAL_DIR)
test_dataset = CachedSurfingManeuverDataset(cache_dir=CACHE_TEST_DIR)

# Create DataLoaders (no transform needed - videos are already preprocessed!)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")
print("✓ Using CACHED dataset - Fast data loading enabled!")


### Visualize Sample Videos

Let's visualize a few frames from sample videos in the training set.


In [None]:
def plot_video_frames(dataloader, classes, n_videos=2, n_frames=4):
    """
    Plot sample frames from videos in the dataset.
    """
    videos, labels = next(iter(dataloader))
    videos = videos[:n_videos]
    labels = labels[:n_videos]
    
    # Denormalize for visualization
    mean = torch.tensor(MEAN).view(1, 3, 1, 1, 1)
    std = torch.tensor(STD).view(1, 3, 1, 1, 1)
    videos = videos * std + mean
    
    fig, axes = plt.subplots(n_videos, n_frames, figsize=(15, 4 * n_videos))
    if n_videos == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(n_videos):
        # Sample frames uniformly from the video
        frame_indices = np.linspace(0, videos.shape[2] - 1, n_frames, dtype=int)
        
        for j, frame_idx in enumerate(frame_indices):
            # Extract frame: (C, T, H, W) -> (H, W, C)
            frame = videos[i, :, frame_idx, :, :].permute(1, 2, 0).numpy()
            frame = np.clip(frame, 0, 1)
            
            axes[i, j].imshow(frame)
            if j == 0:
                axes[i, j].set_ylabel(classes[labels[i].item()], fontsize=12, fontweight='bold')
            axes[i, j].set_title(f'Frame {frame_idx + 1}')
            axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize samples
if len(train_dataset) > 0:
    plot_video_frames(train_loader, class_names, n_videos=2, n_frames=4)
else:
    print("No videos found in training set. Please check your data directory.")


### Training Loop

Define a training function for video classification.


In [None]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    """
    Train for one epoch.
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (videos, labels) in enumerate(train_loader):
        videos, labels = videos.to(device), labels.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(videos)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * videos.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if (batch_idx + 1) % 10 == 0:
            print(f'  Batch [{batch_idx + 1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, Acc: {100. * correct / total:.2f}%')
    
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device):
    """
    Validate the model.
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for videos, labels in val_loader:
            videos, labels = videos.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(videos)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * videos.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc


def train_for_epochs(model, train_loader, val_loader, optimizer, criterion, device, epochs):
    """
    Train the model for multiple epochs.
    """
    history = {
        'train': {'loss': [], 'accuracy': []},
        'val': {'loss': [], 'accuracy': []}
    }
    
    for epoch in range(epochs):
        print(f'\nEpoch [{epoch + 1}/{epochs}]')
        print('-' * 50)
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Save history
        history['train']['loss'].append(train_loss)
        history['train']['accuracy'].append(train_acc)
        history['val']['loss'].append(val_loss)
        history['val']['accuracy'].append(val_acc)
    
    return history


### Load Pretrained R3D Model

Let's load R3D pretrained on Kinetics-400 (a large action recognition dataset with 400 classes) and adapt it for our surfing maneuver classification task.


In [None]:
# Model save path
MODEL_PATH = 'r3d_surfing_model.pth'

# Check if saved model exists
if os.path.exists(MODEL_PATH):
    print(f"Found saved model at '{MODEL_PATH}', loading...")
    
    # Create model structure
    model_transfer = r3d_18(weights=None)
    model_transfer.fc = nn.Linear(512, NUM_CLASSES)
    
    # Load saved weights
    model_transfer.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model_transfer = model_transfer.to(device)
    model_transfer.eval()
    
    print("✓ Loaded saved model successfully!")
    print(f"Total parameters: {sum(p.numel() for p in model_transfer.parameters()):,}")
    
    # Create optimizer (for consistency, even though we won't train)
    optimizer_transfer = optim.Adam(
        model_transfer.fc.parameters(), 
        lr=LEARNING_RATE,
        weight_decay=1e-4
    )
    
else:
    print("No saved model found, creating new model...")
    
    # Load R3D with pretrained weights from Kinetics-400
    weights = R3D_18_Weights.KINETICS400_V1
    model_transfer = r3d_18(weights=weights)

    # Freeze all layers first
    for param in model_transfer.parameters():
        param.requires_grad = False

    # Replace the final classification layer for our 4 classes
    # R3D classifier is: fc (Linear layer with 512 input features -> 400 output classes)
    # We need to replace this with our own layer
    in_features = model_transfer.fc.in_features  # 512
    model_transfer.fc = nn.Linear(in_features, NUM_CLASSES)

    # Move to device
    model_transfer = model_transfer.to(device)

    # Optimizer - only optimize the final classifier layer
    optimizer_transfer = optim.Adam(
        model_transfer.fc.parameters(), 
        lr=LEARNING_RATE,
        weight_decay=1e-4
    )

    print("R3D Model (pretrained on Kinetics-400):")
    print(f"Total parameters: {sum(p.numel() for p in model_transfer.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model_transfer.parameters() if p.requires_grad):,}")


### Train the Model


In [None]:
# Train with transfer learning (only if model doesn't exist)
if not os.path.exists(MODEL_PATH):
    if len(train_dataset) > 0:
        print("\n" + "="*50)
        print("Training R3D with Transfer Learning...")
        print("="*50)
        history_transfer = train_for_epochs(
            model_transfer, 
            train_loader, 
            val_loader, 
            optimizer_transfer, 
            criterion, 
            device, 
            EPOCHS
        )
        
        # Save the trained model
        torch.save(model_transfer.state_dict(), MODEL_PATH)
        print(f"\n✓ Model saved as '{MODEL_PATH}'")
    else:
        print("Skipping training - no data found")
        print("Please update TRAIN_DATA_DIR and VAL_DATA_DIR with your dataset paths.")
else:
    print("\n" + "="*50)
    print("Using existing trained model, skipping training.")
    print(f"Model loaded from: '{MODEL_PATH}'")
    print("="*50)
    print("\nTo retrain from scratch, delete the model file and rerun this cell.")


### (Optional) Fine-tuning: Unfreeze More Layers

After initial training, we can optionally unfreeze more layers for fine-tuning.


In [None]:
# Uncomment to perform fine-tuning

# # Unfreeze the last few layers for fine-tuning
# # R3D structure: stem -> layer1 -> layer2 -> layer3 -> layer4 -> avgpool -> fc
# for param in model_transfer.layer4.parameters():
#     param.requires_grad = True

# # Use a lower learning rate for fine-tuning
# optimizer_finetune = optim.Adam(
#     filter(lambda p: p.requires_grad, model_transfer.parameters()),
#     lr=LEARNING_RATE * 0.1,
#     weight_decay=1e-4
# )

# print(f"Fine-tuning - Trainable parameters: {sum(p.numel() for p in model_transfer.parameters() if p.requires_grad):,}")

# # Fine-tune for additional epochs
# if len(train_dataset) > 0:
#     print("\n" + "="*50)
#     print("Fine-tuning R3D...")
#     print("="*50)
#     history_finetune = train_for_epochs(
#         model_transfer, 
#         train_loader, 
#         val_loader, 
#         optimizer_finetune, 
#         criterion, 
#         device, 
#         EPOCHS // 2  # Fine-tune for fewer epochs
#     )


### Plot Training History

Visualize the training and validation metrics over epochs.


In [None]:
# Plot training history
def plot_history(history, title):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    axes[0].plot(history['train']['loss'], label='Train Loss', marker='o')
    axes[0].plot(history['val']['loss'], label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'{title} - Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # Plot accuracy
    axes[1].plot(history['train']['accuracy'], label='Train Accuracy', marker='o')
    axes[1].plot(history['val']['accuracy'], label='Val Accuracy', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title(f'{title} - Accuracy')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

# Plot transfer learning results
if 'history_transfer' in locals():
    plot_history(history_transfer, 'R3D Transfer Learning')
    print(f"\nFinal Validation Accuracy: {history_transfer['val']['accuracy'][-1]:.2f}%")


### Inference Example

Let's test the model on some test videos to evaluate final performance.


In [None]:
def predict_video(model, video_tensor, device, class_names):
    """
    Predict the class of a video.
    """
    model.eval()
    with torch.no_grad():
        video_tensor = video_tensor.unsqueeze(0).to(device)  # Add batch dimension
        output = model(video_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        confidence, predicted = probabilities.max(1)
        
    return class_names[predicted.item()], confidence.item()


# Test on TEST set
if 'model_transfer' in locals() and len(test_dataset) > 0:
    print("\n" + "="*50)
    print("Testing on TEST videos")
    print("="*50)
    
    for i in range(min(5, len(test_dataset))):
        video, true_label = test_dataset[i]
        predicted_class, confidence = predict_video(model_transfer, video, device, class_names)
        true_class = class_names[true_label]
        
        print(f"\nVideo {i+1}:")
        print(f"  True label: {true_class}")
        print(f"  Predicted: {predicted_class} (confidence: {confidence*100:.2f}%)")
        print(f"  {'✓ Correct' if predicted_class == true_class else '✗ Incorrect'}")
    
    # Evaluate on entire test set
    print("\n" + "="*50)
    print("Full Test Set Evaluation")
    print("="*50)
    test_loss, test_acc = validate(model_transfer, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.2f}%")


---

### References

- [R3D Paper: A Closer Look at Spatiotemporal Convolutions for Action Recognition](https://arxiv.org/abs/1711.11248)
- [PyTorch Video Models Documentation](https://pytorch.org/vision/stable/models.html#video-classification)
- [Kinetics-400 Dataset](https://deepmind.com/research/open-source/kinetics)
