In [7]:
!pip install torch torchvision torchaudio



Collecting torch
  Downloading torch-2.5.0-cp312-cp312-win_amd64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.20.0-cp312-cp312-win_amd64.whl.metadata (6.2 kB)
Collecting torchaudio
  Downloading torchaudio-2.5.0-cp312-cp312-win_amd64.whl.metadata (6.5 kB)
Collecting filelock (from torch)
  Downloading filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.10.0-py3-none-any.whl.metadata (11 kB)
Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.5.0-cp312-cp312-win_amd64.whl (203.1 MB)
   ---------------------------------------- 0.0/203.1

In [10]:
import torch
import cv2
import numpy as np
import os
from torchvision.models.segmentation import (
    deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large,
    DeepLabV3_ResNet50_Weights, DeepLabV3_ResNet101_Weights, DeepLabV3_MobileNet_V3_Large_Weights
)
from PIL import Image

def load_model(model_name: str):
    if model_name.lower() not in ("mobilenet", "resnet_50", "resnet_101"):
        raise ValueError("'model_name' should be one of ('mobilenet', 'resnet_50', 'resnet_101')")
    
    if model_name == "resnet_50":
        model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
        transforms = DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.transforms()
    elif model_name == "resnet_101":
        model = deeplabv3_resnet101(weights=DeepLabV3_ResNet101_Weights.DEFAULT)
        transforms = DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.transforms()
    else:
        model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT)
        transforms = DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.transforms()

    model.eval()
    _ = model(torch.randn(1, 3, 520, 520))
    return model, transforms

def create_main_person_mask(output):
    output_predictions = output.argmax(1).squeeze(0).cpu().numpy()
    person_class_id = 15
    person_mask = (output_predictions == person_class_id).astype(np.uint8)
    
    # Find the largest connected component (main person)
    contours, _ = cv2.findContours(person_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        mask = np.zeros_like(person_mask, dtype=np.uint8)
        cv2.drawContours(mask, [largest_contour], -1, color=1, thickness=cv2.FILLED)
    else:
        mask = person_mask  # If no contours, fall back to original mask

    return mask

def remove_background(original_image, binary_mask):
    original_image_np = np.asarray(original_image)
    if len(binary_mask.shape) == 2:
        binary_mask = np.stack([binary_mask] * 3, axis=-1)
    binary_mask = cv2.resize(binary_mask, (original_image_np.shape[1], original_image_np.shape[0]), interpolation=cv2.INTER_NEAREST)
    clear_segmented_image = original_image_np * binary_mask
    return clear_segmented_image

def perform_inference(model_name, image_dir, save_dir):
    model, transforms = load_model(model_name)
    os.makedirs(save_dir, exist_ok=True)

    for img_file in os.listdir(image_dir):
        img_path = os.path.join(image_dir, img_file)
        
        # Skip hidden files or directories
        if img_file.startswith('.') or os.path.isdir(img_path):
            continue
        
        # Try to open the image and skip if it fails
        try:
            image = Image.open(img_path).convert("RGB")
        except (IOError, FileNotFoundError) as e:
            print(f"Skipping file {img_file} due to error: {e}")
            continue
        
        input_image = transforms(image).unsqueeze(0)
        with torch.no_grad():
            output = model(input_image)["out"]

        # Create mask specifically for the main "person" in the image
        main_person_mask = create_main_person_mask(output)
        segmented_image = remove_background(image, main_person_mask)

        output_path = os.path.join(save_dir, img_file)
        cv2.imwrite(output_path, segmented_image)
        print(f"Saved segmented image to: {output_path}")

# Set paths and parameters
ROOT_raw_image_directory = r"data2/5"
output_segmented_directory = r"data_seg/5"
model_name = 'resnet_50'

perform_inference(
    model_name=model_name,
    image_dir=ROOT_raw_image_directory,
    save_dir=output_segmented_directory
)


Saved segmented image to: data_seg/5\Adele.jpg
Saved segmented image to: data_seg/5\Amber Riley.jpg
Saved segmented image to: data_seg/5\Amy Schumer.jpg
Saved segmented image to: data_seg/5\Britney Spears.jpg
Saved segmented image to: data_seg/5\Dawn French.jpg
Saved segmented image to: data_seg/5\Drew Barrymore.jpg
Saved segmented image to: data_seg/5\Gabourey Sidibe.jpg
Saved segmented image to: data_seg/5\Jennifer Hudson.jpg
Saved segmented image to: data_seg/5\Meghan Trainor.jpg
Skipping file Mindy Kaling.png due to error: cannot identify image file 'C:\\Users\\91820\\Desktop\\deep learnin\\ImageClassification\\data2\\5\\Mindy Kaling.png'
Saved segmented image to: data_seg/5\Nikki Blonsky.jpg
Saved segmented image to: data_seg/5\Octavia Spencer.jpg
Saved segmented image to: data_seg/5\Oprah Winfrey.jpg
Saved segmented image to: data_seg/5\Queen Latifah.jpg
Saved segmented image to: data_seg/5\Rebel Wilson.jpg
Skipping file Rosie O’donnell.png due to error: cannot identify image fil