In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models.segmentation import fcn_resnet50
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import os
from PIL import Image
from pathlib import Path

- 0: BACKGROUND
- 1: SKY
- 2: VEGETATION
- 3: BUILDING
- 4: WINDOW
- 5: GROUND
- 6: NOISE 
- 7: DOOR 

In [2]:
CLASS_COLORS = {
    0: (0, 0, 0),
    1: (0, 0, 255),
    2: (0, 255, 0),
    3: (0, 125, 125),
    4: (255, 255, 0),
    5: (125, 125, 125),
    6: (255, 0, 0),
    7: (125, 125, 0)
}

In [3]:
class CustomDataset(Dataset):
    def __init__(self, img_dir, mask_dir, class_colors, transform = None, mask_transform = None):
        self.__img_dir = img_dir
        self.__mask_dir = mask_dir
        self.__image_filenames = os.listdir(mask_dir)
        self.__class_colors = class_colors
        self.__img_transform = transform
        self.__mask_transform = mask_transform

    # Convert a rgb mask in a index mask
    def __rgb_mask_to_class(self, mask):

        mask_array = np.array(mask) #PILImage don't have the attribute shape
        label_mask = np.zeros((mask_array.shape[0], mask_array.shape[1]), dtype=np.uint8)
        for idx, value in self.__class_colors.items():
            label_mask[np.all(mask_array == value, axis=-1)] = idx
        return Image.fromarray(label_mask)
    
    def __len__(self):
        return len(self.__image_filenames)
    
    def __getitem__(self, idx):
        image_path = self.__img_dir + self.__image_filenames[idx]
        mask_path = self.__mask_dir + self.__image_filenames[idx]

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("RGB")

        mask = self.__rgb_mask_to_class(mask)

        if self.__img_transform:
            image = self.__img_transform(image)

        if self.__mask_transform:
            mask = self.__mask_transform(mask)

        mask = torch.tensor(np.array(mask), dtype=torch.long)

        return image, mask

In [4]:
image_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(), # Transform to Tensor and set the shap [C, H, W]. It works only with RGB or PIL Images
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # used with pretrained models
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
])

In [5]:
train_dataset = CustomDataset('./TMBuD/images/', './TMBuD/gt_label/', CLASS_COLORS, image_transform, mask_transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [6]:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
num_classes = len(CLASS_COLORS)
model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=1)
model.aux_classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)
model = model.cuda()



In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [8]:
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in train_loader:
        images = images.cuda()
        masks = masks.cuda()
        
        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks.squeeze(1)) # Mask have [batch_size, C, H, W] -> [batch_size, H, W]
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(train_loader):.4f}")

Epoch [1/50], Loss: 0.7665
Epoch [2/50], Loss: 0.4607
Epoch [3/50], Loss: 0.3632
Epoch [4/50], Loss: 0.3065
Epoch [5/50], Loss: 0.2697
Epoch [6/50], Loss: 0.2532
Epoch [7/50], Loss: 0.2130
Epoch [8/50], Loss: 0.1929
Epoch [9/50], Loss: 0.1869
Epoch [10/50], Loss: 0.1817
Epoch [11/50], Loss: 0.1611
Epoch [12/50], Loss: 0.1474
Epoch [13/50], Loss: 0.1422
Epoch [14/50], Loss: 0.1373
Epoch [15/50], Loss: 0.1294
Epoch [16/50], Loss: 0.1229
Epoch [17/50], Loss: 0.1171
Epoch [18/50], Loss: 0.1123
Epoch [19/50], Loss: 0.1087
Epoch [20/50], Loss: 0.1062
Epoch [21/50], Loss: 0.1031
Epoch [22/50], Loss: 0.0999
Epoch [23/50], Loss: 0.0980
Epoch [24/50], Loss: 0.0962
Epoch [25/50], Loss: 0.0949
Epoch [26/50], Loss: 0.0929
Epoch [27/50], Loss: 0.0911
Epoch [28/50], Loss: 0.0898
Epoch [29/50], Loss: 0.0884
Epoch [30/50], Loss: 0.0870
Epoch [31/50], Loss: 0.0864
Epoch [32/50], Loss: 0.0867
Epoch [33/50], Loss: 0.0855
Epoch [34/50], Loss: 0.0836
Epoch [35/50], Loss: 0.0816
Epoch [36/50], Loss: 0.0801
E

In [8]:
def mask_to_color(mask, class_colors):
    """
    Converte una maschera con indici di classe in un'immagine RGB.
    
    Args:
        mask (numpy array): Matrice 2D con valori delle classi.
        class_colors (dict): Dizionario {indice_classe: (R, G, B)}
    
    Returns:
        numpy array: Maschera colorata in formato RGB.
    """
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)

    for cls, color in class_colors.items():
        color_mask[mask == cls] = color

    return color_mask


In [None]:
def remove_class_from_image(image, mask, class_to_remove, replace_with=(0, 0, 0)):
    """
    Rimuove una classe specifica dall'immagine utilizzando la maschera segmentata.
    
    Args:
        image (PIL Image or NumPy array): L'immagine originale (RGB).
        mask (torch.Tensor or NumPy array): La maschera delle classi [H, W].
        class_to_remove (int): L'etichetta della classe da rimuovere.
        replace_with (tuple): Colore RGB con cui sostituire i pixel della classe rimossa.

    Returns:
        NumPy array: L'immagine modificata.
    """
    # Converte l'immagine in NumPy se è un PIL Image
    if isinstance(image, Image.Image):
        image = np.array(image) 

    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy() 

    # Assicura che `image` abbia shape [H, W, 3]
    if len(image.shape) == 2:  # Se è in grayscale, converte in RGB
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

    # Se le dimensioni non corrispondono, ridimensiona la maschera
    if mask.shape[:2] != image.shape[:2]:
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Crea una maschera booleana per i pixel da rimuovere
    mask_remove = (mask == class_to_remove)

    # Controllo sulle dimensioni
    if mask_remove.shape[:2] != image.shape[:2]:
        raise ValueError(f"Dimensioni non corrispondenti! Image: {image.shape}, Mask: {mask_remove.shape}")

    # Rimuove la classe
    image[mask_remove] = replace_with

    return image

In [10]:
def crop_background(image, threshold=10):
    """
    Ritaglia i bordi dell'immagine eliminando le aree nere o bianche (o sfondi uniformi).

    Args:
        image (NumPy array): L'immagine da ritagliare.
        threshold (int): Tolleranza per rilevare colori simili al bordo (default=10).

    Returns:
        NumPy array: L'immagine ritagliata.
    """
    # Converte l'immagine in scala di grigi
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    # Trova il colore di sfondo (prendiamo il primo pixel in alto a sinistra)
    bg_color = gray[0, 0]

    # Crea una maschera dove il colore è simile al background
    mask = np.abs(gray - bg_color) > threshold

    # Trova i contorni della regione utile
    coords = np.argwhere(mask)

    if coords.size > 0:
        y_min, x_min = coords.min(axis=0)
        y_max, x_max = coords.max(axis=0)

        # Ritaglia l'immagine
        cropped_image = image[y_min:y_max+1, x_min:x_max+1]
    else:
        cropped_image = image  # Se tutto è sfondo, non taglia nulla

    return cropped_image

Segmenta le immagini 

In [None]:
import matplotlib.pyplot as plt

image_folder = Path('./results/geo_rectification')

model.eval()


for image_path in image_folder.glob('*'):
    image = Image.open(image_path).convert("RGB")
    image_tensor = image_transform(image).unsqueeze(0).cuda()

    with torch.no_grad():
        output = model(image_tensor)['out']
        output = torch.argmax(output, dim=1).cpu().numpy()[0]

        color_mask = mask_to_color(output, CLASS_COLORS)

        plt.imsave('./results/segmentation/' + image_path.name, color_mask)
        
        """
        # Visualizza risultati
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title("Immagine Originale")

        plt.subplot(1, 2, 2)
        plt.imshow(color_mask)
        plt.title("Maschera Predetta (Colorata)")

        plt.show()
        """

Rimuove una classe dalle immagini

In [11]:
import matplotlib.pyplot as plt

state_dict = torch.load('./results/weights/fcn_resnet50_weights.pth')

model.load_state_dict(state_dict)
model.eval()

image_folder = Path('./data')

for image_path in image_folder.glob('*'):
    image = Image.open(image_path).convert("RGB")
    image_tensor = image_transform(image).unsqueeze(0).cuda()

    with torch.no_grad():
        output = model(image_tensor)['out']
        output = torch.argmax(output, dim=1).cpu().numpy()[0]

        rem_image = remove_class_from_image(image, output, 1)
        cropped_img = crop_background(rem_image)

        plt.imsave('./results/no_sky_images/' + image_path.name, cropped_img)