In [6]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import torch
import json
import os
from PIL import Image, ImageDraw, ImageOps

class ClimbingHoldDataset(Dataset):
    def __init__(self, annotations_dir, images_dir, output_size=(128, 128)):
        self.images_dir = images_dir
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ])
        self.output_size = output_size
        self.holds = []

        for json_file in os.listdir(annotations_dir):
            if json_file.endswith(".json"):
                json_path = os.path.join(annotations_dir, json_file)
                with open(json_path, 'r') as file:
                    data = json.load(file)

                images = {img['id']: img['file_name'] for img in data.get('images', [])}
                annotations = data.get('annotations', [])
                difficulties = {cat['id']: cat['name'] for cat in data.get('categories', [])}

                for annotation in annotations:
                    file_name = images.get(annotation.get("image_id"))
                    hold_data = {
                        "image_id": file_name,
                        "difficulty": difficulties.get(annotation.get("category_id")),
                        "type": annotation["attributes"].get("Type"),
                        "orientation": annotation["attributes"].get("Orientation"),
                        "bbox": annotation.get("bbox"),
                        "segmentation": annotation.get("segmentation"),
                    }
                    self.holds.append(hold_data)

    def __len__(self):
        return len(self.holds)

    def __getitem__(self, idx):
        hold_data = self.holds[idx]
        image_id = hold_data["image_id"]
        image_path = os.path.join(self.images_dir, image_id)
        image = Image.open(image_path)
        rotated_image = ImageOps.exif_transpose(image)

        bbox = hold_data["bbox"]
        x_min, y_min, width, height = bbox
        cropped_image = rotated_image.crop((x_min, y_min, x_min + width, y_min + height))
        cropped_image = cropped_image.resize(self.output_size)

        if self.transform:
            cropped_image = self.transform(cropped_image)

        return {
            "image": cropped_image,
            "type": self._map_type(hold_data["type"]),
            "orientation": self._map_orientation(hold_data["orientation"]),
        }

    def _map_type(self, type_label):
        types = ['Jug', 'Sloper', 'Crimp', 'Jib', 'Pinch', 'Pocket', 'Edge']
        return types.index(type_label) if type_label in types else -1

    def _map_orientation(self, orientation_label):
        orientations = ['Up', 'Down', 'Side', 'UpAng', 'DownAng']
        return orientations.index(orientation_label) if orientation_label in orientations else -1

In [7]:
annotations_dir = "data/annotations"
images_dir = "data/images"

dataset = ClimbingHoldDataset(annotations_dir, images_dir)

print("Dataset size:", len(dataset))

sample = dataset[0]

print("Image tensor shape:", sample["image"].shape)
print("Type index:", sample["type"])
print("Orientation index:", sample["orientation"])

Dataset size: 1605
Image tensor shape: torch.Size([3, 224, 224])
Type index: 0
Orientation index: 3
