In [1]:
import torch
from torch import nn
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.models.segmentation as segmentation
from torchvision.models.segmentation import deeplabv3_resnet101
from scipy.ndimage import binary_dilation, binary_erosion, binary_fill_holes

def load_model(device='cuda'):
    # Using U-Net instead of DeepLabV3
    model = deeplabv3_resnet101(weights=True)
    model.to(device)
    model.eval()
    return model

def process_image(image_path, model, device='cuda'):
    # Load and preprocess the input image
    img = Image.open(image_path).convert("RGB")

    # Define the target image size for faster processing
    target_size = 768

    # Resize the image for faster processing
    transform = transforms.Compose([
        transforms.Resize((target_size, target_size)),
        transforms.ToTensor()
    ])
    input_tensor = transform(img).unsqueeze(0).to(device)

    # Use the segmentation model to get the segmentation mask
    with torch.no_grad():
        output = model(input_tensor)['out'][0]
        output_predictions = output.argmax(0).cpu().numpy()

    # Upscale the output mask to the original size
    upscale_transform = transforms.Resize(img.size)
    output_mask = upscale_transform(Image.fromarray(output_predictions.astype('uint8')))

    # Convert the Image object to a NumPy array
    output_mask_np = np.array(output_mask)

    # Resize the output mask to match the original image size
    output_mask_resized = Image.fromarray(output_mask_np).resize(img.size)

    # Create a binary mask for the object
    object_mask_np = torch.tensor(np.array(output_mask_resized) == 15)  # Class 15 represents the object of interest

    # Create a new binary mask containing only the top pixels of the object
    top_pixels_mask = torch.zeros_like(object_mask_np)
    top_pixels_mask[0:100, :] = object_mask_np[0:100, :]  # Adjust the top region height as needed

    # Perform dilation/erosion on the binary mask to enhance the top pixels of the object
    top_pixels_mask_eroded = binary_erosion(top_pixels_mask.numpy(), structure=np.ones((5, 5)))

    # Combine the original object mask with the eroded top pixels mask
    object_mask_eroded = object_mask_np.clone()
    object_mask_eroded[0:50, :] = torch.tensor(top_pixels_mask_eroded[0:50, :])

    # Create a new image with a black background and the object from the original image
    new_image = Image.new("RGB", img.size, color=(0, 0, 0))
    object_pixels = img.copy().convert("RGBA").crop((0, 0, img.size[0], img.size[1]))
    new_image.paste(object_pixels, (0, 0), mask=Image.fromarray(object_mask_eroded.numpy()))

    return new_image

def process_images_in_folder(input_folder, output_folder, model, device='cuda'):
    # Create the output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Get a list of all image files in the input folder
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif'))]

    for image_file in image_files:
        input_image_path = os.path.join(input_folder, image_file)
        output_image = process_image(input_image_path, model, device)

        # Save the processed image in the output folder with the same filename
        output_image_path = os.path.join(output_folder, image_file)
        output_image.save(output_image_path)

        print("Processed image saved at:", output_image_path)

if __name__ == "__main__":
    # Check if GPU is available and set device accordingly
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Load the pre-trained model
    model = load_model(device)

    # Set input and output folder paths
    input_folder = "input2"
    output_folder = "output2"

    # Process all images in the input folder and save the processed images to the output folder
    process_images_in_folder(input_folder, output_folder, model, device)





Processed image saved at: output2\image_2023-07-19_17-00-48.png
Processed image saved at: output2\photo1689764388.jpeg
