# Object detection fine tuning for custom dataset

In [None]:
!pip install -U -q datasets transformers[torch] timm # evaluate  albumentations accelerate

In [None]:
from datasets import load_dataset

dataset = load_dataset("keremberke/german-traffic-sign-detection", name="full")
dataset

In [None]:
dataset["train"][0]

In [None]:
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [None]:
import numpy as np
from PIL import Image, ImageDraw

id2label = {
    0: 'animals', 1: 'construction', 2: 'cycles crossing', 3: 'danger',
    4: 'no entry', 5: 'pedestrian crossing', 6: 'school crossing', 7: 'snow',
    8: 'stop', 9: 'bend', 10: 'bend left', 11: 'bend right', 12: 'give way',
    13: 'go left', 14: 'go left or straight', 15: 'go right', 16: 'go right or straight',
    17: 'go straight', 18: 'keep left', 19: 'keep right', 20: 'no overtaking',
    21: 'no overtaking -trucks-', 22: 'no traffic both ways', 23: 'no trucks',
    24: 'priority at next intersection', 25: 'priority road', 26: 'restriction ends',
    27: 'restriction ends -overtaking -trucks--', 28: 'restriction ends -overtaking-',
    29: 'restriction ends 80', 30: 'road narrows', 31: 'roundabout',
    32: 'slippery road', 33: 'speed limit 100', 34: 'speed limit 120',
    35: 'speed limit 20', 36: 'speed limit 30', 37: 'speed limit 50',
    38: 'speed limit 60', 39: 'speed limit 70', 40: 'speed limit 80',
    41: 'traffic signal', 42: 'uneven road'
}

label2id = {v: k for k, v in id2label.items()}


def draw_image_from_idx(dataset, idx):
    sample = dataset[idx]
    image = sample["image"]
    annotations = sample["objects"]
    draw = ImageDraw.Draw(image)
    width, height = sample["width"], sample["height"]

    for i in range(len(annotations["id"])):
        box = annotations["bbox"][i]
        class_idx = annotations["id"][i]
        x, y, w, h = tuple(box)

        if max(box) <= 1.0:
            x1, y1 = int(x * width), int(y * height)
            x2, y2 = int((x + w) * width), int((y + h) * height)
        else:
            x1, y1 = int(x), int(y)
            x2, y2 = int(x + w), int(y + h)

        draw.rectangle((x1, y1, x2, y2), outline="red", width=3)
        draw.text((x1, y1), id2label[annotations["category"][i]], fill="green")

    return image


draw_image_from_idx(dataset=train_dataset, idx=10)

In [None]:
import matplotlib.pyplot as plt


def plot_images(dataset, indices):
    num_rows = len(indices) // 3
    num_cols = 3
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 10))

    for i, idx in enumerate(indices):
        row = i // num_cols
        col = i % num_cols

        image = draw_image_from_idx(dataset, idx)

        axes[row, col].imshow(image)
        axes[row, col].axis("off")

    plt.tight_layout()
    plt.show()


plot_images(train_dataset, range(9))