## BRIA Background Removal v1.4 Model
#### https://huggingface.co/briaai/RMBG-1.4

In [2]:
pip install -qr https://huggingface.co/briaai/RMBG-1.4/resolve/main/requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install transformers

Note: you may need to restart the kernel to use updated packages.


In [12]:
import os
import json
from transformers import pipeline
from tqdm import tqdm
import time

In [5]:
def get_image_mask(image_path):
    pipe = pipeline("image-segmentation", model="briaai/RMBG-1.4", trust_remote_code=True)
    pillow_mask = pipe(image_path, return_mask = True) # outputs a pillow mask
    # pillow_image = pipe(image_path) # applies mask on input and returns a pillow image
    return pillow_mask

In [21]:
def save_image(objectId, path, pillow_mask):
    filename_mask = path + str(objectId) + "mask.png"

    if not os.path.exists(filename_mask):
        pillow_mask.save(filename_mask)
    else:
        print(f"Skipped: {filename_mask} already exists.")

In [22]:
def save_BRIA_image(obj, path):
    image_path = obj["imageUrl"]
    objectId = obj["objectId"]

    pillow_mask = get_image_mask(image_path)
    save_image(objectId, path, pillow_mask)

In [23]:
def run_batched_BRIA_processing(data, save_path, batch_size=10, delay=0):
    """
    Iterates through an array of objects and calls save_BRIA_image in batches.

    Args:
        data (list): List of objects with keys "objectId" and "imageUrl"
        save_path (str): Where to save output images
        batch_size (int): How many to process per batch
        delay (float): Optional delay between batches (in seconds)
    """
    total = len(data)

    for i in tqdm(range(0, total, batch_size), desc="Processing batches"):
        batch = data[i:i + batch_size]

        for obj in batch:
            try:
                save_BRIA_image(obj, save_path)
            except Exception as e:
                print(f"Error processing object {obj.get('objectId')}: {e}")

        if delay:
            time.sleep(delay)

#### TEST

In [8]:
with open("../objects_db_test.json", "r") as f:
    test_data = json.load(f)

In [17]:
# print(len(test_data['objects']))
test_objs = test_data['objects']

In [15]:
test_output_path = "../public/data/initial_test/BRIA/"

In [24]:
save_BRIA_image(test_objs[0], test_output_path)

Device set to use mps:0


In [25]:
run_batched_BRIA_processing(test_objs, test_output_path)

Processing batches:   0%|          | 0/28 [00:00<?, ?it/s]Device set to use mps:0


Skipped: ../public/data/initial_test/BRIA/14318mask.png already exists.


Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Processing batches:   4%|▎         | 1/28 [00:25<11:24, 25.35s/it]Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Processing batches:   7%|▋         | 2/28 [00:52<11:28, 26.48s/it]Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Processing batches:  11%|█         | 3/28 [01:15<10:15, 24.64s/it]Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set to use mps:0
Device set