In [1]:
'''
Evaluate a given model on a given benchmark

Example:
python evaluate.py --model_name birefnet --benchmark gcp_url_to_benchmark

Gian Favero
Ideogram
2025-10-29
'''

import sys
sys.path.insert(0, "/home/gianfavero/projects/")

import argparse
import os

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from BiRefNet.benchmarking.factory import get_model
from BiRefNet.ideogram_dataset import BenchmarkDataset
from BiRefNet.ideogram_utils import pil_image_to_bytes, reduce_spill
from tfrecords.benchmark.tfr import BenchmarkExample
from tfrecords.eval.tfr import EvalExample

from PIL import Image
import numpy as np
import tensorflow as tf
import cv2

  from .autonotebook import tqdm as notebook_tqdm
2025-11-12 20:01:07.499421: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-12 20:01:07.546590: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-12 20:01:08.599653: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variab

### Functional code for evaluation

In [None]:
def bg_removal_transform(sample): # from BiRefNet
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((1024, 1024)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = sample["inpainting_image"] # PIL.Image
    input_image = transform_pipeline(image) # Tensor (C, H, W) in the range [0.0, 1.0]
    return image, input_image

def collate_fn(batch):
    images = [item[0] for item in batch]
    input_images = [item[1] for item in batch]
    input_images = torch.stack(input_images)
    return {"images": images, "input_images": input_images}

@torch.no_grad()
def evaluate(model, dataloader):
    torch.set_float32_matmul_precision(['high', 'highest'][0])

    images_list = []
    masks_list = []
    for batch in dataloader:
        input_images = batch["input_images"].to(model.device).half() # needs to be full precision for rmbgv2
        images = batch["images"]

        masks = model(input_images)
        masks[masks < 0.1] = 0

        images_list.extend(images)
        masks_list.append(masks.detach().cpu())
    masks_list = torch.cat(masks_list, dim=0)

    output_list = []
    for image, mask in zip(images_list, masks_list):
        mask = transforms.ToPILImage()(mask)
        mask = mask.resize(image.size)
        image = reduce_spill(image, mask, r=90)

        image.putalpha(mask)

        output_list.append(image)

    return output_list

def save_output(output_list, model_name, benchmark):
    os.makedirs(f"eval-output/{benchmark}/{model_name}", exist_ok=True)
    for i, output in enumerate(output_list):
        output.save(f"eval-output/{benchmark}/{model_name}/sample_{i}.png")

def write_output_to_tfr(output_list, model_name, benchmark_url):
    writer_name = f"gs://mobius-dev-us-east5/gian_favero_workspace/background_removal_examples/{benchmark_url.split('/')[-1]}_{model_name}_ablated.tfr"
    data = tf.data.TFRecordDataset(str(benchmark_url)).as_numpy_iterator()
    with tf.io.TFRecordWriter(writer_name) as writer:
        for image, serialized in zip(output_list, data):
            ex = BenchmarkExample.from_tf_example(serialized)
            writer.write(
                EvalExample(
                    prompt=ex,
                    images=[pil_image_to_bytes(image)]
                ).to_tf_example().SerializeToString()
            )
    print(f"Wrote {len(output_list)} eval examples to {writer_name}")

### Launch point

In [10]:
model_name = "custom" # ['birefnet', 'rmbgv2', 'custom']
benchmark = "green-benchmark" # ['green-benchmark', 'ig-benchmark']
path_to_weight = "/home/gianfavero/projects/BiRefNet/ckpts/test/step_473.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
print(f"Evaluating {model_name} on {benchmark}")

if benchmark == "green-benchmark":
    benchmark_url = "gs://mobius-dev-us-east5/gian_favero_workspace/background_removal_examples/11072025_green_graphic_bm_2k.tfr"
elif benchmark == "ig-benchmark":
    benchmark_url = "gs://mobius-dev-us-east5/gian_favero_workspace/background_removal_examples/10292025_samples.tfr"

tfr_dataset = BenchmarkDataset( 
    writer_name=benchmark_url,
    keys=["inpainting_image"],
    transform=bg_removal_transform,
)

tfr_dataloader = DataLoader(
    tfr_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn
)

model = get_model(model_name, device=device, path_to_weight=path_to_weight)

output = evaluate(model, tfr_dataloader)

# write_output_to_tfr(output, model_name, benchmark_url)

save_output(output, model_name, benchmark)

Evaluating custom on green-benchmark


2025-11-12 20:03:09.327415: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


RuntimeError: Input type (float) and bias type (c10::Half) should be the same