In [None]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.models import resnet50, resnet34, ResNet50_Weights
from torch.amp import autocast

# =======================================================================================
# === CONFIGURATION & SETTINGS ===
# =======================================================================================

# --- Enable/Disable Models ---
USE_FASHIONPEDIA_MODEL = True
USE_UPAR_MODEL = True
USE_CELEBA_MODEL = True
USE_FAIRFACE_MODEL = True

# --- Input Image ---
IMAGE_PATH = "ss.jpg" # Change this to your image file

# --- Model File Paths ---
FACE_PROTO = "deploy.prototxt.txt"
FACE_MODEL = "res10_300x300_ssd_iter_140000.caffemodel"
PERSON_PROTO = "deploy.prototxt"
PERSON_MODEL = "mobilenet_iter_73000.caffemodel"
CELEBA_WEIGHTS = "best_celeba_model.pth"
FAIRFACE_WEIGHTS = "best_fairface_model.pth"
FASHIONPEDIA_WEIGHTS = "best_fashionpedia_model.pth"
UPAR_WEIGHTS = "best_upar_model.pth"
CELEBA_CSV = "celeba_cleaned.csv"

# --- Confidence Thresholds ---
FACE_DETECTION_CONF = 0.4
PERSON_DETECTION_CONF = 0.5
FASHIONPEDIA_PRED_CONF = 0.5

# --- Labels & Model-Specific Thresholds ---

# FairFace
fairface_race_labels = ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White']
fairface_gender_labels = ['Female', 'Male']
fairface_age_labels = ['0-2', '10-19', '20-29', '3-9', '30-39', '40-49', '50-59', '60-69', '70+']
fairface_thresholds = {"race": 0.50, "gender": 0.2, "age": 0.2}

# CelebA
celeba_thresholds = {
    "Male": 0.95, "Smiling": 0.50, "Wearing_Earrings": 0.50, "Heavy_Makeup": 0.75,
    "No_Beard": 0.95, "Eyeglasses": 0.70, "Young": 0.70,
}

# UPAR
upar_label_cols = [
    'Accessory-Backpack', 'Accessory-Bag', 'Accessory-Glasses-Normal', 'Accessory-Hat',
    'Age-Adult', 'Age-Young', 'Gender-Female', 'Hair-Length-Long', 'Hair-Length-Short'
]
upar_thresholds = {
    'Accessory-Backpack': 0.80, 'Accessory-Bag': 0.75, 'Accessory-Glasses-Normal': 0.85,
    'Accessory-Hat': 0.85, 'Age-Adult': 0.85, 'Age-Young': 0.85, 'Gender-Female': 0.85,
    'Hair-Length-Long': 0.85, 'Hair-Length-Short': 0.85
}
# Full list for UPAR model's 30-class output mapping
upar_full_label_list = [
    'Accessory-Backpack', 'Accessory-Bag', 'Accessory-Glasses-Normal', 'Accessory-Hat', 'Age-Adult',
    'Age-Young', 'Gender-Female', 'Hair-Length-Long', 'Hair-Length-Short', 'LowerBody-Color-Black',
    'LowerBody-Color-Blue', 'LowerBody-Color-Brown', 'LowerBody-Color-Grey', 'LowerBody-Color-Other',
    'LowerBody-Color-White', 'LowerBody-Length-Short', 'LowerBody-Type-Skirt&Dress',
    'LowerBody-Type-Trousers&Shorts', 'UpperBody-Color-Black', 'UpperBody-Color-Blue', 'UpperBody-Color-Brown',
    'UpperBody-Color-Green', 'UpperBody-Color-Grey', 'UpperBody-Color-Other', 'UpperBody-Color-Pink',
    'UpperBody-Color-Purple', 'UpperBody-Color-Red', 'UpperBody-Color-White', 'UpperBody-Color-Yellow',
    'UpperBody-Length-Short'
]


# Fashionpedia
fashionpedia_attributes = {
    218: "Has patch pockets", 204: "Regular sleeves", 205: "Dropped shoulders", 159: "3/4 sleeves",
    163: "Shirt-style collar", 225: "One row of buttons", 295: "No extra material", 137: "Loose fit",
    145: "No defined waist", 115: "Even on both sides", 148: "Very short", 149: "Mini length",
    316: "Nothing special in build", 317: "Plain design", 160: "Sleeves to the wrist", 128: "Straight cut",
    135: "Tight fit", 106: "Form-fitting", 140: "Waist sits low", 302: "Gathered fabric", 151: "Knee length",
    162: "Standard collar", 224: "Pockets with flaps", 214: "Puffy long sleeves", 133: "Short in front, long in back", 103: "Shirt-dress style",
    127: "Narrow shape", 325: "Floral print", 102: "Simple straight dress", 301: "Printed pattern", 142: "Normal waistline",
    200: "Straight neckline", 179: "No collar", 36: "Denim pants", 230: "Zipper fly", 136: "Standard fit",
    298: "Washed look", 154: "Full leg coverage", 223: "Curved pocket shape", 114: "Asymmetrical look", 147: "Length to the hips",
    112: "Tunic-style", 309: "Has a slit", 152: "A little below the knee", 182: "Round neck", 311: "Lined inside",
    146: "Ends above the hips", 229: "Zipper closure", 305: "Wrinkled effect", 312: "Has appliqués", 185: "Oval neck",
    186: "U-shaped neck", 119: "Fitted top, flared bottom", 141: "High waistline", 319: "Cartoon graphics", 300: "Frayed edges",
    174: "Notch-style lapel", 38: "Leggings", 17: "Blazer-style", 322: "Checkered print", 138: "Baggy fit", 289: "Fur material",
    304: "Pleated fabric", 155: "To the floor", 20: "Motorcycle jacket", 153: "Mid-calf length", 283: "Metal details", 10: "Camisole",
    187: "Heart-shaped neckline", 150: "Above the knee", 157: "Short length", 318: "Abstract print",
    108: "Gown style", 120: "Trumpet shape", 95: "Halter style", 328: "Striped pattern", 183: "V neckline", 207: "Short cap sleeves",
    129: "A-line shape", 143: "Low waist", 68: "Slim skirt", 213: "Loose sleeve style",
    192: "Very deep neckline", 314: "Metal rivets", 219: "Inset pockets", 222: "Slanted pockets", 126: "Peg pants",
    194: "Strap around neck", 209: "Puffy short sleeves", 132: "Wide-leg pants", 308: "Cutout design", 216: "Loose kimono sleeves",
    281: "Plastic material", 118: "Flared shape", 113: "Short flared dress", 117: "Circular shape", 177: "Fancy lapel",
    226: "Two rows of buttons", 297: "Worn-in look", 8: "Cropped top", 175: "Pointy lapel", 326: "Geometric print",
    50: "Short shorts", 181: "Round crew neck", 191: "Square neck", 11: "Tank top", 190: "Wide scoop neck",
    323: "Polka dots", 220: "Big front pocket", 0: "Basic tee", 320: "Letters or numbers", 158: "Sleeves to elbow",
    210: "Bell-shaped sleeves", 123: "Bell-shaped bottom", 315: "Has sequins", 197: "High neck", 101: "Simple loose dress",
    189: "Boat-shaped neckline", 166: "Wrapped collar", 180: "Uneven neckline", 198: "Turtleneck",
    307: "Layered fabric", 176: "Wide collar", 228: "Wrap-around", 2: "Undershirt", 286: "Gem detail", 313: "Beaded design",
    202: "Off-shoulder style", 121: "Mermaid shape", 221: "Stitched pockets", 234: "No visible opening", 203: "One shoulder",
    206: "Diagonal sleeve seam"
}

# =======================================================================================
# === INITIALIZATION ===
# =======================================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Image Transforms ---
# ✅ CORRECTED: Define a separate transform for each model
fairface_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

celeba_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

upar_transform = transforms.Compose([
    transforms.Resize((231, 93)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

fashionpedia_transform = transforms.Compose([
    transforms.Resize((231, 93)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# =======================================================================================
# === MODEL & DETECTOR LOADING ===
# =======================================================================================

# --- Load Object Detectors ---
face_detector_net = cv2.dnn.readNetFromCaffe(FACE_PROTO, FACE_MODEL)
if USE_FASHIONPEDIA_MODEL or USE_UPAR_MODEL:
    person_detector_net = cv2.dnn.readNetFromCaffe(PERSON_PROTO, PERSON_MODEL)

# --- Load CelebA Model ---
if USE_CELEBA_MODEL:
    celeba_model = resnet50(weights=ResNet50_Weights.DEFAULT)
    celeba_model.fc = nn.Linear(celeba_model.fc.in_features, 35) # Ensure this matches your trained model
    celeba_model.load_state_dict(torch.load(CELEBA_WEIGHTS, map_location=device, weights_only=True))
    celeba_model.eval().to(device)
    celeba_attrs_list = pd.read_csv(CELEBA_CSV).columns[1:].tolist()

# --- Load FairFace Model ---
if USE_FAIRFACE_MODEL:
    class FairFaceMultiTask(nn.Module):
        def __init__(self):
            super(FairFaceMultiTask, self).__init__()
            base = resnet34(weights=None)
            self.backbone = nn.Sequential(*list(base.children())[:-1])
            self.fc = nn.Linear(512, 18)
        def forward(self, x):
            x = self.backbone(x)
            x = x.view(x.size(0), -1)
            out = self.fc(x)
            return out[:, :7], out[:, 7:9], out[:, 9:]

    fairface_model = FairFaceMultiTask().to(device)
    fairface_model.load_state_dict(torch.load(FAIRFACE_WEIGHTS, map_location=device, weights_only=True))
    fairface_model.eval()

# --- Load Fashionpedia Model ---
if USE_FASHIONPEDIA_MODEL:
    class FashionpediaModel(nn.Module):
        def __init__(self, num_attributes):
            super(FashionpediaModel, self).__init__()
            base = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=None, verbose=False)

            in_features = base.fc.in_features
            # CORRECTED: Restore the nn.Sequential and nn.Dropout to match the saved model
            base.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_attributes)
            )
            self.model = base
        def forward(self, x):
            return self.model(x)

    fashionpedia_model = FashionpediaModel(num_attributes=len(fashionpedia_attributes)).to(device)
    fashionpedia_model.load_state_dict(torch.load(FASHIONPEDIA_WEIGHTS, map_location=device, weights_only=True))
    fashionpedia_model.eval()

# --- Load UPAR Model ---
if USE_UPAR_MODEL:
    class UPARModel(nn.Module):
        def __init__(self, num_classes=30):
            super(UPARModel, self).__init__()
            base = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=None, verbose=False)

            in_features = base.fc.in_features
            # CORRECTED: Restore the nn.Sequential and nn.Dropout to match the saved model
            base.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
            self.model = base
        def forward(self, x):
            return self.model(x)

    upar_model = UPARModel(num_classes=30).to(device)
    upar_model.load_state_dict(torch.load(UPAR_WEIGHTS, map_location=device, weights_only=True))
    upar_model.eval()
    # Create the index map for efficient lookup
    upar_label_idx_map = {label: upar_full_label_list.index(label) for label in upar_label_cols}


# =======================================================================================
# === HELPER FUNCTIONS (DETECTION, PREDICTION, ANNOTATION) ===
# =======================================================================================

def detect_faces(image, confidence_threshold):
    h, w = image.shape[:2]
    blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
    face_detector_net.setInput(blob)
    detections = face_detector_net.forward()
    boxes = []
    for i in range(detections.shape[2]):
        if detections[0, 0, i, 2] > confidence_threshold:
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            boxes.append(box.astype("int"))
    return boxes

def detect_people(image, confidence_threshold):
    h, w = image.shape[:2]
    blob = cv2.dnn.blobFromImage(image, 0.007843, (300, 300), 127.5)
    person_detector_net.setInput(blob)
    detections = person_detector_net.forward()
    boxes = []
    for i in range(detections.shape[2]):
        conf = detections[0, 0, i, 2]
        class_id = int(detections[0, 0, i, 1])
        if conf > confidence_threshold and class_id == 15: # Class ID 15 is 'person'
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            boxes.append(box.astype("int"))
    return boxes

def predict_celeba(face_crop):
    h, w = face_crop.shape[:2]
    pad_h = int(h * 0.1)
    pad_w = int(w * 0.1)

    padded_crop = cv2.copyMakeBorder(
        face_crop, pad_h, pad_h, pad_w, pad_w, borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
    )

    face_pil = Image.fromarray(cv2.cvtColor(padded_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for CelebA
    face_tensor = celeba_transform(face_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast(device_type=device.type):
        output = celeba_model(face_tensor)
        probs = torch.sigmoid(output).cpu().squeeze()

    pred_attrs = []
    for attr, prob in zip(celeba_attrs_list, probs):
        if prob.item() > celeba_thresholds.get(attr, 0.8):  # default 0.8 if not set
            pred_attrs.append(attr.replace("_", " "))

    if "Male" in pred_attrs:
        pred_attrs = [p for p in pred_attrs if p != "Male"]
        pred_attrs.insert(0, "Male")
    else:
        pred_attrs.insert(0, "Female")
    return pred_attrs


def predict_fairface(face_crop):
    h, w = face_crop.shape[:2]
    pad_h = int(h * 0.1)
    pad_w = int(w * 0.1)

    padded_crop = cv2.copyMakeBorder(
        face_crop, pad_h, pad_h, pad_w, pad_w, borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
    )

    face_pil = Image.fromarray(cv2.cvtColor(padded_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for FairFace
    face_tensor = fairface_transform(face_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast(device_type=device.type):
        out_race, out_gender, out_age = fairface_model(face_tensor)
        race_probs, gender_probs, age_probs = [p.softmax(1).cpu().squeeze() for p in [out_race, out_gender, out_age]]

    texts = []
    if race_probs.max(0).values.item() > fairface_thresholds["race"]:
        texts.append(fairface_race_labels[race_probs.argmax().item()])
    if gender_probs.max(0).values.item() > fairface_thresholds["gender"]:
        texts.append(fairface_gender_labels[gender_probs.argmax().item()])
    if age_probs.max(0).values.item() > fairface_thresholds["age"]:
        texts.append(f"Age: {fairface_age_labels[age_probs.argmax().item()]}")
    return texts


def predict_fashion(person_crop):
    image_pil = Image.fromarray(cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for Fashionpedia
    input_tensor = fashionpedia_transform(image_pil).unsqueeze(0).to(device)
    with torch.no_grad(), autocast(device_type=device.type):
        output = fashionpedia_model(input_tensor)
        probs = torch.sigmoid(output).squeeze()

    return [
        name for i, name in enumerate(fashionpedia_attributes.values())
        if probs[i].item() > FASHIONPEDIA_PRED_CONF
    ]

def predict_upar(person_crop):
    image_pil = Image.fromarray(cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for UPAR
    input_tensor = upar_transform(image_pil).unsqueeze(0).to(device)
    with torch.no_grad(), autocast(device_type=device.type):
        output = upar_model(input_tensor)
        probs = torch.sigmoid(output).squeeze()

    predicted = []
    for label in upar_label_cols:
        idx = upar_label_idx_map[label]
        if probs[idx].item() > upar_thresholds[label]:
            predicted.append(label)
    return predicted

def draw_annotations(image, box, texts, position='below', color=(0, 255, 0)):
    """
    Draws compact annotations scaled based on total image size (not object box).
    Works better for low-resolution images by using smaller font sizes.
    """
    x1, y1, x2, y2 = box
    img_h, img_w = image.shape[:2]

    # === Smaller Font Scale Based on Total Image Height ===
    font_scale = np.clip(img_h / 1600.0, 0.3, 1.2)  # much smaller base
    thickness = max(1, int(font_scale * 1.2))

    # Get estimated line height
    (_, text_h), _ = cv2.getTextSize("A", cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
    padding = int(text_h * 0.25)
    line_spacing = int(text_h * 0.35)

    # Draw bounding box
    cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

    # Start label position
    y_offset = (y2 + padding) if position == 'below' else (y1)
    x_label_start = x1 if position == 'below' else (x2 + padding)

    for txt in texts:
        (tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
        box_top = (x_label_start, y_offset)
        box_bot = (x_label_start + tw + padding, y_offset + th + padding)

        # Do not draw label if it exceeds image dimensions
        if box_bot[0] > img_w or box_bot[1] > img_h:
            break

        cv2.rectangle(image, box_top, box_bot, color, -1)
        cv2.putText(
            image,
            txt,
            (x_label_start + int(padding / 2), y_offset + th),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (0, 0, 0),
            thickness
        )
        y_offset += th + line_spacing



# =======================================================================================
# === MAIN EXECUTION ===
# =======================================================================================

if __name__ == "__main__":
    original_image = cv2.imread(IMAGE_PATH)
    if original_image is None:
        raise FileNotFoundError(f"Error: Could not load image at {IMAGE_PATH}")

    annotated_image = original_image.copy()
    H, W = original_image.shape[:2]

    # --- 1. Process People for Fashion & UPAR Attributes ---
    if USE_FASHIONPEDIA_MODEL or USE_UPAR_MODEL:
        print("Detecting people for fashion/UPAR analysis...")
        person_boxes = detect_people(annotated_image, PERSON_DETECTION_CONF)
        print(f"Found {len(person_boxes)} people.")

        for p_box in person_boxes:
            px, py, px2, py2 = p_box
            w, h = px2 - px, py2 - py
            
            # Use advanced cropping with padding and aspect ratio for better predictions
            target_ratio = 231 / 93
            padding = int(h * 0.05)
            y1 = max(0, py - padding)
            new_h = (py2 - y1)
            new_w = int(new_h / target_ratio)
            cx = px + w // 2
            x1 = max(0, cx - new_w // 2)
            x2 = min(W, x1 + new_w)
            
            person_crop = original_image[y1:py2, x1:x2]
            if person_crop.size == 0: continue

            all_person_preds = []
            if USE_UPAR_MODEL:
                all_person_preds.extend(predict_upar(person_crop))
            if USE_FASHIONPEDIA_MODEL:
                all_person_preds.extend(predict_fashion(person_crop))

            if all_person_preds:
                draw_annotations(annotated_image, (x1, y1, x2, py2), all_person_preds, position='right', color=(0, 255, 0))
            else: # Draw box even if no attributes found
                 cv2.rectangle(annotated_image, (x1, y1, x2, py2), (0, 255, 0), 2)

    # --- 2. Process Faces for Face Attributes ---
    if USE_CELEBA_MODEL or USE_FAIRFACE_MODEL:
        print("Detecting faces for attribute analysis...")
        face_boxes = detect_faces(annotated_image, FACE_DETECTION_CONF)
        print(f"Found {len(face_boxes)} faces.")

        for f_box in face_boxes:
            fx1, fy1, fx2, fy2 = f_box
            # Center-aligned square crop
            fw = fx2 - fx1
            fh = fy2 - fy1
            cx = fx1 + fw // 2
            cy = fy1 + fh // 2

            # Apply 40% margin
            side = int(max(fw, fh) * 1)

            # Shift upward by 10% of side
            cy = max(0, cy - int(side * 0.1))

            x1 = max(0, cx - side // 2)
            y1 = max(0, cy - side // 2)
            x2 = min(W, cx + side // 2)
            y2 = min(H, cy + side // 2)

            face_crop = original_image[y1:y2, x1:x2]
            if face_crop.size == 0:
                continue

            all_face_preds = []
            if USE_FAIRFACE_MODEL:
                all_face_preds.extend(predict_fairface(face_crop))
            if USE_CELEBA_MODEL:
                all_face_preds.extend(predict_celeba(face_crop))

            if all_face_preds:
                draw_annotations(annotated_image, (x1, y1, x2, y2), all_face_preds, position='below', color=(255, 182, 90))  # Use corrected square box
            else:
                cv2.rectangle(annotated_image, (x1, y1, x2, y2), (255, 182, 90), 2)


    # --- 3. Show Final Result ---
    plt.figure(figsize=(18, 14))
    plt.imshow(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.title("Combined Model Predictions", fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.models import resnet50, resnet34, ResNet50_Weights
from torch.amp import autocast

# =======================================================================================
# === CONFIGURATION & SETTINGS ===
# =======================================================================================

# --- Enable/Disable Models ---
USE_FASHIONPEDIA_MODEL = True
USE_UPAR_MODEL = True
USE_CELEBA_MODEL = True
USE_FAIRFACE_MODEL = True

# --- Input Image ---
IMAGE_PATH = "ss.jpg" # Change this to your image file

# --- Model File Paths ---
FACE_PROTO = "deploy.prototxt.txt"
FACE_MODEL = "res10_300x300_ssd_iter_140000.caffemodel"
PERSON_PROTO = "deploy.prototxt"
PERSON_MODEL = "mobilenet_iter_73000.caffemodel"
CELEBA_WEIGHTS = "best_celeba_model.pth"
FAIRFACE_WEIGHTS = "best_fairface_model.pth"
FASHIONPEDIA_WEIGHTS = "best_fashionpedia_model.pth"
UPAR_WEIGHTS = "best_upar_model.pth"
CELEBA_CSV = "celeba_cleaned.csv"

# --- Confidence Thresholds ---
FACE_DETECTION_CONF = 0.4
PERSON_DETECTION_CONF = 0.5
FASHIONPEDIA_PRED_CONF = 0.5

# --- Labels & Model-Specific Thresholds ---

# FairFace
fairface_race_labels = ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White']
fairface_gender_labels = ['Female', 'Male']
fairface_age_labels = ['0-2', '10-19', '20-29', '3-9', '30-39', '40-49', '50-59', '60-69', '70+']
fairface_thresholds = {"race": 0.50, "gender": 0.2, "age": 0.2}

# CelebA
celeba_thresholds = {
    "Male": 0.95, "Smiling": 0.50, "Wearing_Earrings": 0.50, "Heavy_Makeup": 0.75,
    "No_Beard": 0.95, "Eyeglasses": 0.70, "Young": 0.70,
}

# UPAR
upar_label_cols = [
    'Accessory-Backpack', 'Accessory-Bag', 'Accessory-Glasses-Normal', 'Accessory-Hat',
    'Age-Adult', 'Age-Young', 'Gender-Female', 'Hair-Length-Long', 'Hair-Length-Short'
]
upar_thresholds = {
    'Accessory-Backpack': 0.80, 'Accessory-Bag': 0.75, 'Accessory-Glasses-Normal': 0.85,
    'Accessory-Hat': 0.85, 'Age-Adult': 0.85, 'Age-Young': 0.85, 'Gender-Female': 0.85,
    'Hair-Length-Long': 0.85, 'Hair-Length-Short': 0.85
}
# Full list for UPAR model's 30-class output mapping
upar_full_label_list = [
    'Accessory-Backpack', 'Accessory-Bag', 'Accessory-Glasses-Normal', 'Accessory-Hat', 'Age-Adult',
    'Age-Young', 'Gender-Female', 'Hair-Length-Long', 'Hair-Length-Short', 'LowerBody-Color-Black',
    'LowerBody-Color-Blue', 'LowerBody-Color-Brown', 'LowerBody-Color-Grey', 'LowerBody-Color-Other',
    'LowerBody-Color-White', 'LowerBody-Length-Short', 'LowerBody-Type-Skirt&Dress',
    'LowerBody-Type-Trousers&Shorts', 'UpperBody-Color-Black', 'UpperBody-Color-Blue', 'UpperBody-Color-Brown',
    'UpperBody-Color-Green', 'UpperBody-Color-Grey', 'UpperBody-Color-Other', 'UpperBody-Color-Pink',
    'UpperBody-Color-Purple', 'UpperBody-Color-Red', 'UpperBody-Color-White', 'UpperBody-Color-Yellow',
    'UpperBody-Length-Short'
]


# Fashionpedia
fashionpedia_attributes = {
    218: "Has patch pockets", 204: "Regular sleeves", 205: "Dropped shoulders", 159: "3/4 sleeves",
    163: "Shirt-style collar", 225: "One row of buttons", 295: "No extra material", 137: "Loose fit",
    145: "No defined waist", 115: "Even on both sides", 148: "Very short", 149: "Mini length",
    316: "Nothing special in build", 317: "Plain design", 160: "Sleeves to the wrist", 128: "Straight cut",
    135: "Tight fit", 106: "Form-fitting", 140: "Waist sits low", 302: "Gathered fabric", 151: "Knee length",
    162: "Standard collar", 224: "Pockets with flaps", 214: "Puffy long sleeves", 133: "Short in front, long in back", 103: "Shirt-dress style",
    127: "Narrow shape", 325: "Floral print", 102: "Simple straight dress", 301: "Printed pattern", 142: "Normal waistline",
    200: "Straight neckline", 179: "No collar", 36: "Denim pants", 230: "Zipper fly", 136: "Standard fit",
    298: "Washed look", 154: "Full leg coverage", 223: "Curved pocket shape", 114: "Asymmetrical look", 147: "Length to the hips",
    112: "Tunic-style", 309: "Has a slit", 152: "A little below the knee", 182: "Round neck", 311: "Lined inside",
    146: "Ends above the hips", 229: "Zipper closure", 305: "Wrinkled effect", 312: "Has appliqués", 185: "Oval neck",
    186: "U-shaped neck", 119: "Fitted top, flared bottom", 141: "High waistline", 319: "Cartoon graphics", 300: "Frayed edges",
    174: "Notch-style lapel", 38: "Leggings", 17: "Blazer-style", 322: "Checkered print", 138: "Baggy fit", 289: "Fur material",
    304: "Pleated fabric", 155: "To the floor", 20: "Motorcycle jacket", 153: "Mid-calf length", 283: "Metal details", 10: "Camisole",
    187: "Heart-shaped neckline", 150: "Above the knee", 157: "Short length", 318: "Abstract print",
    108: "Gown style", 120: "Trumpet shape", 95: "Halter style", 328: "Striped pattern", 183: "V neckline", 207: "Short cap sleeves",
    129: "A-line shape", 143: "Low waist", 68: "Slim skirt", 213: "Loose sleeve style",
    192: "Very deep neckline", 314: "Metal rivets", 219: "Inset pockets", 222: "Slanted pockets", 126: "Peg pants",
    194: "Strap around neck", 209: "Puffy short sleeves", 132: "Wide-leg pants", 308: "Cutout design", 216: "Loose kimono sleeves",
    281: "Plastic material", 118: "Flared shape", 113: "Short flared dress", 117: "Circular shape", 177: "Fancy lapel",
    226: "Two rows of buttons", 297: "Worn-in look", 8: "Cropped top", 175: "Pointy lapel", 326: "Geometric print",
    50: "Short shorts", 181: "Round crew neck", 191: "Square neck", 11: "Tank top", 190: "Wide scoop neck",
    323: "Polka dots", 220: "Big front pocket", 0: "Basic tee", 320: "Letters or numbers", 158: "Sleeves to elbow",
    210: "Bell-shaped sleeves", 123: "Bell-shaped bottom", 315: "Has sequins", 197: "High neck", 101: "Simple loose dress",
    189: "Boat-shaped neckline", 166: "Wrapped collar", 180: "Uneven neckline", 198: "Turtleneck",
    307: "Layered fabric", 176: "Wide collar", 228: "Wrap-around", 2: "Undershirt", 286: "Gem detail", 313: "Beaded design",
    202: "Off-shoulder style", 121: "Mermaid shape", 221: "Stitched pockets", 234: "No visible opening", 203: "One shoulder",
    206: "Diagonal sleeve seam"
}

# =======================================================================================
# === INITIALIZATION ===
# =======================================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Image Transforms ---
# ✅ CORRECTED: Define a separate transform for each model
fairface_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

celeba_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

upar_transform = transforms.Compose([
    transforms.Resize((231, 93)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

fashionpedia_transform = transforms.Compose([
    transforms.Resize((231, 93)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# =======================================================================================
# === MODEL & DETECTOR LOADING ===
# =======================================================================================

# --- Load Object Detectors ---
face_detector_net = cv2.dnn.readNetFromCaffe(FACE_PROTO, FACE_MODEL)
if USE_FASHIONPEDIA_MODEL or USE_UPAR_MODEL:
    person_detector_net = cv2.dnn.readNetFromCaffe(PERSON_PROTO, PERSON_MODEL)

# --- Load CelebA Model ---
if USE_CELEBA_MODEL:
    celeba_model = resnet50(weights=ResNet50_Weights.DEFAULT)
    celeba_model.fc = nn.Linear(celeba_model.fc.in_features, 35) # Ensure this matches your trained model
    celeba_model.load_state_dict(torch.load(CELEBA_WEIGHTS, map_location=device, weights_only=True))
    celeba_model.eval().to(device)
    celeba_attrs_list = pd.read_csv(CELEBA_CSV).columns[1:].tolist()

# --- Load FairFace Model ---
if USE_FAIRFACE_MODEL:
    class FairFaceMultiTask(nn.Module):
        def __init__(self):
            super(FairFaceMultiTask, self).__init__()
            base = resnet34(weights=None)
            self.backbone = nn.Sequential(*list(base.children())[:-1])
            self.fc = nn.Linear(512, 18)
        def forward(self, x):
            x = self.backbone(x)
            x = x.view(x.size(0), -1)
            out = self.fc(x)
            return out[:, :7], out[:, 7:9], out[:, 9:]

    fairface_model = FairFaceMultiTask().to(device)
    fairface_model.load_state_dict(torch.load(FAIRFACE_WEIGHTS, map_location=device, weights_only=True))
    fairface_model.eval()

# --- Load Fashionpedia Model ---
if USE_FASHIONPEDIA_MODEL:
    class FashionpediaModel(nn.Module):
        def __init__(self, num_attributes):
            super(FashionpediaModel, self).__init__()
            base = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=None, verbose=False)

            in_features = base.fc.in_features
            # CORRECTED: Restore the nn.Sequential and nn.Dropout to match the saved model
            base.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_attributes)
            )
            self.model = base
        def forward(self, x):
            return self.model(x)

    fashionpedia_model = FashionpediaModel(num_attributes=len(fashionpedia_attributes)).to(device)
    fashionpedia_model.load_state_dict(torch.load(FASHIONPEDIA_WEIGHTS, map_location=device, weights_only=True))
    fashionpedia_model.eval()

# --- Load UPAR Model ---
if USE_UPAR_MODEL:
    class UPARModel(nn.Module):
        def __init__(self, num_classes=30):
            super(UPARModel, self).__init__()
            base = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', weights=None, verbose=False)

            in_features = base.fc.in_features
            # CORRECTED: Restore the nn.Sequential and nn.Dropout to match the saved model
            base.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(in_features, num_classes)
            )
            self.model = base
        def forward(self, x):
            return self.model(x)

    upar_model = UPARModel(num_classes=30).to(device)
    upar_model.load_state_dict(torch.load(UPAR_WEIGHTS, map_location=device, weights_only=True))
    upar_model.eval()
    # Create the index map for efficient lookup
    upar_label_idx_map = {label: upar_full_label_list.index(label) for label in upar_label_cols}


# =======================================================================================
# === HELPER FUNCTIONS (DETECTION, PREDICTION, ANNOTATION) ===
# =======================================================================================

def detect_faces(image, confidence_threshold):
    h, w = image.shape[:2]
    blob = cv2.dnn.blobFromImage(image, 1.0, (300, 300), [104, 117, 123], False, False)
    face_detector_net.setInput(blob)
    detections = face_detector_net.forward()
    boxes = []
    for i in range(detections.shape[2]):
        if detections[0, 0, i, 2] > confidence_threshold:
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            boxes.append(box.astype("int"))
    return boxes

def detect_people(image, confidence_threshold):
    h, w = image.shape[:2]
    blob = cv2.dnn.blobFromImage(image, 0.007843, (300, 300), 127.5)
    person_detector_net.setInput(blob)
    detections = person_detector_net.forward()
    boxes = []
    for i in range(detections.shape[2]):
        conf = detections[0, 0, i, 2]
        class_id = int(detections[0, 0, i, 1])
        if conf > confidence_threshold and class_id == 15: # Class ID 15 is 'person'
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            boxes.append(box.astype("int"))
    return boxes

def predict_celeba(face_crop):
    h, w = face_crop.shape[:2]
    pad_h = int(h * 0.1)
    pad_w = int(w * 0.1)

    padded_crop = cv2.copyMakeBorder(
        face_crop, pad_h, pad_h, pad_w, pad_w, borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
    )

    face_pil = Image.fromarray(cv2.cvtColor(padded_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for CelebA
    face_tensor = celeba_transform(face_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast(device_type=device.type):
        output = celeba_model(face_tensor)
        probs = torch.sigmoid(output).cpu().squeeze()

    pred_attrs = []
    for attr, prob in zip(celeba_attrs_list, probs):
        if prob.item() > celeba_thresholds.get(attr, 0.8):  # default 0.8 if not set
            pred_attrs.append(attr.replace("_", " "))

    if "Male" in pred_attrs:
        pred_attrs = [p for p in pred_attrs if p != "Male"]
        pred_attrs.insert(0, "Male")
    else:
        pred_attrs.insert(0, "Female")
    return pred_attrs


def predict_fairface(face_crop):
    h, w = face_crop.shape[:2]
    pad_h = int(h * 0.1)
    pad_w = int(w * 0.1)

    padded_crop = cv2.copyMakeBorder(
        face_crop, pad_h, pad_h, pad_w, pad_w, borderType=cv2.BORDER_CONSTANT, value=[0, 0, 0]
    )

    face_pil = Image.fromarray(cv2.cvtColor(padded_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for FairFace
    face_tensor = fairface_transform(face_pil).unsqueeze(0).to(device)

    with torch.no_grad(), autocast(device_type=device.type):
        out_race, out_gender, out_age = fairface_model(face_tensor)
        race_probs, gender_probs, age_probs = [p.softmax(1).cpu().squeeze() for p in [out_race, out_gender, out_age]]

    texts = []
    if race_probs.max(0).values.item() > fairface_thresholds["race"]:
        texts.append(fairface_race_labels[race_probs.argmax().item()])
    if gender_probs.max(0).values.item() > fairface_thresholds["gender"]:
        texts.append(fairface_gender_labels[gender_probs.argmax().item()])
    if age_probs.max(0).values.item() > fairface_thresholds["age"]:
        texts.append(f"Age: {fairface_age_labels[age_probs.argmax().item()]}")
    return texts


def predict_fashion(person_crop):
    image_pil = Image.fromarray(cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for Fashionpedia
    input_tensor = fashionpedia_transform(image_pil).unsqueeze(0).to(device)
    with torch.no_grad(), autocast(device_type=device.type):
        output = fashionpedia_model(input_tensor)
        probs = torch.sigmoid(output).squeeze()

    return [
        name for i, name in enumerate(fashionpedia_attributes.values())
        if probs[i].item() > FASHIONPEDIA_PRED_CONF
    ]

def predict_upar(person_crop):
    image_pil = Image.fromarray(cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB))
    # ✅ CORRECTED: Use the specific transform for UPAR
    input_tensor = upar_transform(image_pil).unsqueeze(0).to(device)
    with torch.no_grad(), autocast(device_type=device.type):
        output = upar_model(input_tensor)
        probs = torch.sigmoid(output).squeeze()

    predicted = []
    for label in upar_label_cols:
        idx = upar_label_idx_map[label]
        if probs[idx].item() > upar_thresholds[label]:
            predicted.append(label)
    return predicted

def draw_annotations(image, box, texts, position='below', color=(0, 255, 0)):
    """
    Draws compact annotations scaled based on total image size (not object box).
    Works better for low-resolution images by using smaller font sizes.
    """
    x1, y1, x2, y2 = box
    img_h, img_w = image.shape[:2]

    # === Smaller Font Scale Based on Total Image Height ===
    font_scale = np.clip(img_h / 1600.0, 0.3, 1.2)  # much smaller base
    thickness = max(1, int(font_scale * 1.2))

    # Get estimated line height
    (_, text_h), _ = cv2.getTextSize("A", cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
    padding = int(text_h * 0.25)
    line_spacing = int(text_h * 0.35)

    # Draw bounding box
    cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

    # Start label position
    y_offset = (y2 + padding) if position == 'below' else (y1)
    x_label_start = x1 if position == 'below' else (x2 + padding)

    for txt in texts:
        (tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
        box_top = (x_label_start, y_offset)
        box_bot = (x_label_start + tw + padding, y_offset + th + padding)

        # Do not draw label if it exceeds image dimensions
        if box_bot[0] > img_w or box_bot[1] > img_h:
            break

        cv2.rectangle(image, box_top, box_bot, color, -1)
        cv2.putText(
            image,
            txt,
            (x_label_start + int(padding / 2), y_offset + th),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (0, 0, 0),
            thickness
        )
        y_offset += th + line_spacing


# =======================================================================================
# === MAIN EXECUTION (FOR LIVE CAMERA) ===
# =======================================================================================

if __name__ == "__main__":
    # --- 1. Initialize Video Capture ---
    cap = cv2.VideoCapture(0) # Use 0 for the default webcam.
                            # You might need to change this if you have multiple cameras.
    if not cap.isOpened():
        raise IOError("Cannot open webcam")

    print("Starting live detection... Press 'q' to quit.")

    while True:
        # --- 2. Read a Frame from the Camera ---
        ret, frame = cap.read()
        if not ret:
            print("Failed to grab frame")
            break

        # Make a copy of the frame to draw annotations on
        annotated_frame = frame.copy()
        H, W = frame.shape[:2]

        # --- 3. Process People for Fashion & UPAR Attributes ---
        # This block is the same as your original script, but uses 'frame' and 'annotated_frame'
        if USE_FASHIONPEDIA_MODEL or USE_UPAR_MODEL:
            person_boxes = detect_people(frame, PERSON_DETECTION_CONF)
            for p_box in person_boxes:
                px, py, px2, py2 = p_box
                w, h = px2 - px, py2 - py

                target_ratio = 231 / 93
                padding = int(h * 0.05)
                y1 = max(0, py - padding)
                new_h = (py2 - y1)
                new_w = int(new_h / target_ratio)
                cx = px + w // 2
                x1 = max(0, cx - new_w // 2)
                x2 = min(W, x1 + new_w)

                # Use the original 'frame' for prediction to avoid using an already annotated image
                person_crop = frame[y1:py2, x1:x2]
                if person_crop.size == 0: continue

                all_person_preds = []
                if USE_UPAR_MODEL:
                    all_person_preds.extend(predict_upar(person_crop))
                if USE_FASHIONPEDIA_MODEL:
                    all_person_preds.extend(predict_fashion(person_crop))

                if all_person_preds:
                    draw_annotations(annotated_frame, (x1, y1, x2, py2), all_person_preds, position='right', color=(0, 255, 0))
                else:
                    cv2.rectangle(annotated_frame, (x1, y1, x2, py2), (0, 255, 0), 2)


        # --- 4. Process Faces for Face Attributes ---
        # This block is also the same, using 'frame' and 'annotated_frame'
        if USE_CELEBA_MODEL or USE_FAIRFACE_MODEL:
            face_boxes = detect_faces(frame, FACE_DETECTION_CONF)
            for f_box in face_boxes:
                fx1, fy1, fx2, fy2 = f_box
                fw, fh = fx2 - fx1, fy2 - fy1
                cx, cy = fx1 + fw // 2, fy1 + fh // 2
                side = int(max(fw, fh) * 1)
                cy = max(0, cy - int(side * 0.1))
                x1, y1 = max(0, cx - side // 2), max(0, cy - side // 2)
                x2, y2 = min(W, cx + side // 2), min(H, cy + side // 2)

                # Use the original 'frame' for prediction
                face_crop = frame[y1:y2, x1:x2]
                if face_crop.size == 0: continue

                all_face_preds = []
                if USE_FAIRFACE_MODEL:
                    all_face_preds.extend(predict_fairface(face_crop))
                if USE_CELEBA_MODEL:
                    all_face_preds.extend(predict_celeba(face_crop))

                if all_face_preds:
                    draw_annotations(annotated_frame, (x1, y1, x2, y2), all_face_preds, position='below', color=(255, 182, 90))
                else:
                    cv2.rectangle(annotated_frame, (x1, y1, x2, y2), (255, 182, 90), 2)


        # --- 5. Show Final Result in a Window ---
        # We use cv2.imshow for real-time display instead of matplotlib
        cv2.imshow("Live AI Detection", annotated_frame)

        # --- 6. Check for 'q' key to exit the loop ---
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # --- 7. Release Resources ---
    cap.release()
    cv2.destroyAllWindows()