In [1]:
import os
import random
import cv2
import mediapipe as mp
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# Paths to datasets
video_path = '../data/ZJ-videos'
alphabet_path = '../data/mnist_asl_alphabet_train'

In [3]:
# Set the number of frames to pad to
selected_frame_dim = 180  # Example value
padding_value = torch.zeros((1, 21, 3), dtype=torch.float32)  # Padding value

# Target size for resizing frames and images
target_size = (224, 224)

# Label mapping
label_mapping = {'J': 0, 'Z': 1, 'nothing': 3}

In [4]:
# Initialize Mediapipe Hands
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5)

I0000 00:00:1721164616.508004  783115 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1721164616.525374  783174 gl_context.cc:357] GL version: 3.1 (OpenGL ES 3.1 Mesa 23.2.1-1ubuntu3.1~22.04.2), renderer: D3D12 (NVIDIA GeForce RTX 3060)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [5]:
# Augment an image
def augment_image(image, flip=False):
    # Random rotation
    angle = random.uniform(-15, 15)
    height, width = image.shape[:2]
    M = cv2.getRotationMatrix2D((width // 2, height // 2), angle, 1)
    rotated = cv2.warpAffine(image, M, (width, height))
    
    # Flip horizontally if specified
    if flip:
        rotated = cv2.flip(rotated, 1)
    
    return rotated

In [6]:
# Returns list of tuples of videos, with their corresponding label
def load_videos(path, label):
    video_files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.avi')]
    video_files.sort()
    return [(f, label_mapping[label]) for f in video_files]

# Returns list of tuples of images, with their corresponding label
def load_images(path, labels):
    image_files = []
    for label in labels:
        files = [os.path.join(path, label, f) for f in os.listdir(os.path.join(path, label)) if f.endswith('.jpg')]
        files = random.sample(files, 12)  # Take 12 images per label
        image_files.extend([(f, label_mapping['nothing']) for f in files])  # Label all images as 'nothing'
    return image_files

# [(x, y), (x, y),..,]

In [7]:
# Function to extract landmarks from a video file
def extract_landmarks_from_video(video_file, target_size=(224, 224), selected_frame_dim=180):
    cap = cv2.VideoCapture(video_file)
    frames = []
    corrupted_video = False

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, target_size)  # Resize frame
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(frame_rgb)

        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                frame_landmarks = torch.tensor([[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark], dtype=torch.float32)
                frames.append(frame_landmarks.unsqueeze(0))  # Add a batch dimension
        else:
            frames.append(torch.zeros((1, 21, 3), dtype=torch.float32))  # If no hand detected, append zero landmarks

    cap.release()

    # Check for video corruption due to missing landmarks
    if len(frames) == 0 or not all([torch.any(frame != 0) for frame in frames]):
        corrupted_video = True

    # Pad or trim frames to the selected frame dimension
    padded_frames_tensor = pad_sequence(frames, selected_frame_dim, padding_value=torch.zeros((1, 21, 3), dtype=torch.float32))

    return padded_frames_tensor, corrupted_video

In [8]:
# Function to extract landmarks from an image
def extract_landmarks_from_image(image_file, target_size=(224, 224)):
    image = cv2.imread(image_file)
    if image is None:
        return None, True  # Mark as corrupted if image cannot be read
    
    image = cv2.resize(image, target_size)  # Resize image
    augmented_image = augment_image(image)  # Apply augmentation
    image_rgb = cv2.cvtColor(augmented_image, cv2.COLOR_BGR2RGB)
    results = hands.process(image_rgb)
    
    if results.multi_hand_landmarks:
        for hand_landmarks in results.multi_hand_landmarks:
            return torch.tensor([[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark], dtype=torch.float32), False
    else:
        return torch.zeros((1, 21, 3), dtype=torch.float32), True  # If no hand detected, return zero landmarks and mark as corrupted

In [9]:
# Function to pad or trim frames to a specified length
def pad_sequence(sequence, target_length, padding_value):
    padded_sequence = []
    for frame in sequence:
        padded_sequence.append(frame)
    while len(padded_sequence) < target_length:
        padded_sequence.append(padding_value)
    return torch.cat(padded_sequence[:target_length], dim=0)

In [10]:
# Process videos
def process_videos(video_files):
    data = []
    labels = []
    for video_file, label in tqdm(video_files, desc='Processing videos'):
        landmarks, is_corrupted = extract_landmarks_from_video(video_file)
        if not is_corrupted and landmarks is not None:
            data.append(landmarks)
            labels.append(label)
    return torch.stack(data), torch.tensor(labels, dtype=torch.int64)

# Process images
def process_images(image_files):
    data = []
    labels = []
    for image_file, label in tqdm(image_files, desc='Processing images'):
        landmarks, is_corrupted = extract_landmarks_from_image(image_file)
        if not is_corrupted and landmarks is not None:
            data.append(landmarks)
            labels.append(label)
    return torch.stack(data), torch.tensor(labels, dtype=torch.int64)

In [11]:
# Example usage:
video_file = '../data/ZJ-videos/j/24.avi'
landmarks, is_corrupted = extract_landmarks_from_video(video_file)

if not is_corrupted and landmarks is not None:
    # Ensure landmarks are tensors
    landmarks_tensor = torch.tensor(landmarks, dtype=torch.float32)
    print(f"Successfully extracted landmarks from video: {video_file}")
    print(f"Shape of landmarks tensor: {landmarks_tensor.shape}")
else:
    print(f"Video {video_file} is corrupted or has no valid landmarks.")

Video ../data/ZJ-videos/j/24.avi is corrupted or has no valid landmarks.


[mjpeg @ 0x8b54000] overread 8


In [12]:
# j/24.avi is not valid

In [None]:
# Load and process datasets
video_files = load_videos(os.path.join(video_path, 'j'), 'J') + load_videos(os.path.join(video_path, 'z'), 'Z')
image_files = load_images(alphabet_path, list('ABCDEFGHIJKLMNOPQRSTUVWXYZ'))

train_videos, val_videos = train_test_split(video_files, test_size=0.2, random_state=42)
train_images, val_images = train_test_split(image_files, test_size=0.2, random_state=42)

train_video_data, train_video_labels = process_videos(train_videos)
val_video_data, val_video_labels = process_videos(val_videos)
train_image_data, train_image_labels = process_images(train_images)
val_image_data, val_image_labels = process_images(val_images)

train_data = torch.cat((train_video_data, train_image_data), dim=0)
train_labels = torch.cat((train_video_labels, train_image_labels), dim=0)
val_data = torch.cat((val_video_data, val_image_data), dim=0)
val_labels = torch.cat((val_video_labels, val_image_labels), dim=0)

Processing videos:   1%|▌                                                               | 5/569 [00:04<08:17,  1.13it/s][mjpeg @ 0x9229680] overread 8
Processing videos:   1%|▉                                                               | 8/569 [00:06<07:19,  1.28it/s][mjpeg @ 0x9305b40] overread 8
Processing videos:   2%|█▎                                                             | 12/569 [00:09<07:34,  1.22it/s][mjpeg @ 0x8b71e40] overread 8
Processing videos:   2%|█▌                                                             | 14/569 [00:10<07:03,  1.31it/s][mjpeg @ 0x9311180] overread 8
Processing videos:   5%|███▏                                                           | 29/569 [00:25<08:49,  1.02it/s][mjpeg @ 0x9342840] overread 8
Processing videos:   5%|███▍                                                           | 31/569 [00:26<07:31,  1.19it/s][mjpeg @ 0x91fd640] overread 8
Processing videos:   7%|████                                                           | 37/56

In [None]:
# Save the data
torch.save((train_data, train_labels), 'train_data.pt')
torch.save((val_data, val_labels), 'val_data.pt')

In [None]:
# Visualize train data
def visualize_data(data, labels, num_examples=50):
    # Select a random example from the data
    random_idx = random.randint(0, data.shape[0] - 1)
    selected_data = data[random_idx][:num_examples]  # Take the first 50 frames
    selected_label = labels[random_idx].item()

    # Print the respective label
    print(f"Label for selected example: {selected_label}")

    # Plot the hand landmarks for each frame
    fig, axes = plt.subplots(5, 10, figsize=(20, 10))
    fig.suptitle(f'Hand Landmarks for First {num_examples} Frames - Label: {selected_label}', fontsize=16)

    for i, ax in enumerate(axes.flat):
        if i >= selected_data.shape[0]:
            break
        ax.scatter(selected_data[i, :, 0], selected_data[i, :, 1], c='b', marker='o')
        for j in range(selected_data.shape[1]):
            ax.text(selected_data[i, j, 0], selected_data[i, j, 1], str(j), fontsize=9)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.invert_yaxis()
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'Frame {i+1}')

    plt.tight_layout()
    plt.show()

# Visualize train data
visualize_data(train_data, train_labels)

In [None]:
import random

In [None]:
def load_data(filename):
    return torch.load(filename)

# Load data
train_data, train_labels = load_data('train_data.pt')
val_data, val_labels = load_data('val_data.pt')

In [None]:
# Select a random example from the training data
random_idx = random.randint(0, train_data.shape[0] - 1)
selected_data = train_data[random_idx][:50]  # Take the first 50 frames
selected_label = train_labels[random_idx].item()

# Print the respective label
print(f"Label for selected example: {selected_label}")

# Plot the hand landmarks for each frame
fig, axes = plt.subplots(5, 10, figsize=(20, 10))
fig.suptitle(f'Hand Landmarks for First 50 Frames - Label: {selected_label}', fontsize=16)

for i, ax in enumerate(axes.flat):
    if i >= selected_data.shape[0]:
        break
    ax.scatter(selected_data[i, :, 0], selected_data[i, :, 1], c='b', marker='o')
    for j in range(selected_data.shape[1]):
        ax.text(selected_data[i, j, 0], selected_data[i, j, 1], str(j), fontsize=9)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.invert_yaxis()
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f'Frame {i+1}')

plt.tight_layout()
plt.show()

In [None]:
# Debugging functions
def extract_landmarks_from_video(video_file, target_size=(224, 224), max_frames=50):
    cap = cv2.VideoCapture(video_file)
    frames = []
    landmarks = []
    frame_count = 0
    while cap.isOpened() and frame_count < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, target_size)  # Resize frame
        frames.append(frame)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(frame_rgb)
        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                frame_landmarks = [[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark]
                landmarks.append(frame_landmarks)
        else:
            landmarks.append([[0, 0, 0]] * 21)  # If no hand detected, append zero landmarks
        frame_count += 1
    cap.release()
    return frames, landmarks

# Function to normalize landmarks based on image dimensions
def normalize_landmarks(landmarks, image_width, image_height):
    normalized_landmarks = []
    for frame_landmarks in landmarks:
        normalized_frame_landmarks = [[lm[0] * image_width, lm[1] * image_height, lm[2]] for lm in frame_landmarks]
        normalized_landmarks.append(normalized_frame_landmarks)
    return normalized_landmarks

# Function to visualize frames
def visualize_frames(frames):
    fig, axes = plt.subplots(5, 10, figsize=(20, 10))
    fig.suptitle('Frames', fontsize=16)
    
    for i, ax in enumerate(axes.flat):
        if i >= len(frames):
            break
        ax.imshow(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'Frame {i+1}')
    
    plt.tight_layout()
    plt.show()

# Function to visualize landmarks separately
def visualize_landmarks(frames, landmarks):
    image_height, image_width, _ = frames[0].shape
    normalized_landmarks = normalize_landmarks(landmarks, image_width, image_height)
    
    fig, axes = plt.subplots(5, 10, figsize=(20, 10))
    fig.suptitle('Hand Landmarks for First 50 Frames', fontsize=16)
    
    for i, ax in enumerate(axes.flat):
        if i >= len(normalized_landmarks):
            break
        ax.scatter([lm[0] for lm in normalized_landmarks[i]], [lm[1] for lm in normalized_landmarks[i]], c='b', marker='o')
        for j in range(len(normalized_landmarks[i])):
            ax.text(normalized_landmarks[i][j][0], normalized_landmarks[i][j][1], str(j), fontsize=9)
        ax.set_xlim(0, image_width)
        ax.set_ylim(0, image_height)
        ax.invert_yaxis()
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f'Frame {i+1}')
    
    plt.tight_layout()
    plt.show()

# Path to the video file
video_file_path = '../data/ZJ-videos/j/25.avi'  # Update this path

# Extract and visualize the first 50 frames and landmarks
frames, landmarks = extract_landmarks_from_video(video_file_path)

# Visualize frames
visualize_frames(frames)

# Visualize landmarks
visualize_landmarks(frames, landmarks)

# MediaPipe coordinates are normalized between 0 and 1, based on the grid size