<a href="https://colab.research.google.com/github/federicOO1/LAB-IA/blob/main/Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## import librerie


In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from tifffile import imread

In [33]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [34]:
import os
os.chdir("/content/drive/MyDrive/4_Ortho_RGBIR")

In [27]:
class PotsdamDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.transform = transform
        self.image_paths = []
        self.world_file_paths = []

        # Leggi i percorsi delle immagini e dei file .tfw
        for file_name in os.listdir(data_folder):
            if file_name.endswith('.tif'):
                image_path = os.path.join(data_folder, file_name)
                world_file_path = os.path.join(data_folder, file_name.replace('.tif', '.tfw'))
                if os.path.exists(world_file_path):
                    self.image_paths.append(image_path)
                    self.world_file_paths.append(world_file_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        world_file_path = self.world_file_paths[idx]

        # Leggi l'immagine e il file .tfw
        image = imread(image_path)
        world_params = self.load_world_file(world_file_path)

        # Applica trasformazioni se definite
        if self.transform:
            image = self.transform(image)

        return image, world_params

    def load_world_file(world_file_path):
      with open(world_file_path, 'r') as f:
          lines = f.readlines()
          print(lines)  # Stampare il contenuto del file .tfw per debug
          try:
              parameters = tuple(float(line.strip()) for line in lines if line.strip())
              if len(parameters) == 6:
                  return parameters
              else:
                  raise ValueError("Il file .tfw non contiene 6 parametri.")
          except Exception as e:
              print(f"Errore durante la lettura dei parametri di georeferenziazione: {str(e)}")
              return None



In [45]:
from torch.utils.data import Dataset
from tifffile import imread
import os
import torch
import numpy as np


class PotsdamDataset(Dataset):
    def __init__(self, data_folder):
        self.data_folder = data_folder
        self.image_paths = []
        self.world_file_paths = []
        self.mask_paths = []

        # Leggi i percorsi delle immagini, dei file .tfw e delle maschere
        for file_name in os.listdir(data_folder):
            if file_name.endswith('.tif'):
                image_path = os.path.join(data_folder, file_name)
                world_file_path = os.path.join(data_folder, file_name.replace('.tif', '.tfw'))
                mask_path = os.path.join(data_folder, file_name.replace('.tif', '_mask.tif'))

                if os.path.exists(world_file_path) and os.path.exists(mask_path):
                    self.image_paths.append(image_path)
                    self.world_file_paths.append(world_file_path)
                    self.mask_paths.append(mask_path)

    def __len__(self):
        return len(self.image_paths)

    def load_world_file(self, world_file_path):
        with open(world_file_path, 'r') as f:
          lines = f.readlines()
          print(lines)  # Stampare il contenuto del file .tfw per debug
          try:
              parameters = tuple(float(line.strip()) for line in lines if line.strip())
              if len(parameters) == 6:
                  return parameters
              else:
                  raise ValueError("Il file .tfw non contiene 6 parametri.")
          except Exception as e:
              print(f"Errore durante la lettura dei parametri di georeferenziazione: {str(e)}")
              return None

    def rgb_to_label(self, rgb_mask):
      # Definisci i colori delle classi nella maschera RGB
      colors_to_labels = {
          (0, 0, 0): 0,  # Colore di sfondo -> Classe 0
          (255, 0, 0): 1,  # Colore rosso -> Classe 1
          (0, 255, 0): 2,  # Colore verde -> Classe 2
          (0, 0, 255): 3  # Colore blu -> Classe 3
          # Aggiungi altri colori e classi se necessario
      }

      # Inizializza un'immagine di etichette con valori di sfondo (0)
      label_mask = np.zeros(rgb_mask.shape[:2], dtype=np.uint8)

      # Converte ciascun pixel nella maschera RGB nel corrispondente indice di classe
      for color, label in colors_to_labels.items():
          mask = np.all(rgb_mask == np.array(color), axis=-1)
          label_mask[mask] = label

      # Converte l'array di etichette in un tensore PyTorch con tipo LongTensor
      label_mask_tensor = torch.from_numpy(label_mask).long()

      return label_mask_tensor

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        world_file_path = self.world_file_paths[idx]
        mask_path = self.mask_paths[idx]

        # Carica l'immagine TIFF utilizzando Tifffile
        image = imread(image_path)

        # Carica i parametri di georeferenziazione dal file .tfw
        world_params = self.load_world_file(world_file_path)

        # Carica la maschera TIFF
        mask = imread(mask_path)

        # Converti la maschera RGB nel formato appropriato per CrossEntropyLoss
        label = self.rgb_to_label(mask)

        # Restituisci l'immagine, la maschera e i parametri di georeferenziazione
        return image, label, world_params


In [49]:
# Definisci il percorso della cartella contenente i dati .tif e .tfw
data_folder = '/content/drive/My Drive/4_Ortho_RGBIR'

# Crea un'istanza del dataset
dataset = PotsdamDataset(data_folder)

dataset.mask_paths

[]

In [39]:


# Crea un DataLoader
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)





In [10]:
from PIL import Image

# Prova a caricare un'immagine singola
image_path = '/content/drive/My Drive/4_Ortho_RGBIR/top_potsdam_7_12_RGBIR.tif'
try:
    image = Image.open(image_path)
    print("Immagine caricata con successo!")
except Exception as e:
    print(f"Errore nel caricare l'immagine: {e}")


Errore nel caricare l'immagine: cannot identify image file '/content/drive/My Drive/4_Ortho_RGBIR/top_potsdam_7_12_RGBIR.tif'


In [None]:



# Lista di percorsi di immagini nel tuo dataset
image_paths = ['path/to/image1.tif', 'path/to/image2.tif', 'path/to/image3.tif']

# Numero di righe e colonne nella griglia
num_rows = 1
num_cols = len(image_paths)

# Crea una griglia di immagini
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5))

for i, image_path in enumerate(image_paths):
    image = Image.open(image_path)
    axes[i].imshow(image)
    axes[i].set_title(f'Immagine {i+1}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()
