In [1]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import torch
from torchvision.models.segmentation import deeplabv3_resnet101
import matplotlib.pyplot as plt 
import time

# Step 1: Semantic Segmentation with DeepLab
def deeplab_segmentation(image):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = transform(image)
    input_batch = input_tensor.unsqueeze(0)

    model = deeplabv3_resnet101(pretrained=True)
    model.eval()
    with torch.no_grad():
        output = model(input_batch)['out'][0]
        output_predictions = output.argmax(0)

    return output_predictions.numpy()

# Step 2: Combine Semantic Segmentation Masks with Mask Generation Code
def generate_binary_masks(image_path,exclude_classes):

    image = Image.open(image_path).convert('RGB')

    # Semantic Segmentation with DeepLab
    semantic_mask = deeplab_segmentation(image)
    # Save semantic segmentation mask
    semantic_mask_image = Image.fromarray(semantic_mask.astype(np.uint8))

# Step 3: Run the Combined Code
if __name__ == '__main__':
    image_path = '/home/moborobo/datasets/test_robot/00035-image-1681470490.1975896.png'
    exclude_classes = [0] # only exclude background

    num_iterations = 10
    total_time = 0
    for i in range(num_iterations):
        start_time = time.time()
        generate_binary_masks(image_path, exclude_classes)
        end_time = time.time()
        iteration_time = end_time - start_time
        print(f"Iteration {i+1}: {iteration_time} seconds")
        total_time += iteration_time

    average_time = total_time / num_iterations
    print(f"Average running time over {num_iterations} iterations: {average_time} seconds")


Iteration 1: 14.758919477462769 seconds
Iteration 2: 14.310077905654907 seconds
Iteration 3: 14.277084589004517 seconds
Iteration 4: 13.94562840461731 seconds
Iteration 5: 14.155790567398071 seconds
Iteration 6: 14.114462852478027 seconds
Iteration 7: 14.141812086105347 seconds
Iteration 8: 14.104209661483765 seconds
Iteration 9: 14.05760145187378 seconds
Iteration 10: 14.088724136352539 seconds
Average running time over 10 iterations: 14.195431113243103 seconds
