In [None]:
import garmentiq as giq
from garmentiq.landmark.detection.model_definition import PoseHighResolutionNet
from garmentiq.garment_classes import garment_classes
from flask_cors import CORS
from flask import Flask, request, jsonify
from PIL import Image
import numpy as np
import io
import cv2
import base64

In [None]:
%%bash
mkdir -p /app/working/examples/web_landmark_detection/models
if [ ! -f /app/working/examples/web_landmark_detection/models/hrnet.pth ]; then
    wget -q -O /app/working/examples/web_landmark_detection/models/hrnet.pth \
        https://huggingface.co/lygitdata/garmentiq/resolve/main/hrnet.pth
else
    echo "Model file already exists, skipping download."
fi

In [None]:
app = Flask(__name__)
CORS(app)

# load model once
print("Loading HRNet model…")
HRNet = giq.landmark.detection.load_model(
    model_path="/app/working/examples/web_landmark_detection/models/hrnet.pth",
    model_class=PoseHighResolutionNet()
)
print("Model loaded.")

In [None]:
@app.route('/health')
def health_check():
    return jsonify({
        "status": "active",
        "token": "a3f7d2c14e65bb2e8f01a9dc4f6c9823d279f1e05b3a6d74c0987b1c2fae3c65",
        "model": "web_landmark_detection",
        "version": "1.0.0"
    })

@app.route('/landmark_detection', methods=['POST'])
def landmark_detection():
    files = request.files.getlist('images')
    garment_class = request.form.get('garment_class')

    images_np = []
    images_coords = []

    # Convert uploaded images to numpy arrays
    for file in files:
        image = Image.open(file.stream).convert('RGB')
        image_np = np.array(image)
        images_np.append(image_np)

    # Run detection and collect coordinates
    for img in images_np:
        coords, _, _ = giq.landmark.detect(
            class_name=garment_class,
            class_dict=garment_classes,
            image_path=img,
            model=HRNet,
            scale_std=200.0,
            resize_dim=[288, 384],
            normalize_mean=[0.485, 0.456, 0.406],
            normalize_std=[0.229, 0.224, 0.225]
        )
        images_coords.append(coords)

    # Annotate images and encode to base64
    base64_images = []
    for img, coords in zip(images_np, images_coords):
        # Convert RGB numpy array to BGR for OpenCV drawing
        img_annotated = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        height, width = img_annotated.shape[:2]
        
        # Circle radius relative to image size
        radius = max(2, int(round(min(height, width) * 0.0125)))  # at least radius 2
        
        for point in coords[0]:  # coords is (1, N, 2)
            x, y = int(round(point[0])), int(round(point[1]))
            cv2.circle(img_annotated, (x, y), radius, (0, 255, 0), -1)  # green filled circle in BGR
        
        # Convert back to RGB for PIL
        img_annotated_rgb = cv2.cvtColor(img_annotated, cv2.COLOR_BGR2RGB)
        pil_img = Image.fromarray(img_annotated_rgb)
        
        buffered = io.BytesIO()
        pil_img.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
        base64_images.append(img_str)

    return jsonify({
        "results": base64_images
    })

In [None]:
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5001)