In [None]:
import os
import base64
from flask import Flask, request, jsonify
from flask_cors import CORS
import garmentiq as giq
import shutil

In [None]:
app = Flask(__name__)
CORS(app)

# where we put uploads & outputs
BASE_DIR   = 'images'
INPUT_DIR  = os.path.join(BASE_DIR, 'input')
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
os.makedirs(INPUT_DIR,  exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# load model once
print("Loading BiRefNet model…")
BiRefNet = giq.segmentation.load_model(
    pretrained_model='ZhengPeng7/BiRefNet',
    pretrained_model_args={'trust_remote_code': True},
    high_precision=True
)
print("Model loaded.")

In [None]:
@app.route('/health')
def health_check():
    return jsonify({
        "status": "active",
        "token": "fe9def684914b305540615b3bba461fbf8ea58f460c72b1fe128d3cef93fe4f8",
        "model": "web_segmentation",
        "version": "1.0.0"
    })

@app.route('/segment', methods=['POST'])
def segment():
    files = request.files.getlist('images')
    if not files:
        return jsonify(error='No images uploaded'), 400

    # Ensure input/output directories are fresh
    for d in (INPUT_DIR, OUTPUT_DIR):
        if os.path.exists(d):
            shutil.rmtree(d)
        os.makedirs(d, exist_ok=True)

    # Save uploaded images
    for f in files:
        f.save(os.path.join(INPUT_DIR, f.filename))

    # Parse optional RGB values
    r = request.form.get('red', '')
    g = request.form.get('green', '')
    b = request.form.get('blue', '')
    background_color = None
    if r != '' and g != '' and b != '':
        try:
            rgb = [int(r), int(g), int(b)]
            if any(c < 0 or c > 255 for c in rgb):
                raise ValueError
            background_color = rgb
        except ValueError:
            return jsonify(error='Invalid RGB values; must be integers 0–255'), 400

    # Prepare processing arguments
    kwargs = {
        'image_dir': INPUT_DIR,
        'output_dir': OUTPUT_DIR,
        'model': BiRefNet,
        'resize_dim': (1024, 1024),
        'normalize_mean': [0.485, 0.456, 0.406],
        'normalize_std': [0.229, 0.224, 0.225],
        'high_precision': True
    }
    if background_color:
        kwargs['background_color'] = background_color

    # Run segmentation
    giq.segmentation.process_and_save_images(**kwargs)

    # Collect masks
    masks = []
    masks_folder = os.path.join(OUTPUT_DIR, 'masks')
    if os.path.isdir(masks_folder):
        for fn in sorted(os.listdir(masks_folder)):
            path = os.path.join(masks_folder, fn)
            if os.path.isfile(path):
                with open(path, 'rb') as imgf:
                    b64 = base64.b64encode(imgf.read()).decode('utf-8')
                masks.append({'filename': fn, 'base64': b64})

    # Collect background-modified images if any
    modified = []
    if background_color:
        mod_folder = os.path.join(OUTPUT_DIR, 'bg_modified')
        if os.path.isdir(mod_folder):
            for fn in sorted(os.listdir(mod_folder)):
                path = os.path.join(mod_folder, fn)
                if os.path.isfile(path):
                    with open(path, 'rb') as imgf:
                        b64 = base64.b64encode(imgf.read()).decode('utf-8')
                    modified.append({'filename': fn, 'base64': b64})

    # Return results
    return jsonify(masks=masks, bg_modified=modified)

In [None]:
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)