In [1]:
# System
import os
import time
from IPython.display import clear_output
from tqdm.auto import tqdm

# Standard libs
import numpy as np
import pandas as pd
import random

# Plotting
import matplotlib.pyplot as plt

# Image utils
from PIL import Image
import cv2
from glob import glob

# PyTorch 
import torch
import torchvision
import torch.nn.functional as F
import torchvision.datasets
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from torchvision.io import read_image
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader

In [8]:
# DATA_PATH = '/kaggle/input/german-traffic-sign-detection-benchmark-gtsdb'
DATA_PATH = 'DATA/GTSDB/'
TRAIN_DATA_PATH = os.path.join(DATA_PATH, 'TestIJCNN2013/TestIJCNN2013Download')
TEST_DATA_PATH = os.path.join(DATA_PATH, 'TrainIJCNN2013/TrainIJCNN2013')

In [11]:
class TrafficSignsDataset(Dataset):
    def __init__(self, img_dir, annotations_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = pd.read_csv(annotations_file, sep=";", header=None,
                                        names=["filename", "x1", "y1", "x2", "y2", "class"])
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.annotations.iloc[idx, 0])
        image = Image.open(img_path).convert("RGB")
        boxes = self.annotations.iloc[idx, 1:5].values.astype(np.float32)
        labels = self.annotations.iloc[idx, 5]
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        
        if self.transform:
            image, target = self.transform(image, target)
            
        return image, target

In [None]:
class myDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "imagesf"))))
 
    def __getitem__(self, idx):
        # Load image path
        img_path = os.path.join(self.root, "imagesf", self.imgs[idx])
        #Load image as PIL
        img = Image.open(img_path).convert("RGB")        
        # Get objects in the image
        objects = dic[self.imgs[idx]]
        # Get bounding box coordinates for each object in image
        boxes = []
        labels = []
        for obj in objects:
            #print(idx, obj[-1], self.imgs)
            name = obj[-1]
            labels.append(np.int(name))
            #Get bounding box coordinates
            xmin = np.float(obj[0])
            ymin = np.float(obj[1])
            xmax = np.float(obj[2])
            ymax = np.float(obj[3])
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)        
 
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((len(objects),), dtype=torch.int64)
 
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
 
        if self.transforms is not None:
            # Note that target (including bbox) is also transformed\enhanced here, which is different from transforms from torchvision import
            # Https://github.com/pytorch/vision/tree/master/references/detectionOfTransforms.pyThere are examples of target transformations when RandomHorizontalFlip
            img, target = self.transforms(img, target)
 
        return img, target
 
    def __len__(self):
        return len(self.imgs)

In [12]:
dt = TrafficSignsDataset(TRAIN_DATA_PATH, os.path.join(DATA_PATH, 'gt.txt'))

In [13]:
next(iter(dt))

(<PIL.Image.Image image mode=RGB size=1360x800>,
 {'boxes': array([774., 411., 815., 446.], dtype=float32),
  'labels': 11,
  'image_id': tensor([0])})