In [None]:
import os
import base64
from flask import Flask, request, jsonify
from flask_cors import CORS
import shutil
import urllib.request
import pandas as pd
import garmentiq as giq
from garmentiq.classification.model_definition import tinyViT
from garmentiq.landmark.detection.model_definition import PoseHighResolutionNet
from garmentiq.garment_classes import garment_classes
from garmentiq.landmark.derivation.derivation_dict import derivation_dict

# Initialize the empty tailer object
tailor = None
do_refine = None
do_derive = None
background_color = None

In [None]:
app = Flask(__name__)
CORS(app)

# where we put uploads & outputs
BASE_DIR   = 'tailor_files'
INPUT_DIR  = os.path.join(BASE_DIR, 'input')
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
MODELS_DIR = os.path.join(BASE_DIR, 'models')
os.makedirs(INPUT_DIR,  exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

In [None]:
# Download the classification model
# Define model URLs
classification_model_url = "https://huggingface.co/lygitdata/garmentiq/resolve/main/tiny_vit_inditex_finetuned.pt"
landmark_model_url = "https://huggingface.co/lygitdata/garmentiq/resolve/main/hrnet.pth"

# Define local file paths using MODELS_DIR
classification_model_path = os.path.join(MODELS_DIR, "tiny_vit_inditex_finetuned.pt")
landmark_model_path = os.path.join(MODELS_DIR, "hrnet.pth")

# Function to download only if file doesn't exist
def download_if_missing(url, destination_path):
    if not os.path.exists(destination_path):
        print(f"Downloading to {destination_path}...")
        urllib.request.urlretrieve(url, destination_path)
        print("Download complete.")
    else:
        print(f"File already exists: {destination_path}. Skipping download.")

# Perform conditional downloads
download_if_missing(classification_model_url, classification_model_path)
download_if_missing(landmark_model_url, landmark_model_path)

In [None]:
def encode_image_to_base64(path):
    with open(path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')

In [None]:
@app.route('/health')
def health_check():
    return jsonify({
        "status": "active",
        "token": "31f21a4323cf148339669c736f522cf89fa570a2bffc80ba874078477b76f81b",
        "model": "web_tailor",
        "version": "1.0.0"
    })

In [None]:
@app.route('/setup', methods=['POST'])
def setup():
    global tailor
    global do_refine
    global do_derive
    global background_color

    do_refine = request.form.get('do_refine') == 'true'
    do_derive = request.form.get('do_derive') == 'true'
    r = request.form.get('red', '')
    g = request.form.get('green', '')
    b = request.form.get('blue', '')
    background_color = None

    if r and g and b:
        try:
            rgb = [int(r), int(g), int(b)]
            if any(c < 0 or c > 255 for c in rgb):
                raise ValueError
            background_color = rgb
        except ValueError:
            return jsonify(error='Invalid RGB values; must be integers 0–255'), 400

    # Setup the tailor agent
    tailor = giq.tailor(
        input_dir=INPUT_DIR,
        model_dir=MODELS_DIR,
        output_dir=OUTPUT_DIR,
        class_dict=garment_classes,
        do_derive=do_derive,
        derivation_dict=derivation_dict,
        do_refine=do_refine,
        classification_model_path="tiny_vit_inditex_finetuned.pt",
        classification_model_class=tinyViT,
        classification_model_args={
            "num_classes": len(garment_classes),
            "img_size": (120, 184),
            "patch_size": 6,
            "resize_dim": (120, 184),
            "normalize_mean": [0.8047, 0.7808, 0.7769],
            "normalize_std": [0.2957, 0.3077, 0.3081],
        },
        segmentation_model_name="lygitdata/BiRefNet_garmentiq_backup",
        segmentation_model_args={
            "trust_remote_code": True,
            "resize_dim": (1024, 1024),
            "normalize_mean": [0.485, 0.456, 0.406],
            "normalize_std": [0.229, 0.224, 0.225],
            "high_precision": True,
            "background_color": background_color,
        },
        landmark_detection_model_path="hrnet.pth",
        landmark_detection_model_class=PoseHighResolutionNet(),
        landmark_detection_model_args={
            "scale_std": 200.0,
            "resize_dim": [288, 384],
            "normalize_mean": [0.485, 0.456, 0.406],
            "normalize_std": [0.229, 0.224, 0.225],
        },
    )

    return jsonify(message='Tailor setup complete'), 200

In [None]:
@app.route('/measure', methods=['POST'])
def measure():
    global tailor, do_refine, do_derive, background_color
    
    files = request.files.getlist('images')
    if not files:
        return jsonify(error='No images uploaded'), 400

    # Ensure input/output directories are fresh
    for d in (INPUT_DIR, OUTPUT_DIR):
        if os.path.exists(d):
            shutil.rmtree(d)
        os.makedirs(d, exist_ok=True)

    # Save uploaded images
    for f in files:
        f.save(os.path.join(INPUT_DIR, f.filename))

    # Run segmentation & measurement
    metadata, _ = tailor.measure(save_segmentation_image=True, save_measurement_image=True)
    metadata = metadata.sort_values(by='filename', ascending=True)

    # Structure results
    results = []
    for idx, row in metadata.iterrows():
        filename = row['filename']
    
        # Read & encode the JSON contents
        json_b64 = None
        json_path = row.get("measurement_json")
        if json_path:
            with open(json_path, 'r') as jf:
                text = jf.read()
            json_b64 = base64.b64encode(text.encode('utf-8')).decode('utf-8')
    
        entry = {
            "Image name": filename,
            "Class": row.get("class", None),
            "Measurement image": encode_image_to_base64(row["measurement_image"]),
            "Measurement JSON (base64)": json_b64,
        }
    
        if row.get("mask_image"):
            entry["Mask"] = encode_image_to_base64(row["mask_image"])
    
        if row.get("bg_modified_image"):
            entry["Background modified"] = encode_image_to_base64(row["bg_modified_image"])
    
        results.append(entry)
    
    return jsonify(results), 200

In [None]:
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5002)