In [4]:
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.io import read_video
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

# Define Dataset Class
class YogaVideoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.video_files = [f for f in os.listdir(root_dir) if f.endswith('.mp4')]  # Update file extension if needed
        
        if not self.video_files:
            print("No video files found in the directory.")
        else:
            print(f"Found {len(self.video_files)} video files.")

        self.label_map = {'downward_dog': 0, 'tree_pose': 1, 'warrior_pose': 2, 'bridge_pose': 3, 'plank': 4}  # Update as per your labels

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.video_files[idx])
        video, _, info = read_video(video_path, pts_unit='sec')
        label = self.label_map[self.video_files[idx].split('_')[0]]  # Assuming label is part of the filename

        if self.transform:
            video = self.transform(video)
        
        return video, label

# Transformations to apply to each video (resize and normalize)
frame_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize frames
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalization (example)
])

# Define the Dataset and DataLoader
dataset = YogaVideoDataset(root_dir='D:/mini project NN/dataset/bhujangasan', transform=frame_transform) 
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Check the video shape and labels
for batch_idx, (videos, labels) in enumerate(dataloader):
    print(f"Video batch shape: {videos.shape}")
    print(f"Label batch: {labels}")
    
    # Visualizing a few frames from the batch (only display 5 frames to avoid memory overload)
    num_frames_to_show = min(videos.size(1), 5)  # Ensure we don't try to plot more than available frames
    
    fig, axes = plt.subplots(1, num_frames_to_show)  # Ensure no out-of-bounds indexing
    for i in range(num_frames_to_show):  # Iterate over the number of frames in the batch
        frame = videos[0, i].permute(1, 2, 0).numpy()  # Select one video (0th index) and convert to numpy
        axes[i].imshow(frame.astype('uint8'))  # Display the frame as an image
        axes[i].axis('off')
    
    plt.show()

    break  # Stop after one batch for testing


Found 18 video files.


KeyError: 'Copy of VID20250127142108.mp4'