In [5]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms

from models.birefnet import BiRefNet
from utils import check_state_dict

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load("weights/BiRefNet-general-epoch_244.pth", map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)

torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
birefnet.half()
print("Model loaded")

  state_dict = torch.load("weights/BiRefNet-general-epoch_244.pth", map_location='cpu')


Model loaded


In [7]:
def extract_object(birefnet, imagepath):
    # Data settings
    image_size = (1024, 1024)
    transform_image = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image = Image.open(imagepath)
    input_images = transform_image(image).unsqueeze(0).to('cuda').half()

    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image.size)
    image.putalpha(mask)
    return image, mask

In [8]:
from pathlib import Path
import os

img_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/heart-bg/images")
out_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/heart/images")
os.makedirs(out_dir, exist_ok=True)

#### Export PNG

In [9]:
from tqdm import tqdm

for img_path in tqdm(img_dir.iterdir(), desc="Processing images"):
    image, _ = extract_object(birefnet, img_path)
    output_path = out_dir / img_path.name
    image.save(output_path)

Processing images: 565it [05:50,  1.61it/s]


#### Export mask

In [10]:
from tqdm import tqdm

for img_path in tqdm(img_dir.iterdir(), desc="Processing images"):
    image, mask = extract_object(birefnet, img_path)
    output_name = img_path.name.split('_')[1]
    output_path = out_dir / output_name
    mask.save(output_path)

Processing images: 446it [02:48,  2.64it/s]


#### Fix name

In [2]:
from pathlib import Path
import os

old_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/heart-transparent/images_old")
new_dir = Path("C:/Users/ZODNGUY1/datasets/zeiss/heart-transparent/images")
os.makedirs(new_dir, exist_ok=True)

In [4]:
from shutil import copy2
from tqdm import tqdm

for img_path in tqdm(old_dir.glob("*.png")):
    if img_path.stem.isdigit() and len(img_path.stem) == 3:
        new_name = f"{int(img_path.stem)}.png"
        new_path = new_dir / new_name
        copy2(img_path, new_path)

565it [00:01, 523.96it/s]
