In [1]:
import os
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from u2net import U2NET

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = U2NET(3, 1)  # 3 input channels, 1 output channel
model_path = 'saved_models/u2net/u2net.pth'

In [3]:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval() 

U2NET(
  (stage1): RSU7(
    (rebnconvin): REBNCONV(
      (conv_s1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (rebnconv1): REBNCONV(
      (conv_s1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv2): REBNCONV(
      (conv_s1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn_s1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu_s1): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (rebnconv3): REBNCONV(
      (conv_s1): Conv2d(32, 32, k

In [4]:
def preprocess_image(image_path):
    image = Image.open(image_path)
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_tensor = transform(image).unsqueeze(0)
    return input_tensor, image.size

def postprocess_mask(mask, original_size):
    mask = mask.squeeze().cpu().detach().numpy()
    mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_LINEAR)
    mask = (mask * 255).astype(np.uint8)
    return mask

def generate_alpha_mask(image_path):
    input_tensor, original_size = preprocess_image(image_path)
    
    with torch.no_grad():
        output = model(input_tensor)[0]
    
    alpha_mask = postprocess_mask(output, original_size)
    return alpha_mask

def save_mask_overlay(image_path, alpha_mask, output_folder):
    output_path = os.path.join(output_folder, os.path.basename(image_path))
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close() 



In [5]:
def process_folder(folder_path, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    for filename in os.listdir(folder_path):
        if filename.endswith(('.jpg', '.jpeg', '.png')):
            image_path = os.path.join(folder_path, filename)
            alpha_mask = generate_alpha_mask(image_path)
            
            # Save the alpha mask
            output_path = os.path.join(output_folder, f"alpha_mask_{filename}")
            cv2.imwrite(output_path, alpha_mask)
            print(f"Processed and saved: {output_path}")

In [6]:
input_folder = 'image'
output_folder = 'output'

process_folder(input_folder, output_folder)



Processed and saved: output\alpha_mask_i1.jpg
Processed and saved: output\alpha_mask_i2.jpg
Processed and saved: output\alpha_mask_i3.jpg
Processed and saved: output\alpha_mask_i4.jpg
Processed and saved: output\alpha_mask_i5.jpg
