In [2]:
import os
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch

class ScanNet2DSegmentation(Dataset):
    def __init__(self, root_dir, transform=None, target_transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform

        self.samples = []
        for scene in os.listdir(root_dir):
            scene_path = os.path.join(root_dir, scene)
            color_dir = os.path.join(scene_path, 'color')
            label_dir = os.path.join(scene_path, 'label')
            depth_dir = os.path.join(scene_path, 'depth')

            if not os.path.exists(label_dir):
                continue

            for fname in sorted(os.listdir(label_dir)):
                if not fname.endswith('.png'):
                    continue
                f_id = fname.split('.')[0]
                rgb_path = os.path.join(color_dir, f_id + '.jpg')
                depth_path = os.path.join(depth_dir, f_id + '.png')
                label_path = os.path.join(label_dir, f_id + '.png')

                if os.path.exists(rgb_path) and os.path.exists(depth_path):
                    self.samples.append((rgb_path, depth_path, label_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        rgb_path, depth_path, label_path = self.samples[idx]

        image = Image.open(rgb_path).convert("RGB")
        depth = Image.open(depth_path)
        label = Image.open(label_path)

        # Convert to tensors
        image = np.array(image) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1).float()  # CxHxW

        depth = torch.from_numpy(np.array(depth)).unsqueeze(0).float()  # 1xHxW
        label = torch.from_numpy(np.array(label)).long()  # HxW

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return {
            'rgb': image,
            'depth': depth,
            'label': label
        }


In [None]:
from torch.utils.data import DataLoader

dataset = ScanNet2DSegmentation('../scannet_frames_25k')
loader = DataLoader(dataset, batch_size=4, num_workers=2, shuffle=True)

batch = next(iter(loader))
print(batch['rgb'].shape)    # [4, 3, H, W]
print(batch['depth'].shape)  # [4, 1, H, W]
print(batch['label'].shape)  # [4, H, W]

torch.Size([4, 3, 968, 1296])
torch.Size([4, 1, 480, 640])
torch.Size([4, 968, 1296])
Sample 0: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
Sample 1: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
Sample 2: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
Sample 3: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
Sample 4: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
Sample 5: RGB shape torch.Size([3, 968, 1296]), Depth shape torch.Size([1, 480, 640]), Label shape torch.Size([968, 1296])
