In [59]:
import torch
import torchvision
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image

In [60]:
color_to_class={
    (255, 197, 25): 0,      # Forklift
    (140, 255, 25): 1,      # Rack
    (140, 25, 255): 2,      # Crate
    (226, 255, 25): 3,      # Floor
    (255, 111, 25): 4,      # Railing
    (255, 25, 197): 5,      # Pallet
    (54, 255, 25): 6,       # Stillage
    (25, 255, 82): 7,       # iwhub
    (25, 82, 255): 8,       # Dolly
    (0, 0, 0): 9,           # Background
}

In [61]:
def rgb_to_class_map(image, color_to_class):
    
    # Ensure image is a numpy array
    if isinstance(image, Image.Image):
        image = np.array(image)

    # Initialize a single-channel output image
    height, width, _ = image.shape
    class_map = np.zeros((height, width), dtype=np.int64)

    # Iterate through the color map and apply class labels
    for rgb_value, class_id in color_to_class.items():
        class_map[np.all(image == rgb_value, axis=-1)] = int(class_id)

    # Return the class map with shape (height, width)
    return class_map

In [62]:
class DataSet(Dataset):

    def __init__(self, image_dir, mask_dir, transform=None) -> None:
        super(DataSet, self).__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)
        self.masks = os.listdir(mask_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index])

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("RGB"))
        mask = rgb_to_class_map(mask, color_to_class)

        if self.transform is not None:
            augemantations = self.transform(image=image, mask=mask)
            image = augemantations['image']
            mask = augemantations['mask']
            
        return image, mask