In [1]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms

from models.birefnet import BiRefNet
from utils import check_state_dict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load("weights/BiRefNet-general-epoch_244.pth", map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()

  state_dict = torch.load("weights/BiRefNet-general-epoch_244.pth", map_location='cpu')


BiRefNet(
  (bb): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0): SwinTransformerBlock(
            (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=192, out_features=576, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=192, out_features=192, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=T

In [3]:
def extract_object(birefnet, imagepath):
    # Data settings
    image_size = (1024, 1024)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image = Image.open(imagepath)
    input_images = transform_image(image).unsqueeze(0).to('cuda').half()

    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    image.putalpha(mask)
    return image, mask

In [4]:
image, mask = extract_object(birefnet, "C:/Users/ZODNGUY1/datasets/zeiss/brain-bg/input/143542_0000.png")

In [13]:
from pathlib import Path
import os

img_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/brain-bg/input")
out_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/brain-transparent/input")
os.makedirs(out_dir, exist_ok=True)

In [14]:
from tqdm import tqdm

for img_path in tqdm(img_dir.iterdir(), desc="Processing images"):
    image, mask = extract_object(birefnet, img_path)
    output_path = out_dir / img_path.name
    image.save(output_path)

Processing images: 446it [06:35,  1.13it/s]
