In [1]:
import polars as pl
import numpy as np
import torch
from torchvision.transforms import v2
from PIL import Image
import io
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor
import json

from pathlib import Path
from rembg import remove, new_session

DATA_PATH = Path.cwd().parent / "data"

In [2]:
session = new_session("sam")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))

In [None]:
# transforms = v2.Compose([
#     v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
#     v2.ToDtype(torch.uint8, scale=True),  # optional, most input are already uint8 at this point
#     v2.Resize(size=(224, 224), antialias=True),  # Or Resize(antialias=True)
#     v2.ToDtype(torch.float32, scale=True),  # Normalize expects float input
#     # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

In [None]:
data = pl.scan_parquet(DATA_PATH / "food101-validation.parquet")
print(data.select(pl.len()).collect().item())
test_data = data.head(10).collect()
test_data

In [None]:
def process_base(row):
    image, filename, label = row
    Image.open(io.BytesIO(image)).save(f"{DATA_PATH}/base_jpg/food101-validation/{filename}")
    return filename, label

In [None]:
full_data = data.select(pl.col("image"), pl.col("label")).unnest("image").collect()

In [None]:
with ThreadPoolExecutor() as executor:
    results = dict(executor.map(process_base, full_data.iter_rows()))
results

In [None]:
with Path(DATA_PATH / "food101-validation-mappings.json").open("w") as f:
    json.dump(results, f)

In [None]:
test_data['image'][0]['path']

In [None]:
test = test_data[9]
image = Image.open(io.BytesIO(test['image'][0]['bytes']))
image

In [None]:
image.save(f"food_images/{test['label'][0]}.jpg")

In [None]:
with open(f"food_images/{test['label'][0]}.jpg", "rb") as f:
    image = Image.open(f).convert("RGBA")
orig_img = np.count_nonzero(np.array(image)[:, :, 3] > 0)
image = remove(image, model=session, bgcolor=(0, 0, 0, 0), post_process_mask=True)
# image.save(f"food_images/{test['label'][0]}_no_bg.jpg")
post_img = np.count_nonzero(np.array(image)[:, :, 3] > 0)

print(f"Percent removed: {100 - (post_img / orig_img * 100)}")

image

In [None]:
io.BytesIO(image.tobytes()).getvalue()[:100]

In [None]:
%% true
# Performant way to remove background from images

from pathlib import Path
from rembg import remove, new_session

session = new_session()

for file in Path('path/to/folder').glob('*.png'):
    input_path = str(file)
    output_path = str(file.parent / (file.stem + ".out.png"))

    with open(input_path, 'rb') as i:
        with open(output_path, 'wb') as o:
            input = i.read()
            output = remove(input, session=session)
            o.write(output)