In [3]:
import torch
from torch import nn
import torchvision.transforms as transforms
import numpy as np
import os
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.models.segmentation as segmentation
from torchvision.models.segmentation import deeplabv3_resnet101
from scipy.ndimage import grey_dilation, grey_erosion, binary_fill_holes
import segmentation_models_pytorch as smp

def load_model():
    # Load U-Net with the 'resnet34' backbone and pretrained weights
    model = smp.Unet('resnet34', encoder_weights='imagenet', in_channels=3, classes=1)
    model.eval()
    return model

def make_transparent_foreground(pic, mask):
    # split the image into channels
    b, g, r = cv2.split(np.array(pic).astype('uint8'))
    # add an alpha channel with and fill all with transparent pixels (max 255)
    a = np.ones(mask.shape, dtype='uint8') * 255
    # merge the alpha channel back
    alpha_im = cv2.merge([b, g, r, a], 4)
    # create a transparent background
    bg = np.zeros(alpha_im.shape)
    # setup the new mask
    new_mask = np.stack([mask, mask, mask, mask], axis=2)
    # copy only the foreground color pixels from the original image where mask is set
    foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)

    return foreground

def remove_background(model, input_file):
    input_image = Image.open(input_file)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)

    # create a binary (black and white) mask of the profile foreground
    mask = output_predictions.byte().cpu().numpy()
    background = np.zeros(mask.shape)
    bin_mask = np.where(mask, 255, background).astype(np.uint8)

    foreground = make_transparent_foreground(input_image ,bin_mask)

    return foreground, bin_mask

def process_image(input_file, model, device='cuda'):
    foreground, bin_mask = remove_background(model, input_file)
    img_fg = Image.fromarray(foreground)
    if input_file.lower().endswith(('jpg', 'jpeg', 'png', 'bmp', 'gif')):
        img_fg = img_fg.convert('RGB')
    return img_fg

def process_images_in_folder(input_folder, output_folder, model, device='cuda'):
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Get a list of all image files in the input folder
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]

    for image_file in image_files:
        input_image_path = os.path.join(input_folder, image_file)
        output_image = process_image(input_image_path, model, device)

        # Save the processed image in the output folder with the same filename
        output_image_path = os.path.join(output_folder, image_file)
        output_image.save(output_image_path)

        print("Processed image saved at:", output_image_path)

if __name__ == "__main__":
    # Check if GPU is available and set device accordingly
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load the pre-trained model
    deeplab_model = load_model()

    # Set input and output folder paths
    input_folder = "input2"
    output_folder = "output"

    # Process all images in the input folder and save the processed images to the output folder
    process_images_in_folder(input_folder, output_folder, deeplab_model, device)

Processed image saved at: output\1.jpg
Processed image saved at: output\10.jpg
Processed image saved at: output\11.jpg
Processed image saved at: output\2.jpg
Processed image saved at: output\3.jpg
Processed image saved at: output\4.png
Processed image saved at: output\5.jpg
Processed image saved at: output\6.jpg
Processed image saved at: output\7.jpg
Processed image saved at: output\8.jpg
Processed image saved at: output\9.jpg
