In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


# Conteggio macro categorie del dataset originale

In [None]:
import json
import os
from collections import defaultdict

# Percorso della cartella delle annotazioni
annotations_dir = '/content/drive/MyDrive/Colab Notebooks/TACO dataset/ds/ann'

# Lista delle classi
classes = [
    'background', "Plastic bag & wrapper", "Cup", "Plastic gloves", "Styrofoam piece",
    "Aluminium foil", "Shoe", "Unlabeled litter", "Lid", "Rope & strings",
    "Broken glass", "Paper bag", "Blister pack", "Carton", "Bottle cap", "Paper",
    "Plastic container", "Pop tab", "Straw", "Bottle", "Plastic utensils",
    "Other plastic", "Glass jar", "Battery", "Food waste", "Scrap metal",
    "Can", "Cigarette", "Squeezable tube"
]

# Dizionario per tenere traccia del conteggio delle classi
class_counts = defaultdict(int)

# Set per tracciare gli oggetti già conteggiati per ID
counted_object_ids = set()

# Itera su ogni file nella directory delle annotazioni
for filename in os.listdir(annotations_dir):
    if filename.endswith('.json'):  # Considera solo i file JSON
        file_path = os.path.join(annotations_dir, filename)
        with open(file_path, 'r') as f:
            data = json.load(f)

            # Itera sugli oggetti presenti nelle annotazioni
            for obj in data.get("objects", []):
                # Trova il valore del supercategory dai tag dell'oggetto
                supercategory = None
                for tag in obj.get("tags", []):
                    if tag.get("name") == "supercategory":
                        supercategory = tag.get("value")
                        break

                # Controlla se l'oggetto ha più di 2 punti (per i contorni)
                points = obj.get("points", {}).get("exterior", [])
                if len(points) > 2:
                    # Se l'oggetto non è già stato contato
                    if supercategory in classes and obj['id'] not in counted_object_ids:
                        # Incrementa il conteggio per la classe
                        class_counts[supercategory] += 1
                        # Aggiungi l'ID dell'oggetto al set per evitare il doppio conteggio
                        counted_object_ids.add(obj['id'])

# Stampa il conteggio totale per ogni classe
print("Conteggio oggetti per classe:")
for cls in classes:
    print(f"{cls}: {class_counts[cls]}")


Conteggio oggetti per classe:
background: 0
Plastic bag & wrapper: 976
Cup: 195
Plastic gloves: 0
Styrofoam piece: 124
Aluminium foil: 62
Shoe: 7
Unlabeled litter: 551
Lid: 94
Rope & strings: 30
Broken glass: 142
Paper bag: 28
Blister pack: 7
Carton: 255
Bottle cap: 290
Paper: 151
Plastic container: 72
Pop tab: 99
Straw: 183
Bottle: 448
Plastic utensils: 44
Other plastic: 290
Glass jar: 6
Battery: 2
Food waste: 8
Scrap metal: 20
Can: 275
Cigarette: 669
Squeezable tube: 7


# Download immagini e maschere

In [None]:
def download_taco_images_and_masks(annotations_path, target_class=None, max_images=None, image_output_dir=None, mask_output_dir=None, skip_existing=True):
    """
    Scarica immagini da TACO e genera maschere semantiche basate sulle categorie secondo class_finals.
    Gestisce correttamente l'orientamento EXIF per evitare rotazioni tra immagini e maschere.

    NOVITÀ: Può filtrare per una classe specifica e limitare il numero di immagini scaricate.

    Args:
        annotations_path (str): Path alle annotazioni json
        target_class (int, optional): ID della classe da scaricare (es. 4 per carta/cartone). Se None, scarica tutte.
        max_images (int, optional): Numero massimo di immagini da scaricare per la classe target. Se None, nessun limite.
        image_output_dir (str, optional): Cartella per immagini. Default = stessa di annotations.
        mask_output_dir (str, optional): Cartella per maschere. Default = image_output_dir/masks
        skip_existing (bool): Se True, salta immagini e maschere già esistenti. Se False, rigenera tutto.

    Returns:
        int: Numero immagini scaricate
    """

    import os
    import json
    import requests
    import sys
    import numpy as np
    from PIL import Image, ImageDraw, ImageOps
    from io import BytesIO
    from collections import defaultdict

    # Dizionario fornito dall'utente
    class_finals = {
        # PLASTICA E POLIMERI (1)
    "Other plastic bottle": 1,
    "Clear plastic bottle": 1,
    "Plastic bottle cap": 1,
    "Spread tub": 1,
    "Tupperware": 1,
    "Disposable food container": 1,
    "Other plastic container": 1,
    "Plastic film": 1,
    "Garbage bag": 1,
    "Single-use carrier bag": 1,
    "Polypropylene bag": 1,
    "Plastified paper bag": 1,
    "Carded blister pack": 1,
    "Other plastic wrapper": 1,
    "Crisp packet": 1,
    "Disposable plastic cup": 1,
    "Other plastic cup": 1,
    "Plastic lid": 1,
    "Plastic glooves": 1,
    "Plastic utensils": 1,
    "Plastic straw": 1,
    "Other plastic": 1,
    "Six pack rings": 1,

    # METALLI (2)
    "Food Can": 2,
    "Aerosol": 2,
    "Drink can": 2,
    "Metal bottle cap": 2,
    "Metal lid": 2,
    "Pop tab": 2,
    "Scrap metal": 2,
    "Aluminium foil": 2,
    "Aluminium blister pack": 2,

    # VETRO (3)
    "Glass bottle": 3,
    "Glass jar": 3,
    "Glass cup": 3,
    "Broken glass": 3,

    # CARTA E CARTONE (4)
    "Other carton": 4,
    "Egg carton": 4,
    "Drink carton": 4,
    "Corrugated carton": 4,
    "Meal carton": 4,
    "Pizza box": 4,
    "Magazine paper": 4,
    "Normal paper": 4,
    "Wrapping paper": 4,
    "Paper bag": 4,
    "Paper cup": 4,
    "Paper straw": 4,
    "Tissues": 4,
    "Toilet tube": 4,

    # polistirolo (5)
    "Foam cup": 5,
    "Foam food container": 5,
    "Styrofoam piece": 5,

    # SIGARETTE (6)
    "Food waste": 6,
    "Battery": 6,
    "Cigarette": 6,

    # DA RIMUOVERE (8)
    "Shoe": 8,
    "Rope & strings": 8,
    "Squeezable tube": 8,

    # NON CLASSIFICATI (7)
    "Unlabeled litter": 7
    }

    def get_exif_rotation_angle(img):
        """Restituisce l'angolo di rotazione basato sui metadati EXIF"""
        try:
            exif = img._getexif()
            if exif is not None:
                orientation = exif.get(274)  # Tag EXIF per orientamento
                if orientation == 3:
                    return 180
                elif orientation == 6:
                    return 270
                elif orientation == 8:
                    return 90
        except:
            pass
        return 0

    def apply_rotation_to_polygon(polygon, angle, width, height):
        """Applica la rotazione ai punti del poligono"""
        if angle == 0:
            return polygon

        points = []
        for i in range(0, len(polygon), 2):
            x, y = polygon[i], polygon[i+1]

            if angle == 90:
                new_x, new_y = height - y, x
                new_width, new_height = height, width
            elif angle == 180:
                new_x, new_y = width - x, height - y
                new_width, new_height = width, height
            elif angle == 270:
                new_x, new_y = y, width - x
                new_width, new_height = height, width
            else:
                new_x, new_y = x, y
                new_width, new_height = width, height

            points.extend([new_x, new_y])

        return points, new_width, new_height

    def image_contains_target_class(image_id, annotations, category_id_to_name, target_class):
        """Verifica se un'immagine contiene la classe target"""
        anns = image_id_to_anns[image_id]
        for ann in anns:
            category_id = ann['category_id']
            category_name = category_id_to_name.get(category_id)
            label_id = class_finals.get(category_name, 0)
            if label_id == target_class:
                return True
        return False

    if image_output_dir is None:
        image_output_dir = os.path.dirname(annotations_path)
    if mask_output_dir is None:
        mask_output_dir = os.path.join(image_output_dir, "masks")

    os.makedirs(image_output_dir, exist_ok=True)
    os.makedirs(mask_output_dir, exist_ok=True)

    with open(annotations_path, 'r') as f:
        annotations = json.load(f)

    category_id_to_name = {c['id']: c['name'] for c in annotations['categories']}
    image_id_to_anns = defaultdict(list)
    for ann in annotations['annotations']:
        image_id_to_anns[ann['image_id']].append(ann)

    images_to_process = annotations['images']
    if target_class is not None:
        images_to_process = [img for img in annotations['images']
                           if image_contains_target_class(img['id'], annotations, category_id_to_name, target_class)]
        print(f"Trovate {len(images_to_process)} immagini contenenti la classe {target_class}")

    if max_images is not None and len(images_to_process) > max_images:
        images_to_process = images_to_process[:max_images]
        print(f"Limitate a {max_images} immagini")

    nr_images = len(images_to_process)
    downloaded_count = 0
    skipped_count = 0

    print(f"Processando {nr_images} immagini...")

    for i, image in enumerate(images_to_process):
        file_name = image['file_name']
        url_original = image['flickr_url']
        width, height = image['width'], image['height']

        subfolder = os.path.dirname(file_name)
        base_name = os.path.basename(file_name)
        new_file_name = f"{subfolder}_{base_name}" if subfolder else base_name

        img_path = os.path.join(image_output_dir, new_file_name)
        mask_path = os.path.join(mask_output_dir, new_file_name.replace('.jpg', '.png').replace('.JPG', '.png'))

        if skip_existing and os.path.isfile(img_path) and os.path.isfile(mask_path):
            skipped_count += 1
            bar_size = 30
            x = int(bar_size * i / nr_images)
            sys.stdout.write("%s[%s%s] - %i/%i (skipped: %i)\r" % ('Processing: ', "=" * x, "." * (bar_size - x), i + 1, nr_images, skipped_count))
            sys.stdout.flush()
            continue

        try:
            if not os.path.isfile(img_path):
                response = requests.get(url_original)
                response.raise_for_status()

                img_raw = Image.open(BytesIO(response.content)).convert('RGB')
                rotation_angle = get_exif_rotation_angle(img_raw)

                img = ImageOps.exif_transpose(img_raw)
                img.save(img_path)
                downloaded_count += 1
            else:
                img_raw = Image.open(img_path).convert('RGB')
                rotation_angle = get_exif_rotation_angle(img_raw)
                img = ImageOps.exif_transpose(img_raw)

            if not skip_existing or not os.path.isfile(mask_path):
                if rotation_angle in [90, 270]:
                    final_width, final_height = height, width
                else:
                    final_width, final_height = width, height

                mask = np.zeros((final_height, final_width), dtype=np.uint8)
                anns = image_id_to_anns[image['id']]

                for ann in anns:
                    category_id = ann['category_id']
                    category_name = category_id_to_name.get(category_id)
                    label_id = class_finals.get(category_name, 0)

                    segmentation = ann.get('segmentation', [])
                    for poly in segmentation:
                        if rotation_angle != 0:
                            rotated_poly, mask_width, mask_height = apply_rotation_to_polygon(
                                poly, rotation_angle, width, height
                            )
                        else:
                            rotated_poly = poly
                            mask_width, mask_height = width, height

                        mask_img = Image.new('L', (mask_width, mask_height), 0)
                        ImageDraw.Draw(mask_img).polygon(rotated_poly, outline=label_id, fill=label_id)
                        mask_array = np.array(mask_img)
                        mask = np.maximum(mask, mask_array)

                img_width, img_height = img.size
                if mask.shape[1] != img_width or mask.shape[0] != img_height:
                    mask_resized = Image.fromarray(mask).resize((img_width, img_height), Image.NEAREST)
                else:
                    mask_resized = Image.fromarray(mask)

                mask_resized.save(mask_path)

        except requests.exceptions.HTTPError as e:
            if e.response.status_code == 403:
                print(f"\nError 403: {file_name} - {url_original}")
            else:
                print(f"\nErrore HTTP su {file_name}: {e}")
            continue
        except Exception as e:
            print(f"\nErrore nel processare {file_name}: {e}")
            continue

        bar_size = 30
        x = int(bar_size * i / nr_images)
        sys.stdout.write("%s[%s%s] - %i/%i (skipped: %i)\r" % ('Processing: ', "=" * x, "." * (bar_size - x), i + 1, nr_images, skipped_count))
        sys.stdout.flush()

    sys.stdout.write('Finished\n')

    if target_class is not None:
        print(f'Scaricate {downloaded_count} nuove immagini della classe {target_class} (saltate {skipped_count}) e generate maschere in {mask_output_dir}')
    else:
        print(f'Scaricate {downloaded_count} nuove immagini (saltate {skipped_count}) e generate maschere in {mask_output_dir}')

    return downloaded_count

In [None]:
download_taco_images_and_masks(annotations_path='/content/drive/MyDrive/TACO dataset/annotations.json',
                               image_output_dir='/content/drive/MyDrive/TACO dataset/images_new_categories',
                               mask_output_dir='/content/drive/MyDrive/TACO dataset/masks_new_categories')

In [None]:
download_taco_images_and_masks(annotations_path='/content/drive/MyDrive/TACO dataset/annotations_unofficial.json',
                               target_class=3,
                               image_output_dir='/content/drive/MyDrive/TACO dataset/images_new_categories',
                               mask_output_dir='/content/drive/MyDrive/TACO dataset/masks_new_categories')

Rimozione immagini e maschere contenenti la classe 8 (da rimuovere)

In [None]:
import os
from PIL import Image
import numpy as np

def remove_images_with_class_8(image_dir, mask_dir):
    """
    Elimina le immagini e le maschere se la classe 8 è presente nella maschera (classe da rimuovere).

    Args:
        image_dir (str): Path alla cartella contenente le immagini.
        mask_dir (str): Path alla cartella contenente le maschere (grayscale PNG).
    """

    removed_count = 0
    total_files = len(os.listdir(mask_dir))

    for i, mask_filename in enumerate(os.listdir(mask_dir)):
        if not mask_filename.lower().endswith('.png'):
            continue

        mask_path = os.path.join(mask_dir, mask_filename)
        image_path = os.path.join(image_dir, mask_filename.replace('.png', '.jpg'))

        if not os.path.isfile(image_path):
            image_path = image_path.replace('.jpg', '.JPG')
            if not os.path.isfile(image_path):
                print(f"Immagine non trovata per {mask_filename}")
                continue

        try:
            mask = np.array(Image.open(mask_path))
            if 8 in mask:
                os.remove(mask_path)
                os.remove(image_path)
                removed_count += 1
        except Exception as e:
            print(f"Errore con {mask_filename}: {e}")

        bar_size = 30
        x = int(bar_size * i / total_files)
        print("%s[%s%s] - %i/%i" % ('Elaborazione: ', "=" * x, "." * (bar_size - x), i+1, total_files), end="\r")

    print(f"\nRimosse {removed_count} coppie immagine/maschera contenenti la classe 8.")


In [None]:
remove_images_with_class_8("/content/drive/MyDrive/TACO dataset/images_new_categories", "/content/drive/MyDrive/TACO dataset/masks_new_categories")

# Elaborazione annotazioni manuali

In [None]:
'''fine_tune_dict = {
    0: (0,0,0), # background
    1: (102,255,102), #ferro e alluminio
    2: (170, 240, 209), # Bottiglie plastica
    3: (250,50,83), # Bottiglie vetro
    4: (50,183,250), # Tappi e coperchi
    5: (221,255,51), # Vetro rotto
    6: (51,221,255), # Lattine
    7: (89,134, 179), # cartone
    8: (36,179,83), # Bicchieri
    9: (255,0,204), # Plastica generica
    10: (52,209,183), # Carta
    11: (250,125,187), # Sacchetti e buste di plastica
    12: (34,25,77), # Spazzatura generica
    13: (138,138,138), # Cannucce
    14: (255,96,55), # Polistirolo
    15: (184,61,245) # Sigarette
}'''

fine_tune_dict={
    0: (0,0,0),
    1: (51,221,255),
    2: (250,50,83),
    3: (52,209,183),
    4: (255,0,124),
    100: (255,96,55),
    5: (221,255,51),
    6: (36,179,83),
    7: (184,61,245)
}

creazione maschere RGB -> grayscale

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm

def replace_colors_with_grayscale(image_folder, fine_tune_dict, output_folder):
    """
    Sostituisce i colori in un set di immagini con i valori delle chiavi del dizionario in scala di grigi.
    """
    os.makedirs(output_folder, exist_ok=True)

    # Creazione della mappa colori
    color_map = {tuple(color): key for key, color in fine_tune_dict.items()}

    for filename in tqdm(os.listdir(image_folder)):
        img_path = os.path.join(image_folder, filename)
        img = cv2.imread(img_path)
        if img is None:
            continue

        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        gray_img = np.zeros(img_rgb.shape[:2], dtype=np.uint8)

        for rgb_color, gray_value in color_map.items():
            mask = np.all(img_rgb == rgb_color, axis=-1)
            gray_img[mask] = gray_value

        output_path = os.path.join(output_folder, filename)
        cv2.imwrite(output_path, gray_img)

    print(f"Processing completed. Grayscale images saved in {output_folder}")

# Esempio di utilizzo
image_folder = "/content/drive/MyDrive/TACO dataset/defaultannot"
output_folder = "/content/drive/MyDrive/TACO dataset/masks_new_categories"

replace_colors_with_grayscale(image_folder, fine_tune_dict, output_folder)


In [None]:
import os
import cv2
import numpy as np

def replace_pixel_values(folder_path, old_value, new_value):
    """
    Sostituisce tutti i pixel con valore old_value con new_value
    in tutte le immagini grayscale di una cartella.

    Le immagini originali vengono sovrascritte.

    Args:
        folder_path (str): path alla cartella contenente le immagini.
        old_value (int): valore pixel da sostituire.
        new_value (int): nuovo valore pixel.
    """
    # Controlla che i valori siano validi per immagini 8-bit
    if not (0 <= old_value <= 255 and 0 <= new_value <= 255):
        raise ValueError("old_value e new_value devono essere tra 0 e 255")

    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)

        # Controlla estensioni comuni di immagini
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
            # Leggi l'immagine in grayscale
            img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
            if img is None:
                continue  # Se non riesce a leggere, passa oltre

            # Sostituisci i pixel con valore old_value
            img[img == old_value] = new_value

            # Sovrascrivi l'immagine
            cv2.imwrite(file_path, img)

    print("Sostituzione completata.")

replace_pixel_values(
    folder_path="/content/drive/MyDrive/TACO dataset/bibu",
    old_value=2,   # valore da sostituire
    new_value=6    # nuovo valore
)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

x = "/content/drive/MyDrive/TACO dataset/masks_new_categories/7e2e4c62f0712d78_jpg.rf.eb8e4d1c5ad572cb57a8d17d1aa25858.png"
mask_image = Image.open(x).convert('L')
mask_image = mask_image.resize((256,256), Image.NEAREST)

plt.imshow(mask_image)
mask_array = np.array(mask_image)
print(np.unique(mask_array))

# Analisi dataset

Check coppie immagini-maschera spaiati

In [None]:
import os

folder_images = "/content/drive/MyDrive/TACO dataset/images_new_categories"
folder_masks = "/content/drive/MyDrive/TACO dataset/masks_new_categories"

images = {os.path.splitext(f)[0] for f in os.listdir(folder_images) }
masks = {os.path.splitext(f)[0] for f in os.listdir(folder_masks) }

images_senza_maschere = images - masks
maschere_senza_immagini = masks - images

print("Immagini senza maschere:", images_senza_maschere)
print("Maschere senza immagini:", maschere_senza_immagini)


Immagini senza maschere: set()
Maschere senza immagini: set()


conteggio immagini

In [None]:
import os

def count_files_in_folder(folder_path):
    return sum(1 for entry in os.scandir(folder_path) if entry.is_file())

folder = "/content/drive/MyDrive/TACO dataset/images_new_categories"
num_files = count_files_in_folder(folder)
print(f"Numero di file: {num_files}")

Numero di file: 3180


## Esempi

Esempi di coppie immagine-maschera

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Percorsi
img_folder = ''
masks_folder = ''

# Classi mappate (solo le principali con ID > 0)
class_labels = {
    0: 'Background',
    1: 'PLASTICA E POLIMERI',
    2: 'METALLI',
    3: 'VETRO',
    4: 'CARTA E CARTONE',
    5: 'POLISTIROLO',
    6: 'SIGARETTE',
    7: 'NON CLASSIFICATI',
}

# Colori associati (matplotlib tab20)
colors = plt.cm.get_cmap('tab20', 20)

def show_image_and_mask(image_path, mask_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Colora la maschera
    mask_rgb = np.zeros_like(img)
    for class_id, label in class_labels.items():
        mask_rgb[mask == class_id] = (np.array(colors(class_id))[:3] * 255).astype(np.uint8)

    # Estrai classi presenti
    present_ids = np.unique(mask)
    present_classes = [class_labels[class_id] for class_id in present_ids if class_id in class_labels]

    # Plot affiancato
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(img)
    axs[0].set_title("Immagine originale")
    axs[0].axis('off')

    axs[1].imshow(mask_rgb)
    axs[1].set_title("Maschera semantica")
    axs[1].axis('off')

    # Legenda
    patches = [plt.plot([],[], marker="s", ls="", mec=None,
                        color=colors(i)[:3], label=label)[0]
               for i, label in class_labels.items() if i in present_ids]
    plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

    # Stampa classi trovate
    print("Classi presenti nella maschera:")
    for class_id in present_ids:
        if class_id in class_labels:
            print(f"  {class_id}: {class_labels[class_id]}")
    print("-" * 50)


# 🔁 Visualizza N esempi a partire da un count arbitrario con filtro filename
N = 1500               # Numero massimo di immagini da visualizzare
start_count = 0        # Quante immagini saltare prima di iniziare a mostrare
filename_prefix = "unofficial"   # *** MODIFICA QUESTA STRINGA *** - lascia vuota per mostrare tutte le immagini

count = 0
shown = 0
for fname in sorted(os.listdir(img_folder)):
    if not fname.endswith('.jpg'):
        continue

    # *** NUOVO FILTRO *** - controlla se il filename inizia con la stringa specificata
    if filename_prefix and not fname.startswith(filename_prefix):
        continue

    name, _ = os.path.splitext(fname)
    mask_path = os.path.join(masks_folder, f"{name}.png")
    img_path = os.path.join(img_folder, fname)

    if os.path.exists(mask_path):
        if count >= start_count:
            print(f"Mostrando: {fname}, count {count}")
            show_image_and_mask(img_path, mask_path)
            shown += 1
            if shown >= N:
                break
        count += 1

# Stampa statistiche finali
if filename_prefix:
    print(f"\nFiltro applicato: filename che iniziano con '{filename_prefix}'")
print(f"Immagini mostrate: {shown}")
print(f"Immagini totali processate: {count}")

# Creazione split

In [None]:
import os
import shutil
import random
from collections import defaultdict
from PIL import Image
import numpy as np

def stratified_fixed_split(mask_folder, image_folder, dest_folder, train_size=1832, val_size=350, test_size=350):
    """
    Esegue lo split del dataset con numero fisso di immagini per split,
    mantenendo la distribuzione di classi più bilanciata possibile.

    Args:
        mask_folder (str): Percorso delle maschere.
        image_folder (str): Percorso delle immagini.
        dest_folder (str): Cartella di destinazione degli split.
        train_size (int): Numero di immagini per il training set.
        val_size (int): Numero di immagini per il validation set.
        test_size (int): Numero di immagini per il test set.
    """
    os.makedirs(dest_folder, exist_ok=True)
    for split in ['train', 'val', 'test']:
        os.makedirs(os.path.join(dest_folder, split, 'images'), exist_ok=True)
        os.makedirs(os.path.join(dest_folder, split, 'masks'), exist_ok=True)

    # Carica le maschere e assegna le classi
    class_to_files = defaultdict(list)
    all_files = sorted([f for f in os.listdir(mask_folder) if f.lower().endswith('.png')])

    for fname in all_files:
        mask_path = os.path.join(mask_folder, fname)
        mask = np.array(Image.open(mask_path))
        unique_classes = set(np.unique(mask))
        unique_classes.discard(0)  # rimuovi sfondo
        for cls in unique_classes:
            class_to_files[cls].append(fname)

    # Filtra duplicati tra le classi e mescola
    unique_files = list(set(f for files in class_to_files.values() for f in files))
    random.shuffle(unique_files)

    # Split fissato
    total_required = train_size + val_size + test_size
    if total_required > len(unique_files):
        raise ValueError(f"Richiesti {total_required} file ma disponibili solo {len(unique_files)}.")

    selected_files = unique_files[:total_required]
    train_files = selected_files[:train_size]
    val_files = selected_files[train_size:train_size + val_size]
    test_files = selected_files[train_size + val_size:]

    # Funzione di copia
    def copy_files(file_list, split_name):
        for fname in file_list:
            src_img = os.path.join(image_folder, fname.replace('.png', '.jpg'))
            if not os.path.exists(src_img):
                src_img = os.path.join(image_folder, fname.replace('.png', '.JPG'))  # supporta anche JPG
            dst_img = os.path.join(dest_folder, split_name, 'images', os.path.basename(src_img))
            shutil.copy2(src_img, dst_img)

            src_mask = os.path.join(mask_folder, fname)
            dst_mask = os.path.join(dest_folder, split_name, 'masks', fname)
            shutil.copy2(src_mask, dst_mask)

    # Copia i file
    print("Copia file...")
    copy_files(train_files, 'train')
    copy_files(val_files, 'val')
    copy_files(test_files, 'test')

    print(f"Split completato:")
    print(f"Train: {len(train_files)} immagini")
    print(f"Val: {len(val_files)} immagini")
    print(f"Test: {len(test_files)} immagini")

stratified_fixed_split(
    mask_folder='/content/drive/MyDrive/TACO dataset/masks_new_categories',
    image_folder='/content/drive/MyDrive/TACO dataset/images_new_categories',
    dest_folder='/content/drive/MyDrive/TACO dataset/dataset_split',
    train_size=2700,
    val_size=480,
    test_size=0
)

Copia file...
Split completato:
Train: 2700 immagini
Val: 480 immagini
Test: 0 immagini


# Conteggio occorrenze classi

In [None]:
import os
from collections import defaultdict
from PIL import Image
import numpy as np

def count_class_occurrences_in_masks(mask_folder):
    """
    Conta quante volte ciascun valore di classe appare almeno una volta in ogni maschera.

    Args:
        mask_folder (str): Cartella contenente le maschere semantiche (.png)

    Returns:
        dict: classe -> numero di immagini in cui appare
    """
    class_counts = defaultdict(int)

    mask_files = [f for f in os.listdir(mask_folder) if f.lower().endswith(".png")]

    for mask_file in mask_files:
        mask_path = os.path.join(mask_folder, mask_file)
        mask = np.array(Image.open(mask_path))

        # Estrai i valori unici presenti nella maschera
        unique_values = np.unique(mask)

        for val in unique_values:
            class_counts[int(val)] += 1

    return dict(class_counts)

In [None]:
mask_dir = "/content/drive/MyDrive/TACO dataset/masks_new_categories"
counts = count_class_occurrences_in_masks(mask_dir)

# Stampa ordinata
for class_id in sorted(counts):
    print(f"Classe {class_id}: presente in {counts[class_id]} maschere")

'''
    0: 'Background',
    1: 'PLASTICA E POLIMERI',
    2: 'METALLI',
    3: 'VETRO',
    4: 'CARTA E CARTONE',
    5: 'POLISTIROLO',
    6: 'SIGARETTE',
    7: 'NON CLASSIFICATI'
'''

- Classe 0: presente in 3181 maschere
- Classe 1: presente in 1183 maschere
- Classe 2: presente in 720 maschere
- Classe 3: presente in 609 maschere
- Classe 4: presente in 726 maschere
- Classe 5: presente in 604 maschere
- Classe 6: presente in 613 maschere
- Classe 7: presente in 523 maschere