In [None]:
import cv2
import mediapipe as mp
import torch
import numpy as np
import json

In [None]:
base_path = "/workspaces/asl_detection/machine_learning/models/lstm"
MODEL_PATH = f"{base_path}/lstm_model.pth"
LABELS_PATH = f"{base_path}/label_to_index.json"

mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils

In [None]:
class LSTMModel(torch.nn.Module):
    def __init__(self, input_size=225, hidden_size=256, num_layers=3, output_size=168):
        super(LSTMModel, self).__init__()
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out.mean(dim=1)) 
        return out

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMModel().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

with open(LABELS_PATH, "r") as f:
    label_to_index = json.load(f)
index_to_label = {v: k for k, v in label_to_index.items()}

with open(LABELS_PATH, "r") as f:
    label_to_index = json.load(f)
index_to_label = {v: k for k, v in label_to_index.items()}

In [None]:
test_word = input("Welches Wort willst du testen? ").strip().lower()
if test_word not in label_to_index:
    print(f"❌ Fehler: '{test_word}' existiert nicht in den Labels!")
    exit()

test_label_index = label_to_index[test_word]

In [None]:
def extract_keypoints(results):
    pose = np.array([[res.x, res.y, res.z] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33 * 3)
    left_hand = np.array([[res.x, res.y, res.z] for res in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21 * 3)
    right_hand = np.array([[res.x, res.y, res.z] for res in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21 * 3)
    
    return np.concatenate([pose, left_hand, right_hand])  # Exakt 225 Werte

In [None]:
cap = cv2.VideoCapture(0)
sequence = []
frame_count = 102

with mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image.flags.writeable = False
        results = holistic.process(image)
        image.flags.writeable = True
        image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        
        keypoints = extract_keypoints(results)
        sequence.append(keypoints)

        if len(sequence) > frame_count:
            sequence.pop(0)

        if len(sequence) == frame_count:
            input_tensor = torch.tensor([sequence], dtype=torch.float32).to(device)
            with torch.no_grad():
                output = model(input_tensor)
            pred_label = torch.argmax(torch.nn.functional.softmax(output, dim=1), dim=1).item()
            recognized_word = index_to_label.get(pred_label, "Unbekannt")

            color = (0, 255, 0) if pred_label == test_label_index else (0, 0, 255)
            cv2.putText(image, recognized_word, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
            print(f"🔍 Erwartet: {test_word} | Erkannt: {recognized_word}")

        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
        mp_drawing.draw_landmarks(image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
        mp_drawing.draw_landmarks(image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)

        cv2.imshow("ASL Testmodus", image)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

cap.release()
cv2.destroyAllWindows()