In [2]:
import cv2
import torch
import time
import numpy as np
import pandas as pd
from VideoLoader import KeypointExtractor, read_video
from VideoDataset import process_keypoints
from model import SLR

import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [8]:

def record_video_tensor(prep_time=3, record_time=3, fps=30):
    cap = cv2.VideoCapture(0)
    if not cap.isOpened():
        raise RuntimeError("Cannot access the webcam")

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    print(f"Preparing to record at {width}x{height}, {fps} FPS...")

    # Preparation countdown
    start_prep = time.time()
    while time.time() - start_prep < prep_time:
        ret, frame = cap.read()
        if not ret:
            continue

        # Countdown overlay
        time_left = prep_time - int(time.time() - start_prep)
        cv2.putText(frame, f"Recording in {time_left}", (50, 100),
                    cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4)

        cv2.imshow("Preview", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            cap.release()
            cv2.destroyAllWindows()
            return None

    # Start recording
    print("Recording started!")
    num_frames = int(record_time * fps)
    frames = []
    start_record = time.time()

    while len(frames) < num_frames:
        ret, frame = cap.read()
        if not ret:
            break

        # Time remaining
        elapsed = time.time() - start_record
        remaining = max(0, record_time - elapsed)
        cv2.putText(frame, f"{remaining:.1f}s left", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)

        # Show frame
        cv2.imshow("Preview", frame)

        # Convert and store
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame_rgb)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

        # FPS control
        expected = len(frames) / fps
        if expected > elapsed:
            time.sleep(expected - elapsed)

    cap.release()
    cv2.destroyAllWindows()

    video_np = np.stack(frames)
    video_tensor = torch.from_numpy(video_np).float()

    print(f"Captured video tensor with shape {video_tensor.shape}")
    return video_tensor


In [3]:
gloss_info = pd.read_csv('./gloss.csv')
idx_to_word = {}
word_to_idx = {}
for i in range(len(gloss_info)):
    idx_to_word[gloss_info['idx'][i]] = gloss_info['word'][i]
    word_to_idx[gloss_info['word'][i]] = gloss_info['idx'][i]

In [5]:
# from model import SLR
# # model = SLR(
# #     n_embd=12*64, 
# #     n_cls_dict={'asl_citizen':2305, 'lsfb': 4657, 'wlasl':2000, 'autsl':226, 'rsl':1001},
# #     n_head=12, 
# #     n_layer=4,
# #     n_keypoints=63,
# #     dropout=0.2, 
# #     max_len=64,
# #     bias=True
# # )

# # model = torch.compile(model)
# # model.load_state_dict(torch.load('./models/small_model.pth', map_location=torch.device('cpu')))


# # Run a bigger model. About 2.5x larger. Validation accuracy is about the same however

from model import SLR
model = SLR(
    n_embd=16*64, 
    n_cls_dict={'asl_citizen':2305, 'lsfb': 4657, 'wlasl':2000, 'autsl':226, 'rsl':1001},
    n_head=16, 
    n_layer=6,
    n_keypoints=63,
    dropout=0.6, 
    max_len=64,
    bias=True
)

model = torch.compile(model)
model.load_state_dict(torch.load('./models/big_model.pth', map_location=torch.device('cpu')))




# model.eval()
# print(f'Trainable parameters: {model.num_params()}')

<All keys matched successfully>

In [6]:

# Load a video or record it:

#video = record_video_tensor(fps=20, record_time=3)
video = read_video('./example3.mp4')
video = video.permute(0, 3, 1, 2)/255


In [7]:
from VideoLoader import KeypointExtractor
# Over here it runs the media pipe model. Perhaps the biggest bottle neck overall. 

pose = KeypointExtractor().extract(video)
height, width = video.shape[-2], video.shape[-1]
del video


In [8]:
selected_keypoints = list(range(42)) 
selected_keypoints = selected_keypoints + [x + 42 for x in ([291, 267, 37, 61, 84, 314, 310, 13, 80, 14] + [152])]
selected_keypoints = selected_keypoints + [x + 520 for x in ([2, 5, 7, 8, 11, 12, 13, 14, 15, 16])]


flipped_selected_keypoints = list(range(21, 42)) + list(range(21)) 
flipped_selected_keypoints = flipped_selected_keypoints + [x + 42 for x in ([61, 37, 267, 291, 314, 84, 80, 13, 310, 14] + [152])]
flipped_selected_keypoints = flipped_selected_keypoints + [x + 520 for x in ([5, 2, 8, 7, 12, 11, 14, 13, 16, 15])]



In [15]:
from VideoDataset import process_keypoints
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# with augmentation to true, sample multiple frames and feed it to the model. take the average of the result.
# Since torch.compile is used, the model is compiled the first time it is ran. Running it afterwards will be faster.

sample_amount = 20 # Run the model 20 times

logits = 0
with torch.no_grad():
    model.eval()
    for i in range(sample_amount):
        keypoints, valid_keypoints = process_keypoints(pose, 64, selected_keypoints, height=height, width=width, augment=True)
        keypoints[:,:, 0] = keypoints[:,:, 0]
        logits = logits + model.heads['asl_citizen'](model(keypoints.unsqueeze(0), valid_keypoints.unsqueeze(0)))

idx = torch.argsort(logits, descending=True)[0].tolist()
idx[:10]

[1610, 1584, 2148, 138, 927, 414, 1366, 483, 972, 772]

In [14]:
sample_amount = 10
with torch.no_grad():
    model.eval()
    
    # Simple approach - collect individual samples
    keypoints_list = []
    valid_keypoints_list = []
    
    for i in range(sample_amount):
        # Get a single sample
        single_keypoints, single_valid_keypoints = process_keypoints(
            pose, 64, selected_keypoints, height=height, width=width, augment=True
        )
        
        # Add to list (not using subscript operations)
        keypoints_list.append(single_keypoints)
        valid_keypoints_list.append(single_valid_keypoints)
    
    # Stack when done
    keypoints_batch = torch.stack(keypoints_list)
    valid_keypoints_batch = torch.stack(valid_keypoints_list)
    
    # Process batch
    output_logits = model.heads['asl_citizen'](
        model(keypoints_batch, valid_keypoints_batch)
    )
    
    # Average logits
    logits = output_logits.mean(dim=0)
idx = torch.argsort(logits, descending=True)[0].tolist()
idx   

  keypoints = torch.tensor(keypoints[indices])


1584

In [16]:
print("Top 5 words")
print(', '.join([idx_to_word[idx[i]] for i in range(5)]))

Top 5 words
library, snorkel, moon, cross, i love you


In [11]:
idx.index(word_to_idx['hug']) # search for a word's idx

1796