In [15]:
# https://huggingface.co/learn/computer-vision-course/unit3/vision-transformers/vision-transformer-for-objection-detection

In [16]:
import albumentations
import numpy as np
from datasets import load_from_disk
from transformers import (
    AutoImageProcessor,
    AutoModelForObjectDetection,
    Trainer,
    TrainingArguments,
)

In [17]:
data = load_from_disk("../data/container")

In [None]:
data

In [19]:
preprocessor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")

In [None]:
data["train"][0]

In [21]:
aug = albumentations.Compose(
    transforms=[
        albumentations.Resize(480, 480),
        albumentations.HorizontalFlip(p=1),
        albumentations.RandomBrightnessContrast(p=1.0),
    ],
    bbox_params=albumentations.BboxParams(format="coco", label_fields=["category"]),
)

In [22]:
def datapipe(data):
    images, bboxes, areas, categories, targets = [], [], [], [], []

    image_ids = data["image_id"]

    for image, objects in zip(data["image"], data["objects"]):
        image = np.array(image.convert("RGB"))[:, :, ::-1]
        out = aug(image=image, bboxes=objects["bbox"], category=objects["id"])

        areas.append(objects["area"])

        images.append(out["image"])
        bboxes.append(out["bboxes"])
        categories.append(out["category"])

    for image_id, category, area, box in zip(image_ids, categories, areas, bboxes):
        annotations = []

        for _category, _area, _box in zip(category, area, box):
            new_ann = {
                "image_id": image_id,
                "category_id": _category,
                "isCrowd": 0,
                "area": _area,
                "bbox": list(_box),
            }
            annotations.append(new_ann)
        targets.append({"image_id": image_id, "annotations": annotations})
    return preprocessor(images=images, annotations=targets, return_tensors="pt")

In [23]:
train_data = data["train"].with_transform(datapipe)
# val_data = data["validation"].with_transform(datapipe)
test_data = data["test"].with_transform(datapipe)

In [24]:
def collate_fn(batch):
    pixel_values = [item["pixel_values"] for item in batch]
    output = preprocessor.pad(pixel_values, return_tensors="pt")
    labels = [item["labels"] for item in batch]

    ret = {}

    ret["pixel_values"] = output["pixel_values"]
    ret["pixel_mask"] = output["pixel_mask"]
    ret["labels"] = labels

    return ret

In [None]:
print(train_data[0])

In [None]:
id2label = {0: "container"}
label2id = {"container": 0}


model = AutoModelForObjectDetection.from_pretrained(
    "facebook/detr-resnet-50",
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

In [29]:
training_args = TrainingArguments(
    output_dir="detr-resnet-50-container-finetuned",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    eval_strategy="epoch",
    learning_rate=1e-5,
    weight_decay=1e-4,
    save_total_limit=2,
    remove_unused_columns=False,
)

# Define the trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=preprocessor,
)

In [None]:
trainer.train()