In [None]:
import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn as nn
import os
import time

# MediaPipe Holistic initialisieren
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils

# Konstanten
LABELS = ['g', 'h', 'l', 'x']  # Die 4 Klassen des Modells

# Modellklasse definieren (identisch zum Training)
class HandSignNet(nn.Module):
    def __init__(self, num_classes=4):
        super(HandSignNet, self).__init__()
        
        # Feature Extraction Blocks
        self.features = nn.Sequential(
            nn.Linear(63, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Landmark-Extraktion für 63 Features
def extract_landmarks(results):
    # Rechte Hand-Keypoints (21 Landmarks x 3 Koordinaten = 63 Features)
    if results.right_hand_landmarks:
        # Extrahiere die 21 Handlandmarks mit xyz-Koordinaten
        hand_landmarks = np.array([[lm.x, lm.y, lm.z] for lm in results.right_hand_landmarks.landmark])
        
        # Normalisieren - relative Position zum Handgelenk (Landmark 0)
        wrist = hand_landmarks[0].copy()
        hand_landmarks = hand_landmarks - wrist
        
        # Flachdrücken zu einem 63D-Vektor
        features = hand_landmarks.flatten()
    else:
        # Wenn keine Hand erkannt wird, gib einen Nullvektor zurück
        features = np.zeros(63)
    
    return features

# Modell laden
def load_model(model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Gerät: {device}")
    
    # Modell initialisieren
    model = HandSignNet()
    
    # Modell laden
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"Modell erfolgreich geladen von: {model_path}")
    except Exception as e:
        print(f"Fehler beim Laden des Modells: {e}")
        return None
    
    model.to(device)
    model.eval()
    return model

# Hauptfunktion
def main():
    # Modellpfad
    model_path = "/home/geiger/asl_detection/machine_learning/models/asl_now/best_model.pth"
    # Modell laden
    model = load_model(model_path)
    if model is None:
        print("Konnte Modell nicht laden. Beende Programm.")
        return
    
    # Webcam initialisieren
    cap = cv2.VideoCapture(0)
    
    # Text-Einstellungen
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1.5
    font_thickness = 2
    
    # Frame-Zähler und Puffer für stabilere Vorhersagen
    frame_counter = 0
    prediction_buffer = []
    buffer_size = 5
    last_prediction = ""
    
    # MediaPipe Holistic starten
    with mp_holistic.Holistic(
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5) as holistic:
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                print("Fehler beim Erfassen des Frames")
                break
            
            # Frame für MediaPipe umwandeln
            image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image.flags.writeable = False
            
            # Keypoints erkennen
            results = holistic.process(image)
            
            # Zurück zu BGR umwandeln
            image.flags.writeable = True
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            
            # Landmarks zeichnen (nur Hände, da unser Modell nur Handdaten verwendet)
            mp_drawing.draw_landmarks(
                image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
                mp_drawing.DrawingSpec(color=(121, 22, 76), thickness=2, circle_radius=4),
                mp_drawing.DrawingSpec(color=(121, 44, 250), thickness=2, circle_radius=2)
            )
            mp_drawing.draw_landmarks(
                image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS,
                mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
                mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
            )
            
            # Nur alle paar Frames eine Vorhersage machen
            frame_counter += 1
            if frame_counter % 5 == 0:  # Alle 5 Frames
                if results.right_hand_landmarks:  # Wir verwenden nur die rechte Hand für die Vorhersage
                    # Features extrahieren
                    landmarks = extract_landmarks(results)
                    
                    # Modellvorhersage
                    with torch.no_grad():
                        landmarks_tensor = torch.FloatTensor(landmarks).unsqueeze(0)
                        outputs = model(landmarks_tensor)
                        _, predicted = torch.max(outputs.data, 1)
                        
                        # Vorhersage zum Puffer hinzufügen
                        prediction_buffer.append(LABELS[predicted.item()])
                        
                        # Puffer-Größe begrenzen
                        if len(prediction_buffer) > buffer_size:
                            prediction_buffer.pop(0)
                        
                        # Häufigste Vorhersage auswählen
                        if prediction_buffer:
                            from collections import Counter
                            last_prediction = Counter(prediction_buffer).most_common(1)[0][0]
            
            # Vorhersage anzeigen
            cv2.rectangle(image, (0, 0), (200, 100), (245, 117, 16), -1)
            cv2.putText(image, last_prediction, (60, 60), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
            
            # Ergebnis anzeigen
            cv2.imshow('ASL Erkennung', image)
            
            # Abbruch bei 'q' drücken
            if cv2.waitKey(10) & 0xFF == ord('q'):
                break
        
        cap.release()
        cv2.destroyAllWindows()

if __name__ == "__main__":
    main() 