In [None]:
import gradio as gr
import numpy as np
import cv2
import torch
from facenet_pytorch import MTCNN
from sklearn.cluster import KMeans
from collections import Counter
import matplotlib.pyplot as plt
from PIL import Image
import tensorflow as tf
import io
import base64

# Step 1: Face Detection Branch
class FaceDetectionBranch:
    def __init__(self, image_size=160, margin=20, min_face_size=20, thresholds=[0.6, 0.7, 0.7], factor=0.709):
        """
        Initialize the face detection branch with MTCNN
        """
        # Check if MPS is available (Apple Silicon)
        self.device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
        print(f"Using device: {self.device}")

        # Initialize MTCNN
        self.mtcnn = MTCNN(
            image_size=image_size,
            margin=margin,
            min_face_size=min_face_size,
            thresholds=thresholds,
            factor=factor,
            device=self.device
        )

    def detect_faces(self, image):
        """
        Detect faces in the input image
        """
        # Convert BGR to RGB (MTCNN expects RGB)
        if isinstance(image, np.ndarray) and image.shape[-1] == 3:
            # Check if image is already RGB
            if image.dtype == np.uint8 and image.shape[2] == 3:
                rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            else:
                rgb_image = image
        else:
            # Convert PIL Image to numpy array if needed
            if isinstance(image, Image.Image):
                rgb_image = np.array(image)
            else:
                rgb_image = image

        # Detect faces
        boxes, probs = self.mtcnn.detect(rgb_image)

        # Convert to [x, y, width, height] format
        faces = []
        if boxes is not None:
            for box in boxes:
                x, y, x2, y2 = box
                width = x2 - x
                height = y2 - y
                faces.append([int(x), int(y), int(width), int(height)])

        return faces

    def crop_faces(self, image, faces, padding=0.2):
        """
        Crop detected faces from the original image
        """
        # Convert PIL Image to numpy array if needed
        if isinstance(image, Image.Image):
            image = np.array(image)

        # If image is RGB, convert to BGR for OpenCV processing
        if image.shape[-1] == 3 and image.dtype == np.uint8:
            image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        else:
            image_bgr = image

        height, width = image_bgr.shape[:2]
        cropped_faces = []

        for face in faces:
            x, y, w, h = face

            # Add padding
            pad_w = int(w * padding)
            pad_h = int(h * padding)

            # Calculate padded coordinates with boundary checks
            x1 = max(0, x - pad_w)
            y1 = max(0, y - pad_h)
            x2 = min(width, x + w + pad_w)
            y2 = min(height, y + h + pad_h)

            # Crop face
            face_image = image_bgr[y1:y2, x1:x2]

            # Only add if we have a valid crop
            if face_image.size > 0:
                cropped_faces.append(face_image)

        return cropped_faces

# Step 2: Skin Tone Extractor
class SkinToneExtractor:
    def __init__(self, n_clusters=4):
        """
        Initialize the skin tone extractor
        """
        self.n_clusters = n_clusters

        # Define skin tone ranges in HSV color space
        # These are general ranges that work well for most skin tones
        self.lower_hsv = np.array([0, 20, 70], dtype=np.uint8)
        self.upper_hsv = np.array([25, 255, 255], dtype=np.uint8)

    def extract_skin_mask(self, image):
        """
        Extract skin regions from the face image
        """
        # Convert to HSV color space
        hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        # Create skin mask using HSV thresholds
        skin_mask = cv2.inRange(hsv_image, self.lower_hsv, self.upper_hsv)

        # Apply morphological operations to clean up the mask
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        skin_mask = cv2.erode(skin_mask, kernel, iterations=1)
        skin_mask = cv2.dilate(skin_mask, kernel, iterations=1)

        # Apply Gaussian blur to smooth the mask
        skin_mask = cv2.GaussianBlur(skin_mask, (3, 3), 0)

        return skin_mask

    def find_dominant_color(self, image, mask):
        """
        Find the dominant color in the skin regions
        """
        # Apply mask to the original image
        skin_only = cv2.bitwise_and(image, image, mask=mask)

        # Reshape the image to be a list of pixels (only non-zero values)
        pixels = skin_only.reshape(-1, 3)
        pixels = pixels[np.all(pixels != [0, 0, 0], axis=1)]

        # If no skin pixels are found, return a default value
        if len(pixels) == 0:
            return np.array([211, 169, 150])  # Default skin tone

        # Cluster the skin pixels
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=10)
        kmeans.fit(pixels)

        # Get the most common cluster
        counts = Counter(kmeans.labels_)
        center_idx = counts.most_common(1)[0][0]
        dominant_color = kmeans.cluster_centers_[center_idx].astype(int)

        return dominant_color

    def extract_skin_tone(self, face_image):
        """
        Extract the skin tone from a face image
        """
        # If face_image is a list with one element, extract the element
        if isinstance(face_image, list) and len(face_image) == 1:
            face_image = face_image[0]

        # Extract skin mask
        skin_mask = self.extract_skin_mask(face_image)

        # Find dominant color
        dominant_color = self.find_dominant_color(face_image, skin_mask)

        # Convert BGR to RGB
        rgb_color = dominant_color[::-1]

        # Convert RGB to hex code
        hex_code = '#{:02x}{:02x}{:02x}'.format(rgb_color[0], rgb_color[1], rgb_color[2])

        return hex_code, rgb_color

# Function for Fitzpatrick scale classification
def classify_skin_tone(rgb_color):
    """
    Classify skin tone according to Fitzpatrick scale
    """
    # Calculate Individual Typology Angle (ITA)
    # ITA = arctan((L* - 50) / b*) * 180/π
    # Convert RGB to L*a*b*
    rgb_normalized = np.array([[rgb_color]], dtype=np.float32) / 255.0
    bgr_normalized = rgb_normalized[:, :, ::-1]  # RGB to BGR
    lab_color = cv2.cvtColor(bgr_normalized, cv2.COLOR_BGR2Lab)[0][0]

    L = lab_color[0]
    b = lab_color[2] - 128  # Adjust b* value (b* in OpenCV is from 0-255, center at 128)

    # Calculate ITA
    if b == 0:
        b = 0.01  # Avoid division by zero
    ita = np.arctan((L - 50) / b) * 180 / np.pi

    # Classify based on ITA values
    if ita > 55:
        return "Type I - Very fair, always burns"
    elif 48 <= ita <= 55:
        return "Type II - Fair, usually burns"
    elif 41 <= ita < 48:
        return "Type III - Medium, sometimes burns"
    elif 30 <= ita < 41:
        return "Type IV - Olive, rarely burns"
    elif 19 <= ita < 30:
        return "Type V - Brown, very rarely burns"
    else:
        return "Type VI - Dark brown to black, never burns"

# Palette recommendation
def predict_palette(hex_code):
    """
    Generate a color palette based on the skin tone hex code
    Returns 5 complementary colors as hex codes
    """
    # Define color palettes for different skin tone ranges
    light_palette = ["#F4C2C2", "#AEC6CF", "#E6E6FA", "#FFFFF0", "#FFDB58"]
    medium_palette = ["#D2B48C", "#FFDB58", "#E2725B", "#778899", "#B0C4DE"]
    dark_palette = ["#FFD700", "#00A86B", "#DC143C", "#4169E1", "#800080"]

    # Extract RGB values from hex
    hex_code = hex_code.lstrip('#')
    r, g, b = tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))

    # Simple classification based on brightness
    brightness = (r * 299 + g * 587 + b * 114) / 1000

    if brightness > 170:
        return light_palette
    elif brightness > 100:
        return medium_palette
    else:
        return dark_palette

# Detect faces and extract skin tone
def process_face_image(image):
    """
    Process an image to detect faces and extract skin tone
    Returns skin tone hex code, palette, and fitzpatrick type
    """
    if image is None:
        return None, None, None

    # Initialize face detection
    face_detector = FaceDetectionBranch()

    # Convert PIL Image to numpy if needed
    img_for_detection = np.array(image) if isinstance(image, Image.Image) else image

    # Detect faces
    faces = face_detector.detect_faces(img_for_detection)

    # If no faces detected, return None
    if not faces:
        return None, None, None

    # Crop faces
    cropped_faces = face_detector.crop_faces(img_for_detection, faces)

    if not cropped_faces:
        return None, None, None

    # Initialize skin tone extractor
    skin_extractor = SkinToneExtractor()

    # Extract skin tone from first face
    hex_code, rgb_color = skin_extractor.extract_skin_tone(cropped_faces[0])

    # Classify skin tone
    fitzpatrick_type = classify_skin_tone(rgb_color)

    # Generate palette
    palette = predict_palette(hex_code)

    return hex_code, palette, fitzpatrick_type

# Display results with HTML
def display_output(image):
    if image is None:
        return "<p>Please upload an image first.</p>", "<p>No palette available.</p>"

    # Process the image
    hex_code, palette, fitzpatrick_type = process_face_image(image)

    if hex_code is None:
        return "<p>No face detected in the image. Please try another image.</p>", "<p>No palette available.</p>"

    # Create HTML for skin tone display
    tone_section = f"""
    <div style='text-align:center; padding: 10px;'>
        <h3 style='text-align:center;'>Detected Skin Tone</h3>
        <div style='display:flex; justify-content:center; align-items:center; gap:20px;'>
            <div style='width:80px; height:80px; background:{hex_code};
                border-radius:8px; box-shadow:0 1px 4px rgba(0,0,0,0.1);'></div>
            <div>
                <h2 style='color:#333; margin:0;'>{hex_code}</h2>
                <p style='margin:5px 0 0 0;'>{fitzpatrick_type}</p>
            </div>
        </div>
    </div>
    """

    # Create HTML for palette display
    swatches = "<div style='display:flex; gap:12px; justify-content:center; margin-top:10px;'>"
    for color in palette:
        swatches += f"<div style='width:60px; height:60px; background:{color}; border-radius:8px; box-shadow:0 1px 4px rgba(0,0,0,0.1);'></div>"
    swatches += "</div>"

    palette_section = f"""
    <div style='padding: 10px;'>
        <h3 style='text-align:center;'>Recommended Color Palette</h3>
        {swatches}
    </div>
    """

    return tone_section, palette_section

# Custom CSS
custom_css = """
body {
    background: #fdf6f0;
    font-family: 'Helvetica Neue', sans-serif;
    color: #222;
}
#upload-btn {
    margin: auto;
    border: 2px dashed #ccc;
    padding: 16px;
    background: white;
    box-shadow: 0 4px 12px rgba(0,0,0,0.05);
    border-radius: 10px;
}
.gr-button {
    background-color: #FFBC80;
    color: #222;
    font-weight: bold;
    border: none;
}
.gr-button:hover {
    background-color: #ffaa55;
}
"""

# Create Gradio Interface
with gr.Blocks(css=custom_css, title="SkinShade AI") as demo:
    gr.Markdown("""
        <div style='text-align:center; max-width: 640px; margin: auto; padding-top: 40px;'>
            <h1 style='font-size: 32px;'>Skin Tone & Color Palette Recommender</h1>
            <p style='font-size: 16px; color: #444;'>
                Upload a photo to find your skin tone and get personalized color suggestions.
            </p>
        </div>
    """)

    image_input = gr.Image(type="pil", label="Upload Your Photo", elem_id="upload-btn")
    tone_output = gr.HTML()
    palette_output = gr.HTML()
    btn = gr.Button("Detect Tone")

    btn.click(display_output, inputs=image_input, outputs=[tone_output, palette_output])

    gr.Markdown("""
        <div style='text-align:center; margin-top: 40px; color: #666; font-size: 14px;'>
            We don't store any images. Everything runs locally during the session.
        </div>
    """)

if __name__ == "__main__":
    demo.launch(share=True)

ModuleNotFoundError: No module named 'gradio'