<a href="https://colab.research.google.com/github/lucasmarques/colab-notebooks/blob/main/BiRefNet_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
# Install necessary libraries
!pip install -q transformers torch torchvision kornia

In [32]:
from transformers import AutoModelForImageSegmentation, AutoImageProcessor
import torch
from PIL import Image
import requests
import matplotlib.pyplot as plt

# Load model
model = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True)
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

BiRefNet(
  (bb): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=192, out_features=576, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=192, out_features=192, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=T

In [33]:
# Helper preprocessing
def preprocess(image, fp16=True):
    from torchvision import transforms
    transform = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])
    t = transform(image).unsqueeze(0).to(device)
    return t.half() if fp16 else t


In [34]:
# Load example
img = Image.open(requests.get("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png", stream=True).raw).convert("RGB")
inp = preprocess(img, fp16=False)

In [35]:
# Inference
with torch.no_grad():
    outputs = model(inp)

logits = outputs[0]
mask = logits[0, 0].sigmoid().cpu()  # single-channel mask

In [36]:
import numpy as np
from torchvision.transforms.functional import to_pil_image
import io

# 1. Get original image size
original_size = img.size  # (width, height)

# Ensure mask is in float32 [0, 1]
mask_np = np.clip(mask.numpy(), 0, 1)

# Convert mask to PIL and resize to original image size
mask_img = Image.fromarray((mask_np * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)

# Resize mask to match original image size
mask_resized = mask_img.resize(original_size, Image.BILINEAR)

# Convert resized mask back to numpy
alpha_channel = np.array(mask_resized).astype(np.uint8)

# Convert to RGBA
image_rgba = img.convert("RGBA")
image_np = np.array(image_rgba)

# Make sure shapes match
if alpha_channel.shape != image_np[..., 0].shape:
    raise ValueError(f"Alpha channel shape {alpha_channel.shape} doesn't match image shape {image_np.shape[:2]}")

# Assign alpha channel
image_np[..., 3] = alpha_channel  # replace alpha with mask

# Convert back to image
final_image = Image.fromarray(image_np)

buf = io.BytesIO()
final_image.save(buf, format="PNG")
image_bytes = buf.getvalue()

In [37]:
from google.colab import files
# 1. Save bytes to a file
with open("output.png", "wb") as f:
    f.write(image_bytes)
# 2. Trigger download
files.download("output.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>