In [2]:
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
import torchvision.models.segmentation as models

def get_bounding_boxes(mask):
    """
    Given a binary mask, return bounding boxes of connected components.
    Each bounding box is represented as (x_min, y_min, x_max, y_max).
    """
    mask = mask.cpu().numpy().astype(np.uint8)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    boxes = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        boxes.append((x, y, x + w, y + h))
    
    return boxes

def process_image(image_path, model):
    """
    Process the image to segment tanks and return their coordinates.
    """
    image = Image.open(image_path).convert("RGB")
    original_width, original_height = image.size
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])
    image = transform(image).unsqueeze(0).to('cuda')
    
    model.eval()
    with torch.no_grad():
        output = model(image)['out']
        mask = torch.argmax(output, dim=1).squeeze()
        
    boxes = get_bounding_boxes(mask)
    
    coordinates = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        # Calculate the center points
        x_center = (x_min + x_max) / 2
        y_center = (y_min + y_max) / 2
        # Calculate the ratios
        width_ratio = (x_max - x_min) / original_width
        height_ratio = (y_max - y_min) / original_height
        # Append to the list
        coordinates.append([x_center, y_center, width_ratio, height_ratio])
    
    return coordinates

# Load the model (assuming it's already trained and saved)
model = models.deeplabv3_resnet50(pretrained=False)
model.classifier[4] = torch.nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
model.load_state_dict(torch.load('path/to/model.pth'))
model = model.to('device')

# Process an image
image_path = 'real_images/images/000000.jpg'
coordinates = process_image(image_path, model)
print(coordinates)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/francescostella/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 38.9MB/s]


FileNotFoundError: [Errno 2] No such file or directory: 'path/to/model.pth'