1- loading data set using PyTorch

In [None]:
import os
import json
import matplotlib.pyplot as plt
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader

class ImageLabelDataset(Dataset):
    def __init__(self, root_dir, data_types):
        self.data = {}
        for data_type in data_types:
            data_path = os.path.join(root_dir, data_type)
            image_folder_path = os.path.join(data_path, "images")
            label_folder_path = os.path.join(data_path, "labels", "json")
            images = []
            labels = []
            for image_file in os.listdir(image_folder_path):
                image_path = os.path.join(image_folder_path, image_file)
                json_file = image_file.split('.')[0] + '.json'
                label_path = os.path.join(label_folder_path, json_file)
                if os.path.exists(label_path):
                    images.append(image_path)
                    labels.append(label_path)
            self.data[data_type] = {"images": images, "labels": labels}

    def __len__(self):
        total_images = 0
        for data_type in self.data:
            total_images += len(self.data[data_type]["images"])
        return total_images

    def __getitem__(self, idx):
     data_type = None
     image_idx = None
     for dt, data in self.data.items():
        if idx < len(data['images']):
            data_type = dt
            image_idx = idx
            break
        else:
            idx -= len(data['images'])
     if data_type is not None and image_idx is not None:
        image_path = self.data[data_type]["images"][image_idx]
        label_path = self.data[data_type]["labels"][image_idx]
        return image_path, label_path
     else:
        raise IndexError("Index out of range.")


2- visualizing some of the labeled images

In [None]:
def visualize_images(data_loader, num_batches):
    colors = {"bin": (255, 0, 0), "dolly": (0, 255, 0), "jack": (0, 0, 255)}
    
    for batch_idx, batch in enumerate(data_loader):
        if batch_idx >= num_batches:
            break
        
        images, labels = batch
        for image_path, label_path in zip(images, labels):
            image_cv2 = cv2.imread(image_path)
            image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
            plt.figure(figsize=(8, 6))
            plt.imshow(image_cv2)

            with open(label_path, 'r') as f:
                labels = json.load(f)

            for bbox in labels:
                left = bbox["Left"]
                top = bbox["Top"]
                right = bbox["Right"]
                bottom = bbox["Bottom"]
                class_name = bbox["ObjectClassName"]
                color = colors.get(class_name, (0, 0, 0))
            

                cv2.rectangle(image_cv2, (left, top), (right, bottom), color, 2)
                cv2.putText(image_cv2, class_name, (left, top - 5), cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 4)

            plt.imshow(image_cv2)
            plt.title(image_path)
            plt.show()

root_dir = "data"
dataset_train = ImageLabelDataset(root_dir, ["Training"])
dataset_test = ImageLabelDataset(root_dir, ["Testing"])
batch_size = 2

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

num_batches_to_visualize = 1
print("Visualizing training images:")
visualize_images(train_loader, num_batches_to_visualize)

print("Visualizing testing images:")
visualize_images(test_loader, num_batches_to_visualize)

3- adding agmentations

In [None]:
import albumentations as A

def visualize_images(data_loader, num_batches):
    colors = {"bin": (255, 0, 0), "dolly": (0, 255, 0), "jack": (0, 0, 255)}
    for batch_idx, batch in enumerate(data_loader):
        if batch_idx >= num_batches:
            break  
        augmentation = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
            A.RandomGamma(gamma_limit=(80, 120), p=0.5),
            A.Blur(blur_limit=(3, 7), p=0.5),
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']))
        images, labels = batch
        for image_path, label_path in zip(images, labels):
            image = cv2.imread(image_path)
            label_path = image_path.replace("images", "labels/json").split('.')[0] + '.json'
            
            if os.path.exists(label_path):
                with open(label_path, 'r') as f:
                    try:
                        labels = json.load(f)
                    except json.JSONDecodeError:
                        print(f"Error: Unable to parse JSON file: {label_path}")
                        continue

                bboxes = []
                class_labels = []
                for bbox in labels:
                    try:
                        left = bbox["Left"]
                        top = bbox["Top"]
                        right = bbox["Right"]
                        bottom = bbox["Bottom"]
                        class_name = bbox["ObjectClassName"]
                    except KeyError:
                        print(f"Error: Malformed label in file: {label_path}")
                        continue
                    bboxes.append([left, top, right, bottom])
                    class_labels.append(class_name)

                augmented = augmentation(image=image, bboxes=bboxes, class_labels=class_labels)
                augmented_image = augmented['image']
                augmented_bboxes = augmented['bboxes']
                class_labels = augmented['class_labels']

                for bbox, class_label in zip(augmented_bboxes, class_labels):
                    left, top, right, bottom = map(int, bbox)
                    color = colors.get(class_label, (0, 0, 0))
                    cv2.rectangle(augmented_image, (left, top), (right, bottom), color, 2)
                    cv2.putText(augmented_image, class_label, (left, top - 5), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1.3, color, 2)

                plt.imshow(cv2.cvtColor(augmented_image, cv2.COLOR_BGR2RGB))
                plt.title(image_path)
                plt.show()
           
root_dir = "data"
dataset_train = ImageLabelDataset(root_dir, ["Training"])
dataset_test = ImageLabelDataset(root_dir, ["Testing"])
batch_size = 3
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
num_batches_to_visualize = 1
print("visualizing training images Ag:")
visualize_images(train_loader, num_batches_to_visualize)

print("visualizing testing images Ag:")
visualize_images(test_loader, num_batches_to_visualize)
