In [None]:
'''
Evaluate a given model on a given benchmark

Example:
python evaluate.py --model_name birefnet --benchmark gcp_url_to_benchmark

Gian Favero
Ideogram
2025-10-29
'''

import sys
sys.path.insert(0, "/home/gianfavero/projects/")
sys.path.insert(0, "/home/gianfavero/projects/BiRefNet/")

import argparse
import os

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from BiRefNet.benchmarking.factory import get_model
from BiRefNet.ideogram_utils import pil_image_to_bytes, reduce_spill, recover_original_rgba

import PIL
from PIL import Image
import numpy as np

In [None]:
def bg_removal_transform(sample): # from BiRefNet
    transform_pipeline = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((1024, 1024)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = PIL.Image.fromarray(sample) # np.ndarray (H, W, C)
    input_image = transform_pipeline(image) # Tensor (C, H, W) in the range [0.0, 1.0]
    return image, input_image

def collate_fn(batch):
    images = [item[0] for item in batch]
    input_images = [item[1] for item in batch]
    input_images = torch.stack(input_images)
    return {"images": images, "input_images": input_images}

@torch.no_grad()
def evaluate(model, dataloader):
    torch.set_float32_matmul_precision(['high', 'highest'][0])

    images_list = []
    masks_list = []
    for batch in dataloader:
        input_images = batch["input_images"].to(model.device).half() # needs to be full precision for rmbgv2
        images = batch["images"]

        masks = model(input_images)
        masks[masks < 0.1] = 0

        images_list.extend(images)
        masks_list.append(masks.detach().cpu())
    masks_list = torch.cat(masks_list, dim=0)

    output_list = []
    for image, mask in zip(images_list, masks_list):
        mask = transforms.ToPILImage()(mask)
        mask = mask.resize(image.size)

        recovered_rgba = recover_original_rgba(image, mask)
        image = reduce_spill(recovered_rgba, mask, r=90)

        image.putalpha(mask)

        output_list.append(image)

    return output_list

In [None]:
model_name = "custom" # ['birefnet', 'rmbgv2', 'custom']
benchmark = "base-benchmark" # ['green-benchmark', 'base-benchmark', 'ig-benchmark']
path_to_weight = "/home/gianfavero/projects/BiRefNet/ckpts/green_1e-5_cosine_matting/step_43186.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import time
from PIL import Image

# input dir
image_paths = [f"test_samples/monopoly.png" for i in range(1)]
for idx, path in enumerate(image_paths):
    # Load the upscaled image for testing ("upscaled.png")
    image = Image.open(path).convert("RGB")
    img_np = np.array(image)

    # Get tile size and the shape of the image
    tile_size = 1024 * 1
    overlap = 128  # Overlap in pixels (adjust as needed, e.g., 128, 256, etc.)
    stride = tile_size - overlap
    h, w = img_np.shape[0], img_np.shape[1]

    # Create tiles with overlap and track their positions
    tiles = []
    tile_positions = []
    for y in range(0, h, stride):
        for x in range(0, w, stride):
            # Extract tile, handling boundaries
            y_end = min(y + tile_size, h)
            x_end = min(x + tile_size, w)
            tile = img_np[y:y_end, x:x_end]
            
            tile = bg_removal_transform(tile)
            tiles.append(tile)
            tile_positions.append((x, y, x_end - x, y_end - y))  # (x, y, width, height)

    # Make a dataloader for the tiles
    dataloader = DataLoader(tiles, batch_size=1, shuffle=False, collate_fn=collate_fn)

    model = get_model(model_name, device=device, path_to_weight=path_to_weight)
    start_time = time.time()
    output_list = evaluate(model, dataloader)
    end_time = time.time()
    print(f"Time taken: {end_time - start_time} seconds")

    # Function to create a feather mask for blending
    def create_feather_mask(width, height, feather_size):
        """Create a feather mask with soft edges for blending"""
        mask = np.ones((height, width), dtype=np.float32)
        
        # Apply feathering on all edges
        for i in range(feather_size):
            alpha = i / feather_size
            # Top edge
            if i < height:
                mask[i, :] = np.minimum(mask[i, :], alpha)
            # Bottom edge
            if height - 1 - i >= 0:
                mask[height - 1 - i, :] = np.minimum(mask[height - 1 - i, :], alpha)
            # Left edge
            if i < width:
                mask[:, i] = np.minimum(mask[:, i], alpha)
            # Right edge
            if width - 1 - i >= 0:
                mask[:, width - 1 - i] = np.minimum(mask[:, width - 1 - i], alpha)
        
        return mask

    # Stitch the output tiles back together with blending
    output_image = np.zeros((h, w, 4), dtype=np.float32)
    weight_map = np.zeros((h, w), dtype=np.float32)

    feather_size = overlap // 2  # Feather half of the overlap region

    for (x, y, tile_w, tile_h), output in zip(tile_positions, output_list):
        # Convert PIL image to numpy array
        tile_array = np.array(output).astype(np.float32)
        
        # Create feather mask for this tile
        feather_mask = create_feather_mask(tile_w, tile_h, feather_size)
        
        # Apply the feather mask to the tile
        for c in range(4):  # RGBA channels
            output_image[y:y+tile_h, x:x+tile_w, c] += tile_array[:, :, c] * feather_mask
        
        # Accumulate weights
        weight_map[y:y+tile_h, x:x+tile_w] += feather_mask

    # Normalize by the weight map to get the final blended result
    for c in range(4):
        output_image[:, :, c] = np.divide(
            output_image[:, :, c], 
            weight_map, 
            out=np.zeros_like(output_image[:, :, c]), 
            where=weight_map != 0
        )

    # Convert back to PIL Image
    output_image = Image.fromarray(output_image.astype(np.uint8), mode='RGBA')

    output_image.save(f"sample_{idx}.png")

    print(f"Processed {len(output_list)} overlapping tiles with {overlap}px overlap")