In [4]:
import os
import base64
from flask import Flask, request, jsonify
from flask_cors import CORS
import garmentiq as giq

app = Flask(__name__)
CORS(app, resources={r"/*": {"origins": "*"}})

# where we put uploads & outputs
BASE_DIR   = 'output_img'
INPUT_DIR  = os.path.join(BASE_DIR, 'input')
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')
os.makedirs(INPUT_DIR,  exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# load model once
print("Loading BiRefNet model…")
BiRefNet = giq.segmentation.load_model(
    pretrained_model='ZhengPeng7/BiRefNet',
    pretrained_model_args={'trust_remote_code': True},
    high_precision=True
)
print("Model loaded.")

@app.route('/', methods=['GET'])
def home():
    return "GarmentIQ Segmentation API is running."

@app.route('/segment', methods=['POST'])
def segment():
    files = request.files.getlist('images')
    if not files:
        return jsonify(error='No images uploaded'), 400

    # clear input/output dirs
    for d in (INPUT_DIR, OUTPUT_DIR):
        for fn in os.listdir(d):
            fp = os.path.join(d, fn)
            if os.path.isfile(fp):
                os.remove(fp)

    # save uploads
    for f in files:
        f.save(os.path.join(INPUT_DIR, f.filename))

    # parse optional RGB
    r = request.form.get('red', '')
    g = request.form.get('green', '')
    b = request.form.get('blue', '')
    background_color = None
    if r != '' and g != '' and b != '':
        try:
            rgb = [int(r), int(g), int(b)]
            if any(c < 0 or c > 255 for c in rgb):
                raise ValueError
            background_color = rgb
        except ValueError:
            return jsonify(error='Invalid RGB values; must be integers 0–255'), 400

    # run the batch processor
    kwargs = {
        'image_dir': INPUT_DIR,
        'output_dir': OUTPUT_DIR,
        'model': BiRefNet,
        'resize_dim': (1024, 1024),
        'normalize_mean': [0.485, 0.456, 0.406],
        'normalize_std': [0.229, 0.224, 0.225],
        'high_precision': True
    }
    if background_color:
        kwargs['background_color'] = background_color

    giq.segmentation.process_and_save_images(**kwargs)

    # pick which folder to read from
    sub = 'bg_modified' if background_color else 'masks'
    folder = os.path.join(OUTPUT_DIR, sub)

    out = []
    for fn in sorted(os.listdir(folder)):
        path = os.path.join(folder, fn)
        with open(path, 'rb') as imgf:
            b64 = base64.b64encode(imgf.read()).decode('utf-8')
        out.append({'filename': fn, 'base64': b64})

    return jsonify(images=out)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000)

Loading BiRefNet model…
Model loaded.
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.17.0.2:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:53:20] "GET / HTTP/1.1" 200 -
INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:53:27] "GET / HTTP/1.1" 200 -


Processing Images:   0%|          | 0/1 [00:00<?, ?image/s]

INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:53:40] "POST /segment HTTP/1.1" 200 -


Processing Images:   0%|          | 0/1 [00:00<?, ?image/s]

INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:55:25] "POST /segment HTTP/1.1" 200 -


Processing Images:   0%|          | 0/1 [00:00<?, ?image/s]

INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:55:43] "POST /segment HTTP/1.1" 200 -


Processing Images:   0%|          | 0/4 [00:00<?, ?image/s]

INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:56:41] "POST /segment HTTP/1.1" 200 -


Processing Images:   0%|          | 0/4 [00:00<?, ?image/s]

INFO:werkzeug:172.17.0.1 - - [01/May/2025 14:57:26] "POST /segment HTTP/1.1" 200 -
