In [12]:
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from PIL import Image
import numpy as np

# Dataset class for images
class ActionImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.labels = []
        self.classes = os.listdir(root_dir)
        for label, action in enumerate(self.classes):
            action_dir = os.path.join(root_dir, action)
            for img_file in os.listdir(action_dir):
                img_path = os.path.join(action_dir, img_file)
                self.data.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        img = Image.fromarray(img)  # Convert NumPy array to PIL image
        if self.transform:
            img = self.transform(img)
        return img, label


# Dataset class for videos
# class ActionVideoDataset(Dataset):
#     def __init__(self, video_dir, num_frames=16, transform=None):
#         self.video_dir = video_dir
#         self.num_frames = num_frames
#         self.transform = transform
#         self.video_files = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.mp4')]

#     def __len__(self):
#         return len(self.video_files)

#     def __getitem__(self, idx):
#         video_path = self.video_files[idx]
#         cap = cv2.VideoCapture(video_path)
#         frames = []
#         while len(frames) < self.num_frames:
#             ret, frame = cap.read()
#             if not ret:
#                 break
#             frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#             frame = Image.fromarray(frame)  # Convert NumPy array to PIL image
#             if self.transform:
#                 frame = self.transform(frame)
#             frames.append(frame)
#         cap.release()
#         # Pad missing frames with zeros if video is too short
#         while len(frames) < self.num_frames:
#             frames.append(torch.zeros_like(frames[0]))
#         frames = torch.stack(frames)
#         return frames, video_path
# class ActionVideoDataset(Dataset):
#     def __init__(self, video_dir, num_frames=16, transform=None):
#         self.video_dir = video_dir
#         self.num_frames = num_frames
#         self.transform = transform
#         self.classes = os.listdir(video_dir)  # List of action folder names
#         self.label_to_class = {i: cls for i, cls in enumerate(self.classes)}  # Map index to folder name
#         self.video_files = [(os.path.join(video_dir, cls, f), i) 
#                             for i, cls in enumerate(self.classes) 
#                             for f in os.listdir(os.path.join(video_dir, cls)) if f.endswith('.mp4')]

#     def __len__(self):
#         return len(self.video_files)

#     def __getitem__(self, idx):
#         video_path, label = self.video_files[idx]
#         cap = cv2.VideoCapture(video_path)
#         frames = []
#         while len(frames) < self.num_frames:
#             ret, frame = cap.read()
#             if not ret:
#                 break
#             frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#             frame = Image.fromarray(frame)
#             if self.transform:
#                 frame = self.transform(frame)
#             frames.append(frame)
#         cap.release()
#         while len(frames) < self.num_frames:
#             frames.append(torch.zeros_like(frames[0]))
#         frames = torch.stack(frames)
#         return frames, label, video_path

class ActionVideoDatasetSingle(Dataset):
    def __init__(self, video_dir, num_frames=16, transform=None):
        self.video_dir = video_dir
        self.num_frames = num_frames
        self.transform = transform
        self.classes = os.listdir(video_dir)  # List of action folder names
        self.label_to_class = {i: cls for i, cls in enumerate(self.classes)}  # Map index to folder name

        # Select one video per action label folder
        self.video_files = []
        for label, cls in enumerate(self.classes):
            action_folder = os.path.join(video_dir, cls)
            videos = [f for f in os.listdir(action_folder) if f.endswith('.mp4')]
            if videos:
                # Pick the first video from the folder (or a random video)
                self.video_files.append((os.path.join(action_folder, videos[0]), label))

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path, label = self.video_files[idx]
        cap = cv2.VideoCapture(video_path)
        frames = []
        while len(frames) < self.num_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)  # Convert NumPy array to PIL image
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()

        # Pad missing frames with zeros if the video is too short
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        frames = torch.stack(frames)
        return frames, label, video_path


# Transforms
image_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

video_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load datasets
train_dataset = ActionImageDataset(root_dir='action_sp', transform=image_transform)
test_dataset = ActionVideoDatasetSingle(video_dir='resized_videos', transform=video_transform)

# Subset to train on 50 images
# train_subset = Subset(train_dataset, list(range(200)))
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Test loader
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)



In [9]:
# import os
# import cv2
# import torch
# import torch.nn as nn
# from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
# from torchvision.transforms import Compose, ToTensor, Resize, Normalize, RandomHorizontalFlip, RandomRotation, ColorJitter
# from PIL import Image
# import numpy as np
# from collections import Counter

# # Dataset class for images
# class ActionImageDataset(Dataset):
#     def __init__(self, root_dir, transform=None):
#         self.root_dir = root_dir
#         self.transform = transform
#         self.data = []
#         self.labels = []
#         self.classes = os.listdir(root_dir)
#         for label, action in enumerate(self.classes):
#             action_dir = os.path.join(root_dir, action)
#             for img_file in os.listdir(action_dir):
#                 img_path = os.path.join(action_dir, img_file)
#                 self.data.append(img_path)
#                 self.labels.append(label)

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         img_path = self.data[idx]
#         label = self.labels[idx]
#         img = cv2.imread(img_path)
#         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
#         img = Image.fromarray(img)  # Convert NumPy array to PIL image
#         if self.transform:
#             img = self.transform(img)
#         return img, label


# # Dataset class for videos
# class ActionVideoDatasetSingle(Dataset):
#     def __init__(self, video_dir, num_frames=16, transform=None):
#         self.video_dir = video_dir
#         self.num_frames = num_frames
#         self.transform = transform
#         self.classes = os.listdir(video_dir)  # List of action folder names
#         self.label_to_class = {i: cls for i, cls in enumerate(self.classes)}  # Map index to folder name

#         # Select one video per action label folder
#         self.video_files = []
#         for label, cls in enumerate(self.classes):
#             action_folder = os.path.join(video_dir, cls)
#             videos = [f for f in os.listdir(action_folder) if f.endswith('.mp4')]
#             if videos:
#                 # Pick the first video from the folder (or a random video)
#                 self.video_files.append((os.path.join(action_folder, videos[0]), label))

#     def __len__(self):
#         return len(self.video_files)

#     def __getitem__(self, idx):
#         video_path, label = self.video_files[idx]
#         cap = cv2.VideoCapture(video_path)
#         frames = []
#         while len(frames) < self.num_frames:
#             ret, frame = cap.read()
#             if not ret:
#                 break
#             frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#             frame = Image.fromarray(frame)  # Convert NumPy array to PIL image
#             if self.transform:
#                 frame = self.transform(frame)
#             frames.append(frame)
#         cap.release()

#         # Pad missing frames with zeros if the video is too short
#         while len(frames) < self.num_frames:
#             frames.append(torch.zeros_like(frames[0]))
#         frames = torch.stack(frames)
#         return frames, label, video_path


# # Augmented transforms
# augmented_transform = Compose([
#     Resize((224, 224)),
#     RandomHorizontalFlip(),
#     RandomRotation(15),
#     ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
#     ToTensor(),
#     Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # Regular transforms for test dataset
# video_transform = Compose([
#     Resize((224, 224)),
#     ToTensor(),
#     Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # Load datasets
# train_dataset = ActionImageDataset(root_dir='action_sp', transform=augmented_transform)
# test_dataset = ActionVideoDatasetSingle(video_dir='resized_videos', transform=video_transform)

# # Class balancing with WeightedRandomSampler
# class_counts = Counter(train_dataset.labels)
# total_samples = sum(class_counts.values())
# class_weights = {label: total_samples / count for label, count in class_counts.items()}

# # Create a subset of the dataset
# train_subset = Subset(train_dataset, list(range(200)))

# # Adjust weights for the subset
# sample_weights = [class_weights[train_dataset.labels[idx]] for idx in train_subset.indices]
# sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_subset), replacement=True)

# # DataLoader with sampler for training
# train_loader = DataLoader(train_subset, batch_size=14, sampler=sampler)


# # Test loader
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [83]:
# import sys
# print(sys.executable)
# print(sys.version)


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# class RViT(nn.Module):
#     def __init__(self, num_classes, hidden_dim, num_layers, frame_dim):
#         super().__init__()
#         self.hidden_dim = hidden_dim
        
    #     # Patch embedding with reduced stride and kernel size
    #     # self.patch_embedding = nn.Conv3d(16, hidden_dim, kernel_size=(3, 8, 8), stride=(3, 4, 4), padding=(1, 2, 2))
    #     self.patch_embedding = nn.Conv3d(3, hidden_dim, kernel_size=(3, 8, 8), stride=(3, 4, 4), padding=(1, 2, 2))

    #     # Simplified learnable position encoding
    #     # Adjust the temporal dimension to match the expected input
    #     self.position_encoding = nn.Parameter(torch.randn(1, hidden_dim, 1, 56, 56), requires_grad=True)
        
    #     self.rvit_units = nn.ModuleList([RViTUnit(hidden_dim) for _ in range(num_layers)])
    #     self.classifier = nn.Linear(hidden_dim, num_classes)
        
    #     self.frame_reconstruction = nn.Sequential(
    #         nn.Conv3d(hidden_dim, 64, kernel_size=(1, 3, 3), padding=(0, 1, 1)),
    #         nn.ReLU(),
    #         nn.Conv3d(64, frame_dim[0], kernel_size=(1, 3, 3), padding=(0, 1, 1))
    #     )
    #     self.temporal_upsample = nn.Upsample(size=(15, 224, 224), mode='trilinear', align_corners=False)

    # def forward(self, x):
    #     # Patch embedding
    #     patches = self.patch_embedding(x)  # Shape: [batch_size, hidden_dim, depth, height, width]
    #     # print("After patch embedding:", patches.shape)
        
    #     # Dynamically resize position encoding to match patches' spatial dimensions
    #     _, _, depth, height, width = patches.shape
    #     pos_encoding = F.interpolate(
    #         self.position_encoding, size=(depth, height, width), mode='trilinear', align_corners=False
    #     )  # Adjusted shape: [1, hidden_dim, depth, height, width]
    #     # print("Position encoding shape after interpolation:", pos_encoding.shape)
        
    #     # Add position encoding
        # patches += pos_encoding
        # # print("After adding position encoding:", patches.shape)

        # # Initialize recurrent state
        # h = torch.zeros_like(patches)
        # # print("Initialized recurrent state:", h.shape)

        # # Pass through RViT units
        # for i, unit in enumerate(self.rvit_units):
        #     h = unit(patches, h)
        #     # print(f"After RViT unit {i+1}:", h.shape)

        # # Classification
        # h_last = h.mean(dim=(2, 3, 4))  # Global average pooling over spatial dimensions
        # action_logits = self.classifier(h_last)
        # print("Action logits shape:", action_logits.shape)

        # # Frame reconstruction
        # reconstructed_frame = self.frame_reconstruction(h)
        # # print("Reconstructed frame shape:", reconstructed_frame.shape)

        # return action_logits, reconstructed_frame

class RViT(nn.Module):
    def __init__(self, num_classes, hidden_dim, num_layers, frame_dim):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Patch embedding
        self.patch_embedding = nn.Conv3d(3, hidden_dim, kernel_size=(3, 8, 8), stride=(3, 4, 4), padding=(1, 2, 2))
        self.position_encoding = nn.Parameter(torch.randn(1, hidden_dim, 1, 56, 56), requires_grad=True)

        # Attention layers
        self.scaled_attention = ScaledDotProductAttention(hidden_dim)
        self.linear_attention = LinearAttention(hidden_dim)

        # Recurrent Vision Transformer Units
        self.rvit_units = nn.ModuleList([RViTUnit(hidden_dim) for _ in range(num_layers)])
        
        # Classifier
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # Patch embedding
        patches = self.patch_embedding(x)
        
        # Add positional encoding
        _, _, depth, height, width = patches.shape
        pos_encoding = F.interpolate(self.position_encoding, size=(depth, height, width), mode='trilinear', align_corners=False)
        patches += pos_encoding

        # Attention mechanisms
        scaled_attn_output = self.scaled_attention(patches, patches)
        lin_attn_output = self.linear_attention(scaled_attn_output, scaled_attn_output)

        # Recurrent processing
        h = torch.zeros_like(lin_attn_output).to(lin_attn_output.device)
        for unit in self.rvit_units:
            h = unit(lin_attn_output, h)

        # Classification
        h_last = h.mean(dim=(2, 3, 4))  # Global average pooling
        logits = self.classifier(h_last)
        return logits, h




class ScaledDotProductAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.Wq = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wk = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wv = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.scale = hidden_dim ** -0.5

    def forward(self, x, h):
        # Compute queries, keys, and values
        q = self.Wq(x)
        k = self.Wk(h)
        v = self.Wv(h)
        
        # Compute scaled dot-product attention
        attn = torch.softmax((q * k).sum(dim=1, keepdim=True) * self.scale, dim=-1)
        output = attn * v  # Apply attention weights to values
        
        return output
        
class LinearAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.Wq = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wk = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wv = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wo = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)

    def forward(self, x, h):
        # Compute queries, keys, and values
        q = self.Wq(x)  # Shape: [batch_size, hidden_dim, depth, height, width]
        k = self.Wk(h)  # Shape: [batch_size, hidden_dim, depth, height, width]
        v = self.Wv(h)  # Shape: [batch_size, hidden_dim, depth, height, width]
        
        # Reshape q and k for attention computation
        q = q.flatten(start_dim=2)  # Shape: [batch_size, hidden_dim, depth*height*width]
        k = k.flatten(start_dim=2)  # Shape: [batch_size, hidden_dim, depth*height*width]
        v = v.flatten(start_dim=2)  # Shape: [batch_size, hidden_dim, depth*height*width]

        # Compute attention weights
        attn_weights = torch.bmm(q.transpose(1, 2), k)  # Shape: [batch_size, depth*height*width, depth*height*width]
        attn_weights = attn_weights / (k.size(1) ** 0.5)  # Scale by sqrt of hidden_dim
        attn_weights = torch.softmax(attn_weights, dim=-1)  # Apply softmax over last dimension

        # Apply attention weights to values
        attn_output = torch.bmm(v, attn_weights.transpose(1, 2))  # Shape: [batch_size, hidden_dim, depth*height*width]
        
        # Reshape back to 3D
        attn_output = attn_output.view_as(h)  # Shape: [batch_size, hidden_dim, depth, height, width]

        # Final projection to match the input shape
        output = self.Wo(attn_output)  # Shape: [batch_size, hidden_dim, depth, height, width]
        return output



class RViTUnit(nn.Module):
    def __init__(self, hidden_dim, dropout_rate=0.5
                 ):
        super().__init__()
        self.attention_gate = LinearAttention(hidden_dim)  # Use LinearAttention here
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.recurrent_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, h):
        # Attention mechanism with linear attention
        attn_output = self.attention_gate(x, h)
        attn_output = self.recurrent_dropout(attn_output)
    
        # Reshape for LayerNorm
        batch_size, hidden_dim, depth, height, width = attn_output.shape
        h_flat = h.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
        attn_output_flat = attn_output.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
    
        # Apply LayerNorm
        h_new_flat = self.layer_norm1(h_flat + attn_output_flat)
    
        # Reshape back to original shape
        h_new = h_new_flat.reshape(batch_size, depth, height, width, hidden_dim).permute(0, 4, 1, 2, 3)
    
        # Apply FFN with LayerNorm
        h_new_flat = h_new.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
        h_new_flat = self.layer_norm2(h_new_flat + self.ffn(h_new_flat))
    
        h_new = h_new_flat.reshape(batch_size, depth, height, width, hidden_dim).permute(0, 4, 1, 2, 3)
        return h_new





#new trial

In [11]:
import os
from collections import Counter
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, ColorJitter, Resize, ToTensor, Normalize

# Dataset class for images
class ActionImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, target_count=None, augment=False):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.labels = []
        self.classes = os.listdir(root_dir)
        self.augment = augment

        # Count images per class and balance classes with augmentation
        if target_count and augment:
            self._balance_classes(target_count)

        # Load dataset
        for label, action in enumerate(self.classes):
            action_dir = os.path.join(root_dir, action)
            for img_file in os.listdir(action_dir):
                img_path = os.path.join(action_dir, img_file)
                self.data.append(img_path)
                self.labels.append(label)

    def _balance_classes(self, target_count):
        # Data augmentation pipeline
        augmentation_transform = Compose([
            RandomHorizontalFlip(p=0.5),
            RandomRotation(degrees=15),
            ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        ])

        for label, action in enumerate(self.classes):
            action_dir = os.path.join(self.root_dir, action)
            image_files = os.listdir(action_dir)
            count = len(image_files)
            if count < target_count:
                print(f"Augmenting {action} with {target_count - count} new images.")
                for i in range(target_count - count):
                    img_path = os.path.join(action_dir, image_files[i % count])
                    img = Image.open(img_path)
                    
                    # Convert image to RGB mode if not already
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    
                    # Apply augmentation
                    augmented_img = augmentation_transform(img)
                    
                    # Save augmented image
                    augmented_img_path = os.path.join(action_dir, f"augmented_{i}.jpg")
                    augmented_img.save(augmented_img_path)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        img = Image.open(img_path).convert("RGB")  # Ensure it's in RGB format
        if self.transform:
            img = self.transform(img)
        return img, label


class ActionVideoDatasetSingle(Dataset):
    def __init__(self, video_dir, num_frames=16, transform=None):
        self.video_dir = video_dir
        self.num_frames = num_frames
        self.transform = transform
        self.classes = os.listdir(video_dir)  # List of action folder names
        self.label_to_class = {i: cls for i, cls in enumerate(self.classes)}  # Map index to folder name

        # Select one video per action label folder
        self.video_files = []
        for label, cls in enumerate(self.classes):
            action_folder = os.path.join(video_dir, cls)
            videos = [f for f in os.listdir(action_folder) if f.endswith('.mp4')]
            if videos:
                # Pick the first video from the folder (or a random video)
                self.video_files.append((os.path.join(action_folder, videos[0]), label))

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path, label = self.video_files[idx]
        cap = cv2.VideoCapture(video_path)
        frames = []
        while len(frames) < self.num_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)  # Convert NumPy array to PIL image
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()

        # Pad missing frames with zeros if the video is too short
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        frames = torch.stack(frames)
        return frames, label, video_path


# Transforms
image_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

video_transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Parameters for class balancing
target_count = 15  # Define the target number of images per class

# Load datasets with augmentation for imbalance correction
train_dataset = ActionImageDataset(root_dir='action_sp', transform=image_transform, target_count=target_count, augment=True)
test_dataset = ActionVideoDatasetSingle(video_dir='resized_videos', transform=video_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Confirm updated dataset statistics
print(f"Number of training samples: {len(train_dataset)}")


Number of training samples: 2100


In [22]:
import os
from PIL import Image
import torch
from torchvision.transforms import (
    Compose, RandomHorizontalFlip, RandomRotation, ColorJitter,
    RandomResizedCrop, GaussianBlur, RandomPerspective
)
from tqdm import tqdm

class BalancedAugmentor:
    def __init__(self, root_dir, target_count=100, transform=None):
        """
        Args:
            root_dir (str): Path to the original dataset root directory.
            target_count (int): Desired number of images per action label after augmentation.
            transform (callable, optional): Transform to apply for augmentation.
        """
        self.root_dir = root_dir
        self.target_count = target_count
        self.transform = transform if transform else self.default_augmentation()

    def default_augmentation(self):
        """
        Defines the augmentation pipeline for intense augmentation.
        """
        return Compose([
            RandomResizedCrop(size=(224, 224), scale=(0.8, 1.2)),  # Random crop and resize
            RandomHorizontalFlip(p=0.5),
            RandomRotation(degrees=30),  # Random rotation
            ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
            GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2.0)),
            RandomPerspective(distortion_scale=0.5, p=0.5),
        ])

    def augment_class(self, label_dir, label, current_count):
        """
        Augments a specific class to reach the target number of images.
        Args:
            label_dir (str): Directory containing images for the specific class label.
            label (str): Name of the label.
            current_count (int): Current number of images in the class folder.
        """
        image_files = os.listdir(label_dir)
        num_to_augment = self.target_count - current_count

        if num_to_augment > 0:
            print(f"Augmenting class '{label}' with {num_to_augment} new images.")
            for i in range(num_to_augment):
                img_path = os.path.join(label_dir, image_files[i % len(image_files)])
                img = Image.open(img_path).convert("RGB")  # Ensure all images are in RGB mode

                # Apply augmentation
                augmented_img = self.transform(img)

                # Save augmented image
                augmented_img_path = os.path.join(label_dir, f"{label}_aug_{i}.jpg")
                augmented_img.save(augmented_img_path)

    def balance_dataset(self):
        """
        Balances all classes by augmenting underrepresented ones.
        """
        action_labels = os.listdir(self.root_dir)
        for label in tqdm(action_labels, desc="Balancing Classes"):
            label_dir = os.path.join(self.root_dir, label)
            if not os.path.isdir(label_dir):
                continue  # Skip if it's not a directory
            current_count = len(os.listdir(label_dir))
            if current_count < self.target_count:
                self.augment_class(label_dir, label, current_count)

# Parameters
root_dir = 'action_sp'  # Original dataset path
target_count = 100  # Desired number of images per action label

# Perform balancing augmentation
augmentor = BalancedAugmentor(root_dir, target_count=target_count)
augmentor.balance_dataset()

# Verify final counts
final_counts = {label: len(os.listdir(os.path.join(root_dir, label))) for label in os.listdir(root_dir)}
print(f"Final image counts per class: {final_counts}")


Balancing Classes:   0%|          | 0/21 [00:00<?, ?it/s]

Augmenting class 'questioning and answering' with 84 new images.


Balancing Classes:   5%|▍         | 1/21 [00:00<00:11,  1.80it/s]

Augmenting class 'giving or receiving award' with 81 new images.


Balancing Classes:  10%|▉         | 2/21 [00:00<00:08,  2.27it/s]

Augmenting class 'wrapping present' with 85 new images.


Balancing Classes:  14%|█▍        | 3/21 [00:01<00:07,  2.41it/s]

Augmenting class 'wrestling' with 85 new images.


Balancing Classes:  19%|█▉        | 4/21 [00:01<00:06,  2.48it/s]

Augmenting class 'checking tires' with 85 new images.


Balancing Classes:  24%|██▍       | 5/21 [00:02<00:06,  2.40it/s]

Augmenting class 'air drumming' with 82 new images.


Balancing Classes:  29%|██▊       | 6/21 [00:02<00:05,  2.54it/s]

Augmenting class 'playing trombone' with 85 new images.


Balancing Classes:  33%|███▎      | 7/21 [00:02<00:05,  2.60it/s]

Augmenting class 'milking cow' with 85 new images.


Balancing Classes:  38%|███▊      | 8/21 [00:03<00:05,  2.46it/s]

Augmenting class 'applauding' with 82 new images.


Balancing Classes:  43%|████▎     | 9/21 [00:03<00:04,  2.52it/s]

Augmenting class 'moving furniture' with 85 new images.


Balancing Classes:  48%|████▊     | 10/21 [00:04<00:04,  2.56it/s]

Augmenting class 'abseiling' with 65 new images.


Balancing Classes:  52%|█████▏    | 11/21 [00:04<00:03,  2.75it/s]

Augmenting class 'pushing wheelchair' with 85 new images.


Balancing Classes:  57%|█████▋    | 12/21 [00:04<00:03,  2.66it/s]

Augmenting class 'riding elephant' with 85 new images.


Balancing Classes:  62%|██████▏   | 13/21 [00:05<00:03,  2.59it/s]

Augmenting class 'opening bottle' with 85 new images.


Balancing Classes:  67%|██████▋   | 14/21 [00:05<00:02,  2.47it/s]

Augmenting class 'playing harp' with 85 new images.


Balancing Classes:  71%|███████▏  | 15/21 [00:06<00:02,  2.47it/s]

Augmenting class 'applying cream' with 85 new images.


Balancing Classes:  76%|███████▌  | 16/21 [00:06<00:01,  2.51it/s]

Augmenting class 'throwing axe' with 85 new images.


Balancing Classes:  81%|████████  | 17/21 [00:06<00:01,  2.44it/s]

Augmenting class 'archery' with 81 new images.


Balancing Classes:  86%|████████▌ | 18/21 [00:07<00:01,  2.54it/s]

Augmenting class 'doing aerobics' with 80 new images.


Balancing Classes:  90%|█████████ | 19/21 [00:07<00:00,  2.48it/s]

Augmenting class 'juggling soccer ball' with 81 new images.


Balancing Classes:  95%|█████████▌| 20/21 [00:08<00:00,  2.49it/s]

Augmenting class 'waxing chest' with 85 new images.


Balancing Classes: 100%|██████████| 21/21 [00:08<00:00,  2.49it/s]

Final image counts per class: {'questioning and answering': 100, 'giving or receiving award': 100, 'wrapping present': 100, 'wrestling': 100, 'checking tires': 100, 'air drumming': 100, 'playing trombone': 100, 'milking cow': 100, 'applauding': 100, 'moving furniture': 100, 'abseiling': 100, 'pushing wheelchair': 100, 'riding elephant': 100, 'opening bottle': 100, 'playing harp': 100, 'applying cream': 100, 'throwing axe': 100, 'archery': 100, 'doing aerobics': 100, 'juggling soccer ball': 100, 'waxing chest': 100}





In [23]:
# # Initialize model, loss, optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RViT(num_classes=21, hidden_dim=512, num_layers=4, frame_dim=(3, 224, 224)).to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Initialize model, loss, optimizer
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = RViT(num_classes=5, hidden_dim=128, num_layers=4, frame_dim=(3, 224, 224)).to(device)
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)



In [2]:

# num_epochs = 10
# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     for images, labels in train_loader:
#         images = images.unsqueeze(2).to(device)
#         labels = labels.to(device)

#         optimizer.zero_grad()
#         logits, _ = model(images)
#         loss = criterion(logits, labels)
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     scheduler.step()  # Adjust learning rate
#     print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")


# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.unsqueeze(2).to(device)  # Add temporal dimension
        labels = labels.to(device)
        
        optimizer.zero_grad()
        logits, _ = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")

In [3]:
import os

def count_images_in_folder(folder_path, extensions=('jpg', 'jpeg', 'png')):
    total_images = 0
    
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith(extensions):
                total_images += 1
                
    return total_images

# Example usage
folder_path = "action_sp"  # Replace with your folder path
total_images = count_images_in_folder(folder_path)
print(f"Total images: {total_images}")


Total images: 2098


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RViT(num_classes=21, hidden_dim=512, num_layers=4, frame_dim=(3, 224, 224)).to(device)

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

# Parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Assuming the model is defined
scaler = GradScaler()  # Mixed-precision training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()  # Replace with your loss function
gradient_accumulation_steps = 4  # Accumulate gradients for 4 steps
num_epochs = 10

# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    
    for step, (inputs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
        inputs, labels = inputs.to(device), labels.to(device)
        if inputs.ndim == 4:  # If depth dimension is missing
            inputs = inputs.unsqueeze(2)
        
        with autocast():
            outputs = model(inputs)
            if isinstance(outputs, tuple):  # Extract logits if model returns a tuple
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            loss = loss / gradient_accumulation_steps

        
        # Backpropagation with scaling
        scaler.scale(loss).backward()

        # Update weights after gradient accumulation steps
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        running_loss += loss.item() * gradient_accumulation_steps

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")


  scaler = GradScaler()  # Mixed-precision training
  with autocast():
Epoch 1/10: 100%|██████████| 36/36 [07:53<00:00, 13.15s/it]


Epoch 1/10, Loss: 0.3281


Epoch 2/10: 100%|██████████| 36/36 [08:07<00:00, 13.54s/it]


Epoch 2/10, Loss: 0.3229


Epoch 3/10: 100%|██████████| 36/36 [08:10<00:00, 13.63s/it]


Epoch 3/10, Loss: 0.3120


Epoch 4/10: 100%|██████████| 36/36 [08:04<00:00, 13.46s/it]


Epoch 4/10, Loss: 0.3115


Epoch 5/10: 100%|██████████| 36/36 [07:52<00:00, 13.14s/it]


Epoch 5/10, Loss: 0.3113


Epoch 6/10: 100%|██████████| 36/36 [07:38<00:00, 12.74s/it]


Epoch 6/10, Loss: 0.3069


Epoch 7/10: 100%|██████████| 36/36 [08:09<00:00, 13.59s/it]


Epoch 7/10, Loss: 0.3058


Epoch 8/10: 100%|██████████| 36/36 [08:35<00:00, 14.31s/it]


Epoch 8/10, Loss: 0.3064


Epoch 9/10: 100%|██████████| 36/36 [08:23<00:00, 13.99s/it]


Epoch 9/10, Loss: 0.3065


Epoch 10/10: 100%|██████████| 36/36 [07:58<00:00, 13.29s/it]

Epoch 10/10, Loss: 0.3027





In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RViT(num_classes=21, hidden_dim=512, num_layers=6, frame_dim=(3, 224, 224)).to(device)

In [12]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

scaler = GradScaler()  # Mixed-precision training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()  # Replace with your loss function
gradient_accumulation_steps = 4  # Accumulate gradients for 4 stepsscaler = GradScaler()  # Mixed-precision training
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()  # Replace with your loss function
gradient_accumulation_steps = 4  # Accumulate gradients for 4 steps

# Track epoch losses
epoch_losses = []
num_epochs = 20
# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    
    for step, (inputs, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
        inputs, labels = inputs.to(device), labels.to(device)
        if inputs.ndim == 4:  # If depth dimension is missing
            inputs = inputs.unsqueeze(2)
        
        with autocast():
            outputs = model(inputs)
            if isinstance(outputs, tuple):  # Extract logits if model returns a tuple
                outputs = outputs[0]
            loss = criterion(outputs, labels)
            loss = loss / gradient_accumulation_steps

        # Backpropagation with scaling
        scaler.scale(loss).backward()

        # Update weights after gradient accumulation steps
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        
        running_loss += loss.item() * gradient_accumulation_steps

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_losses.append(epoch_loss)  # Save epoch loss
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# Plot the loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o', label='Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()


  scaler = GradScaler()  # Mixed-precision training
  with autocast():
Epoch 1/20: 100%|██████████| 210/210 [1:30:29<00:00, 25.85s/it]


Epoch 1/20, Loss: 0.3170


Epoch 2/20: 100%|██████████| 210/210 [1:28:28<00:00, 25.28s/it]


Epoch 2/20, Loss: 0.3099


Epoch 3/20: 100%|██████████| 210/210 [1:22:57<00:00, 23.70s/it]


Epoch 3/20, Loss: 0.3084


Epoch 4/20:  65%|██████▍   | 136/210 [50:44<27:36, 22.39s/it] 


KeyboardInterrupt: 

In [8]:
# import matplotlib.pyplot as plt
# import numpy as np

# def visualize_attention_map(frame, attention_weights, output_path="attention_map.png"):
#     """
#     Visualize attention map overlayed on a video frame.
#     Args:
#         frame: Original frame (H, W, C)
#         attention_weights: Attention weights (H, W)
#     """
#     # Normalize attention weights
#     attention_weights = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min())
    
#     # Resize attention weights to match the frame size
#     attention_resized = cv2.resize(attention_weights, (frame.shape[1], frame.shape[0]))
    
#     # Create heatmap
#     heatmap = cv2.applyColorMap((attention_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
    
#     # Blend heatmap with original frame
#     overlay = cv2.addWeighted(frame, 0.6, heatmap, 0.4, 0)
    
#     # Save or display the result
#     plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
#     plt.axis('off')
#     plt.savefig(output_path, bbox_inches='tight')
#     plt.close()

# # Example usage
# sample_frame = cv2.imread("sample_frame.jpg")  # Replace with an actual frame
# sample_attention = np.random.rand(56, 56)  # Replace with actual attention weights
# visualize_attention_map(sample_frame, sample_attention)



In [28]:
inputs = inputs.unsqueeze(2)  # Adds depth dimension
print(inputs.shape)  # Should be [batch_size, channels, depth, height, width]


torch.Size([10, 3, 1, 224, 224])


In [37]:
# Validate the fine-tuned model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for video_frames, true_label, video_path in test_loader:
        video_frames = video_frames.permute(0, 2, 1, 3, 4).to(device)
        logits, _ = model(video_frames)
        predicted_class = torch.argmax(logits, dim=1).item()
        total += 1
        correct += (predicted_class == true_label.item())
        print(f"Video: {video_path}")
        print(f"True Label: {true_label.item()}, Predicted Label: {predicted_class}")

accuracy = correct / total * 100
print(f"Validation Accuracy: {accuracy:.4f}%")


Video: ('resized_videos/giving or receiving award/JBWeDivEHFI.mp4',)
True Label: 0, Predicted Label: 10
Video: ('resized_videos/wrapping present/HscLLuC-PQs.mp4',)
True Label: 1, Predicted Label: 10


KeyboardInterrupt: 

In [87]:
model.eval()
with torch.no_grad():
    for video_frames, true_label, video_path in test_loader:
        video_frames = video_frames.permute(0, 2, 1, 3, 4).to(device)
        logits, _ = model(video_frames)
        
        predicted_class = torch.argmax(logits, dim=1)[0].item()
        predicted_folder = test_dataset.label_to_class[predicted_class]
        true_folder = test_dataset.label_to_class[true_label.item()]  # Convert tensor to integer

        print(f"Video: {video_path}")
        print(f"True Label: {true_folder}, Predicted Label: {predicted_folder}")
# Testing loop
# model.eval()
# with torch.no_grad():
#     for video_frames, true_label, video_path in test_loader:
#         video_frames = video_frames.permute(0, 2, 1, 3, 4).to(device)
#         logits, _ = model(video_frames)
        
#         predicted_class = torch.argmax(logits, dim=1)[0].item()
#         predicted_folder = test_dataset.label_to_class[predicted_class]
#         true_folder = test_dataset.label_to_class[true_label.item()]  # Convert tensor to integer

#         print(f"Video: {video_path}")
#         print(f"True Label: {true_folder}, Predicted Label: {predicted_folder}")


Video: ('resized_videos/giving or receiving award/JBWeDivEHFI.mp4',)
True Label: giving or receiving award, Predicted Label: giving or receiving award
Video: ('resized_videos/wrapping present/HscLLuC-PQs.mp4',)
True Label: wrapping present, Predicted Label: giving or receiving award
Video: ('resized_videos/wrestling/yxJHCSA35Ns.mp4',)
True Label: wrestling, Predicted Label: giving or receiving award
Video: ('resized_videos/answering questions/j6ogVLOLQug.mp4',)
True Label: answering questions, Predicted Label: giving or receiving award
Video: ('resized_videos/checking tires/pqlY2l0KhEY.mp4',)
True Label: checking tires, Predicted Label: giving or receiving award
Video: ('resized_videos/air drumming/5M80ZTWfzOU.mp4',)
True Label: air drumming, Predicted Label: giving or receiving award
Video: ('resized_videos/kissing/dNh1nMXCJs8.mp4',)
True Label: kissing, Predicted Label: giving or receiving award
Video: ('resized_videos/playing trombone/ScZAUtVAShI.mp4',)
True Label: playing trombone,

In [88]:
from collections import Counter

class_counts = Counter(train_dataset.labels)
print(f"Class distribution: {class_counts}")


Class distribution: Counter({1: 3753, 3: 2655, 0: 2258, 4: 2220, 2: 1715})


In [89]:
# class RViT(nn.Module):
#     def __init__(self, num_classes, hidden_dim, num_layers, frame_dim):
#         super().__init__()
#         self.hidden_dim = hidden_dim

#         # Patch embedding
#         self.patch_embedding = nn.Conv3d(3, hidden_dim, kernel_size=(3, 8, 8), stride=(3, 4, 4), padding=(1, 2, 2))
#         self.position_encoding = nn.Parameter(torch.randn(1, hidden_dim, 1, 56, 56), requires_grad=True)

#         # Attention layers
#         self.scaled_attention = ScaledDotProductAttention(hidden_dim)
#         self.linear_attention = LinearAttention(hidden_dim)

#         # Recurrent Vision Transformer Units
#         self.rvit_units = nn.ModuleList([RViTUnit(hidden_dim) for _ in range(num_layers)])
        
#         # Classifier
#         self.classifier = nn.Linear(hidden_dim, num_classes)

#     def forward(self, x):
#         # Patch embedding
#         patches = self.patch_embedding(x)
        
#         # Add positional encoding
#         _, _, depth, height, width = patches.shape
#         pos_encoding = F.interpolate(self.position_encoding, size=(depth, height, width), mode='trilinear', align_corners=False)
#         patches += pos_encoding

#         # Attention mechanisms
#         scaled_attn_output = self.scaled_attention(patches, patches)
#         lin_attn_output = self.linear_attention(scaled_attn_output, scaled_attn_output)

#         # Recurrent processing
#         h = torch.zeros_like(lin_attn_output).to(lin_attn_output.device)
#         for unit in self.rvit_units:
#             h = unit(lin_attn_output, h)

#         # Classification
#         h_last = h.mean(dim=(2, 3, 4))  # Global average pooling
#         logits = self.classifier(h_last)
#         return logits, h


In [94]:
import matplotlib.pyplot as plt

another test

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RViT(nn.Module):
    def __init__(self, num_classes, hidden_dim, num_layers, frame_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.patch_embedding = nn.Conv3d(3, hidden_dim, kernel_size=(3, 8, 8), stride=(3, 4, 4), padding=(1, 2, 2))
        self.position_encoding = nn.Parameter(torch.randn(1, hidden_dim, 1, 56, 56), requires_grad=True)
        self.rvit_units = nn.ModuleList([RViTUnit(hidden_dim) for _ in range(num_layers)])
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        patches = self.patch_embedding(x)
        _, _, depth, height, width = patches.shape
        pos_encoding = F.interpolate(self.position_encoding, size=(depth, height, width), mode='trilinear', align_corners=False)
        patches += pos_encoding
        h = torch.zeros_like(patches).to(patches.device)
        for unit in self.rvit_units:
            h = unit(patches, h)
        h_last = h.mean(dim=(2, 3, 4))
        logits = self.classifier(h_last)
        return logits, h

class RViTUnit(nn.Module):
    def __init__(self, hidden_dim, dropout_rate=0.5):
        super().__init__()
        self.attention_gate = LinearAttention(hidden_dim)
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, h):
        attn_output = self.attention_gate(x, h)
        attn_output = self.dropout(attn_output)
        batch_size, hidden_dim, depth, height, width = attn_output.shape
        h_flat = h.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
        attn_output_flat = attn_output.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
        h_new_flat = self.layer_norm1(h_flat + attn_output_flat)
        h_new = h_new_flat.reshape(batch_size, depth, height, width, hidden_dim).permute(0, 4, 1, 2, 3)
        h_new_flat = h_new.permute(0, 2, 3, 4, 1).reshape(-1, hidden_dim)
        h_new_flat = self.layer_norm2(h_new_flat + self.ffn(h_new_flat))
        h_new = h_new_flat.reshape(batch_size, depth, height, width, hidden_dim).permute(0, 4, 1, 2, 3)
        return h_new

class LinearAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.Wq = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wk = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)
        self.Wv = nn.Conv3d(hidden_dim, hidden_dim, kernel_size=1)

    def forward(self, x, h):
        q = self.Wq(x).flatten(start_dim=2)
        k = self.Wk(h).flatten(start_dim=2)
        v = self.Wv(h).flatten(start_dim=2)
        attn_weights = torch.bmm(q.transpose(1, 2), k) / (k.size(1) ** 0.5)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.bmm(v, attn_weights.transpose(1, 2))
        return attn_output.view_as(h)


In [17]:
import os
import cv2
from PIL import Image
from collections import Counter
from random import shuffle
from torch.utils.data import Dataset

class ActionImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Dataset for loading images grouped by classes into subfolders.
        Each subfolder represents a class.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []  # List to hold image paths
        self.labels = []  # List to hold corresponding class labels
        self.classes = sorted(os.listdir(root_dir))  # List of class names (subfolder names)

        for label, action in enumerate(self.classes):  # Iterate through subfolders
            action_dir = os.path.join(root_dir, action)
            if not os.path.isdir(action_dir):
                continue  # Skip non-folder files (if any)
            for img_file in os.listdir(action_dir):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):  # Check for image files
                    img_path = os.path.join(action_dir, img_file)
                    self.data.append(img_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        return img, label



class ActionVideoDatasetSingle(Dataset):
    def __init__(self, video_dir, num_frames=16, transform=None):
        self.video_dir = video_dir
        self.num_frames = num_frames
        self.transform = transform
        self.classes = os.listdir(video_dir)
        self.video_files = []
        for label, cls in enumerate(self.classes):
            action_folder = os.path.join(video_dir, cls)
            videos = [f for f in os.listdir(action_folder) if f.endswith('.mp4')]
            shuffle(videos)
            if videos:
                self.video_files.append((os.path.join(action_folder, videos[0]), label))

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path, label = self.video_files[idx]
        cap = cv2.VideoCapture(video_path)
        frames = []
        while len(frames) < self.num_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        frames = torch.stack(frames)
        return frames, label


In [18]:
import os
import cv2
from PIL import Image
from collections import Counter
from random import shuffle
from torch.utils.data import Dataset

class ActionImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Dataset for loading images directly from a folder containing images (no subfolders).
        """
        self.root_dir = root_dir
        self.transform = transform
        self.data = []  # List to hold image paths
        self.labels = []  # List to hold corresponding action labels (optional, can be all 0)

        # Single class definition for compatibility
        self.classes = ['action_sp']  # Define a single class name

        # Collect all .jpg files in the directory
        for img_file in os.listdir(root_dir):
            if img_file.endswith('.jpg'):
                img_path = os.path.join(root_dir, img_file)  # Full path to the image
                self.data.append(img_path)
                self.labels.append(0)  # Assign a single class label (e.g., 0)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        label = self.labels[idx]  # Will always be 0 in this case
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        img = Image.fromarray(img)  # Convert NumPy array to PIL image
        if self.transform:
            img = self.transform(img)
        return img, label



class ActionVideoDatasetSingle(Dataset):
    def __init__(self, video_dir, num_frames=16, transform=None):
        self.video_dir = video_dir
        self.num_frames = num_frames
        self.transform = transform
        self.classes = os.listdir(video_dir)
        self.video_files = []
        for label, cls in enumerate(self.classes):
            action_folder = os.path.join(video_dir, cls)
            videos = [f for f in os.listdir(action_folder) if f.endswith('.mp4')]
            shuffle(videos)
            if videos:
                self.video_files.append((os.path.join(action_folder, videos[0]), label))

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path, label = self.video_files[idx]
        cap = cv2.VideoCapture(video_path)
        frames = []
        while len(frames) < self.num_frames:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        frames = torch.stack(frames)
        return frames, label


In [14]:
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import DataLoader, WeightedRandomSampler
from collections import Counter
import torch.optim as optim

# Transforms
transform = Compose([Resize((224, 224)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# Datasets
train_dataset = ActionImageDataset(root_dir='action_sp', transform=transform)
test_dataset = ActionVideoDatasetSingle(video_dir='resized_videos', transform=transform)

# Weighted Sampler
class_counts = Counter(train_dataset.labels)
class_weights = [1.0 / class_counts[label] for label in train_dataset.labels]
sampler = WeightedRandomSampler(weights=class_weights, num_samples=len(train_dataset), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)
test_loader = DataLoader(test_dataset, batch_size=1)

# Model, Loss, Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RViT(num_classes=len(train_dataset.classes), hidden_dim=128, num_layers=4, frame_dim=(3, 224, 224)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.unsqueeze(2).to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

# Save Pretrained Model
torch.save(model.state_dict(), "rvit_pretrained.pth")

# Fine-Tuning
model.load_state_dict(torch.load("rvit_pretrained.pth"))
model.train()
for param in model.parameters():
    param.requires_grad = True
for epoch in range(5):
    running_loss = 0.0
    for video_frames, labels in test_loader:
        video_frames = video_frames.permute(0, 2, 1, 3, 4).to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(video_frames)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Fine-Tune Epoch {epoch+1}, Loss: {running_loss/len(test_loader)}")


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [19]:
train_dataset = ActionImageDataset(root_dir='action_sp', transform=transform)
print(f"Number of samples in dataset: {len(train_dataset)}")  # Should be > 0
print(f"Classes: {train_dataset.classes}")  # Should list all subfolder names
print(f"First few samples: {train_dataset.data[:5]}")  # Should show file paths
print(f"First few labels: {train_dataset.labels[:5]}")  # Should show class labels


Number of samples in dataset: 0
Classes: ['action_sp']
First few samples: []
First few labels: []
