In [None]:
import cv2
import mediapipe as mp
import torch
import numpy as np
import json
import time

# --- PATHS ---
base_path = "/workspaces/asl_detection/machine_learning/models/lstm"
MODEL_PATH = f"{base_path}/best_lstm_model.pth"
LABELS_PATH = f"{base_path}/label_to_index.json"

# --- Initialize Mediapipe ---
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils

# --- Define LSTM Model ---
class LSTMModel(torch.nn.Module):
    def __init__(self, input_size=225, hidden_size=128, num_layers=2, output_size=209):  # Reduced model complexity
        super(LSTMModel, self).__init__()
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.1)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out.mean(dim=1)) 
        return torch.log_softmax(out, dim=1)  # Log-softmax for stability

# --- Load Model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMModel().to(device)

try:
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint)
    model.eval()
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Model loading error: {e}")
    exit()

# --- Load Labels ---
try:
    with open(LABELS_PATH, "r") as f:
        label_to_index = json.load(f)
    index_to_label = {v: k for k, v in label_to_index.items()}
    print(f"✅ {len(label_to_index)} Labels loaded!")
except Exception as e:
    print(f"❌ Error loading labels: {e}")
    exit()

# --- Extract Keypoints ---
def extract_keypoints(results):
    try:
        pose = np.array([[res.x, res.y, res.z] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33 * 3)
        left_hand = np.array([[res.x, res.y, res.z] for res in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21 * 3)
        right_hand = np.array([[res.x, res.y, res.z] for res in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21 * 3)
        return np.concatenate([pose, left_hand, right_hand])
    except Exception as e:
        print(f"❌ Error extracting keypoints: {e}")
        return np.zeros(225)

# --- Pad sequence if movement is shorter than required frames ---
def pad_sequence(sequence, target_length=64):
    while len(sequence) < target_length:
        sequence.append(sequence[-1])  # Repeat last frame to match required length
    return sequence

# --- Capture from Webcam ---
cap = cv2.VideoCapture(0)
if not cap.isOpened():
    print("❌ Error: Webcam could not be opened!")
    exit()

sequence = []
frame_count = 64  # Default sequence length for prediction
last_prediction_time = time.time()
predicted_word = "Waiting..."

with mp_holistic.Holistic(min_detection_confidence=0.3, min_tracking_confidence=0.3) as holistic:  # Lower precision for better performance
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("❌ Error: No image received from the camera!")
            break

        # Convert frame to RGB and process with Mediapipe
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = holistic.process(image)
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

        # Extract keypoints and store in sequence
        keypoints = extract_keypoints(results)
        sequence.append(keypoints)

        # Maintain a dynamic sequence length, allowing longer movements
        if len(sequence) > frame_count:
            sequence.pop(0)

        # Perform prediction only if at least 20 frames are collected
        if len(sequence) >= 20 and (time.time() - last_prediction_time) > 1.5:
            padded_sequence = pad_sequence(sequence, frame_count)  # Pad if shorter
            input_tensor = torch.tensor([padded_sequence], dtype=torch.float32).to(device)
            with torch.no_grad():
                output = model(input_tensor)
            pred_label = torch.argmax(output, dim=1).item()
            predicted_word = index_to_label[pred_label] if pred_label in index_to_label else "Unknown"
            last_prediction_time = time.time()

        # Display recognized word and frame count
        cv2.putText(image, f"Recognized: {predicted_word}", (10, 50),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
        cv2.putText(image, f"Frames: {len(sequence)}/{frame_count}", (10, 80),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        # Draw detected landmarks
        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
        mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
        mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)

        # Add reset frames button
        cv2.putText(image, "Press 'r' to reset frames", (10, 140),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
        
        # Display video feed
        cv2.imshow("ASL Test Mode", image)
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            break
        elif key == ord('r'):
            sequence = []  # Reset sequence
            print("🔄 Frames reset!")

cap.release()
cv2.destroyAllWindows()


❌ Model loading error: Error(s) in loading state_dict for LSTMModel:
	Unexpected key(s) in state_dict: "lstm.weight_ih_l2", "lstm.weight_hh_l2", "lstm.bias_ih_l2", "lstm.bias_hh_l2". 
	size mismatch for lstm.weight_ih_l0: copying a param with shape torch.Size([1024, 225]) from checkpoint, the shape in current model is torch.Size([512, 225]).
	size mismatch for lstm.weight_hh_l0: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for lstm.bias_ih_l0: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for lstm.bias_hh_l0: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for lstm.weight_ih_l1: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for lstm.weight_hh_l1: co

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1741548050.993450   47846 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1741548051.020329   47846 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1741548051.022887   47852 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1741548051.023242   47855 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1741548051.023663   47847 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1741548051.030334   47

: 