In [None]:
import cv2
import mediapipe as mp
import numpy as np
import torch
import torch.nn as nn
import time
from collections import Counter

# Define the model class (identical to training)
class HandSignNet(nn.Module):
    def __init__(self, num_classes=24):
        super(HandSignNet, self).__init__()
        
        # Feature Extraction Blocks
        self.features = nn.Sequential(
            nn.Linear(63, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(128, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Landmark extraction for 63 features (exactly like in training)
def extract_landmarks(results):
    """
    Extract keypoints of the signing hand (right hand from the viewer's perspective)
    Exact same method as in the training code
    """
    # Initialize array for the signing hand (21 keypoints with x, y, z)
    hand_keypoints = np.zeros(21 * 3)
    
    if results.multi_hand_landmarks:
        # If multiple hands are detected, find the correct hand
        for hand_idx, hand_landmarks in enumerate(results.multi_hand_landmarks):
            # The hand classification is from the camera's perspective
            handedness = results.multi_handedness[hand_idx].classification[0].label
            if handedness == "Right":  # We are looking for the right hand from the camera's perspective
                hand_keypoints = np.array([[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark]).flatten()
                break
        # If no right hand was found, take the first detected hand
        if np.all(hand_keypoints == 0) and results.multi_hand_landmarks:
            hand_landmarks = results.multi_hand_landmarks[0]
            hand_keypoints = np.array([[lm.x, lm.y, lm.z] for lm in hand_landmarks.landmark]).flatten()
    
    return hand_keypoints

def preprocess_image(image):
    """
    Image is returned unchanged (no transformations)
    """
    return image

class ASLPredictor:
    def __init__(self, model_path='/workspaces/asl_detection/machine_learning/models/asl_now/best_model.pth'):
        # Initialize MediaPipe Hands (like in training)
        self.mp_hands = mp.solutions.hands
        self.mp_drawing = mp.solutions.drawing_utils
        self.hands = self.mp_hands.Hands(
            static_image_mode=False,
            max_num_hands=2,
            min_detection_confidence=0.2,
            min_tracking_confidence=0.2)

        # Load model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        self.model = HandSignNet().to(self.device)
        self.model.load_state_dict(torch.load(model_path, map_location=self.device))
        self.model.eval()
        
        # Letters mapping (all letters except j and z)
        self.letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y']
        print("Model loaded successfully!")
        
        # Buffer for more stable predictions
        self.prediction_buffer = []
        self.buffer_size = 5
        self.last_prediction = ""

    def predict_frame(self, frame):
        """Processes a frame and returns the prediction"""
        # Preprocess the image (now without transformations)
        frame = preprocess_image(frame)
        
        # Convert to RGB for MediaPipe
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = self.hands.process(frame_rgb)
        
        # Draw hand landmarks
        if results.multi_hand_landmarks:
            for hand_landmarks in results.multi_hand_landmarks:
                self.mp_drawing.draw_landmarks(
                    frame, 
                    hand_landmarks, 
                    self.mp_hands.HAND_CONNECTIONS,
                    self.mp_drawing.DrawingSpec(color=(245, 117, 66), thickness=2, circle_radius=4),
                    self.mp_drawing.DrawingSpec(color=(245, 66, 230), thickness=2, circle_radius=2)
                )

        # Make a prediction when a hand is detected (exactly like in training)
        if results.multi_hand_landmarks:
            # Extract features
            landmarks = extract_landmarks(results)
            
            # Model prediction
            with torch.no_grad():
                landmarks_tensor = torch.FloatTensor(landmarks).unsqueeze(0).to(self.device)
                outputs = self.model(landmarks_tensor)
                probabilities = torch.softmax(outputs, dim=1)
                confidence, prediction = torch.max(probabilities, dim=1)
                
                # Get letter and confidence
                predicted_letter = self.letters[prediction.item()]
                confidence_value = confidence.item()
                
                # Add prediction to buffer
                self.prediction_buffer.append(predicted_letter)
                
                # Limit buffer size
                if len(self.prediction_buffer) > self.buffer_size:
                    self.prediction_buffer.pop(0)
                
                # Select the most common prediction
                if self.prediction_buffer:
                    most_common = Counter(self.prediction_buffer).most_common(1)
                    self.last_prediction = most_common[0][0]
                    frequency = most_common[0][1] / len(self.prediction_buffer)
                    
                    # Display the prediction
                    if frequency > 0.6:  # Only display if more than 60% of predictions match
                        cv2.rectangle(frame, (0, 0), (200, 100), (245, 117, 16), -1)
                        cv2.putText(frame, self.last_prediction.upper(), 
                                  (60, 60), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2)
                        cv2.putText(frame, f"Conf: {confidence_value:.2f}", 
                                  (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        
        return frame

def main():
    # Initialize predictor
    predictor = ASLPredictor()
    
    # Open webcam
    cap = cv2.VideoCapture(0)
    
    print("Starting real-time detection... (Press 'q' to exit)")
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("Error reading webcam.")
            break
            
        # Process frame
        frame = predictor.predict_frame(frame)
        
        # Show frame
        cv2.imshow('ASL Letter Detection', frame)
        
        # Exit on 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()

# For notebook usage
class NotebookPredictor:
    def __init__(self, model_path='/workspaces/asl_detection/machine_learning/models/asl_now/best_model.pth'):
        self.predictor = ASLPredictor(model_path)
        
    def process_webcam(self, num_frames=100):
        """Processes a specific number of frames from the webcam"""
        from IPython.display import clear_output, Image, display
        import PIL.Image
        
        # Open webcam
        cap = cv2.VideoCapture(0)
        if not cap.isOpened():
            print("Error: Could not open webcam")
            return
            
        try:
            for _ in range(num_frames):
                ret, frame = cap.read()
                if not ret:
                    print("Error reading webcam")
                    break
                
                # Process frame
                frame = self.predictor.predict_frame(frame)
                
                # Convert to RGB for notebook display
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                
                # Display frame
                clear_output(wait=True)
                display(PIL.Image.fromarray(frame_rgb))
                
                # Small pause for smoother display
                time.sleep(0.1)
        finally:
            cap.release()
            
        print("Detection completed.")

if __name__ == "__main__":
    main() 