In [21]:
import gradio as gr
import numpy as np
import cv2
import mediapipe as mp
import base64
import tensorflow as tf
from PIL import Image
from io import BytesIO
import json

mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils

def load_model(model_path):
    return tf.keras.models.load_model(model_path)

models = {
    'Getting Started': load_model('getting_started_v2.h5'),
    'Polite Expressions': load_model('polite_expressions.h5'),
    'Questions': load_model('questions.h5'),
    'Responses': load_model('responses.h5'),
    'Compliments': load_model('compliments.h5'),
    'Feelings': load_model('feelings.h5'),
    'Objects': load_model('objects.h5'),
    'Animals': load_model('animals.h5'),
    'Fruits': load_model('fruits.h5'),
    'The Sky': load_model('the_sky.h5'),
}

model_classes = {
    'Getting Started': ['Hello', 'Goodbye', 'Dad', 'Mom', 'I love you'],
    'Polite Expressions': ['Please', 'Excuse Me', 'Thank You', 'Sorry', 'You are Welcome'],
    'Questions': ['Question', 'Where', 'Who', 'Why', 'What'],
    'Responses': ['Yes', 'No', 'Now', 'Later', 'Tomorrow'],
    'Compliments': ['Beautiful', 'Cute', 'Nice', 'Funny', 'Smart'],      
    'Feelings': ['Happy', 'Sad', 'Proud', 'Excited', 'Hungry'],
    'Objects': ['Computer', 'Phone', 'Camera', 'Bag', 'Toothbrush'],
    'Animals': ['Llama', 'Horse', 'Cat', 'Pig', 'Goat'],
    'The Sky': ['Sky', 'Sun', 'Moon', 'Clouds', 'Stars'],
    'Fruits': ['Fruit', 'Apple', 'Orange', 'Strawberry', 'Grapes'],
}

def extract_keypoints(results, frame_width, frame_height):
    pose = np.array([[landmark.x, landmark.y, landmark.z, landmark.visibility] for landmark in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
    lh = np.array([[landmark.x, landmark.y, landmark.z] for landmark in results.left_hand_landmarks.landmark]).flatten() if results.left_hand_landmarks else np.zeros(21*3)
    rh = np.array([[landmark.x, landmark.y, landmark.z] for landmark in results.right_hand_landmarks.landmark]).flatten() if results.right_hand_landmarks else np.zeros(21*3)

    # Normalization
    pose[::4] = pose[::4] / frame_width
    pose[1::4] = pose[1::4] / frame_height
    lh[::3] = lh[::3] / frame_width
    lh[1::3] = lh[1::3] / frame_height
    rh[::3] = rh[::3] / frame_width
    rh[1::4] = rh[1::4] / frame_height

    return np.concatenate([pose, lh, rh])

def process_image_data_uri(data_uri):
    # Convert base64 to PIL Image
    header, encoded = data_uri.split(",", 1)
    image_data = base64.b64decode(encoded)
    pil_image = Image.open(BytesIO(image_data))
    return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

def is_all_masked(sequence, mask_value=0):
    """Check if the entire sequence is masked."""
    return (sequence == mask_value).all()

def extract_keypoints_from_sequence(image_data_uris):
    keypoints_sequence = []
    holistic = mp_holistic.Holistic(min_detection_confidence=0.5, min_tracking_confidence=0.5)
    
    for data_uri in image_data_uris:
        image = process_image_data_uri(data_uri)

        # MediaPipe keypoint extraction
        results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        keypoints = extract_keypoints(results, image.shape[1], image.shape[0])
        keypoints_sequence.append(keypoints)
        
    holistic.close()
    return np.array(keypoints_sequence)

def predict_sequence(model_name, image_data_uris_json):
    image_data_uris = json.loads(image_data_uris_json)
    keypoints_sequence = extract_keypoints_from_sequence(image_data_uris)

    if is_all_masked(keypoints_sequence):
        return "Nothing to detect. Try Again"
    else:
        try:
            # Shaped as [1, 40, 258]
            if keypoints_sequence.shape[0] == 40 and keypoints_sequence.shape[1] == 258:
                keypoints_sequence = np.expand_dims(keypoints_sequence, axis=0)
            else:
                # Handle incorrect shape
                return "Incorrect keypoints sequence shape. Expected 40 sets of 258 keypoints."
    
            current_model = models[model_name]
            if current_model is None:
                return "Model not found."
            
            actions = model_classes[model_name]
            res = current_model.predict(keypoints_sequence)[0]
            top_action_indices = np.argsort(res)[-3:][::-1]
            top_actions = [(actions[index], res[index]*100) for index in top_action_indices if res[index]*100 > 30]
            
            if not top_actions: 
                return "Try again"
            else:
                return ', '.join([f"{name} ({prob:.2f}%)" for name, prob in top_actions])
        except Exception as e:
            return f"Error processing sequence: {str(e)}"
        
# Gradio interface
iface = gr.Interface(
    fn=predict_sequence,
    inputs=[
        gr.Dropdown(label="Model Selection", choices=list(models.keys())),
        gr.Textbox(label="Enter Image Data URIs Here (JSON array)")
    ],
    outputs=gr.Text(label="Predicted Action"),
    title="Hand Sign Prediction with aslmodel_v2"
)

if __name__ == "__main__":
    iface.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


In [22]:
iface.close()

Closing server running on port: 7860
