
# Transforms v2: End-to-end object detection/segmentation example

<div class="alert alert-info"><h4>Note</h4><p>Try on [collab](https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_transforms_e2e.ipynb)
    or `go to the end <sphx_glr_download_auto_examples_transforms_plot_transforms_e2e.py>` to download the full example code.</p></div>

Object detection and segmentation tasks are natively supported:
``torchvision.transforms.v2`` enables jointly transforming images, videos,
bounding boxes, and masks.

This example showcases an end-to-end instance segmentation training case using
Torchvision utils from ``torchvision.datasets``, ``torchvision.models`` and
``torchvision.transforms.v2``. Everything covered here can be applied similarly
to object detection or semantic segmentation tasks.


In [1]:
HEIGHT = 3984
WIDTH  = 5312

SCALE = 0.75

In [2]:
#!pip install torch torchvision matplotlib tqdm pycocotools transforms

In [3]:
import pathlib
import os

import torch
import torch.utils.data

from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

from tqdm import tqdm

torch.manual_seed(0)

# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../data") / "flight263_COCO"
IMAGES_PATH = str(ROOT / "img")
RAW_ANNOTATIONS_PATH = str(ROOT / "annotations" / "instances_default.json")
ANNOTATIONS_PATH = ROOT / "annotations/instances_annotated.json"
from helpers import plot

## Dataset preparation

We start off by loading the :class:`~torchvision.datasets.CocoDetection` dataset to have a look at what it currently
returns.



In [4]:
coco_dataset = datasets.CocoDetection(IMAGES_PATH, RAW_ANNOTATIONS_PATH)

sample = coco_dataset[25] # Not all images have annotations
print(sample)
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


(<PIL.Image.Image image mode=RGB size=3984x5312 at 0x7F0DC20ECD00>, [{'id': 1, 'image_id': 26, 'category_id': 1, 'segmentation': [], 'area': 335.8352000000005, 'bbox': [106.13, 2526.16, 17.24, 19.48], 'iscrowd': 0, 'attributes': {'text': '?', 'bg_color': 'orange', 'txt_color': 'white', 'shape': 'circle', 'occluded': False, 'rotation': 0.0, 'track_id': 0, 'keyframe': True}}])
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['id', 'image_id', 'category_id', 'segmentation', 'area', 'bbox', 'iscrowd', 'attributes'])


Create a subset with only images that have bbox annotations

In [5]:
def create_annotated_subset():    
    idx_has_ann = []
    for i, entry in tqdm(enumerate(coco_dataset)):
        if len(entry[1]) > 0:
            idx_has_ann += [i]


    print(len(idx_has_ann))

    import json

    with open(RAW_ANNOTATIONS_PATH, "r") as f:
        instances = json.load(f)

    idxs = [x+1 for x in idx_has_ann]
    instances["images"] = [x for x in instances["images"] if x["id"] in idxs]

    with open(ROOT / "annotations/instances_annotated.json", "w") as f:
        json.dump(instances, f)

if not os.path.isfile(ROOT / "annotations/instances_annotated.json"):
    create_annotated_subset()

In [6]:
coco_dataset = datasets.CocoDetection(IMAGES_PATH, ROOT / "annotations/instances_annotated.json")

# TODO: create sliding window as `transforms` arg

sample = coco_dataset[0] # Not all images have annotations
print(sample)
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
(<PIL.Image.Image image mode=RGB size=3984x5312 at 0x7F0EC4AB3A60>, [{'id': 1, 'image_id': 26, 'category_id': 1, 'segmentation': [], 'area': 335.8352000000005, 'bbox': [106.13, 2526.16, 17.24, 19.48], 'iscrowd': 0, 'attributes': {'text': '?', 'bg_color': 'orange', 'txt_color': 'white', 'shape': 'circle', 'occluded': False, 'rotation': 0.0, 'track_id': 0, 'keyframe': True}}])
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['id', 'image_id', 'category_id', 'segmentation', 'area', 'bbox', 'iscrowd', 'attributes'])


Torchvision datasets preserve the data structure and types as it was intended
by the datasets authors. So by default, the output structure may not always be
compatible with the models or the transforms.

To overcome that, we can use the
:func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For
:class:`~torchvision.datasets.CocoDetection`, this changes the target
structure to a single dictionary of lists:



In [7]:
dataset = datasets.wrap_dataset_for_transforms_v2(coco_dataset, target_keys=("boxes", "labels"))

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }")

type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'dict'>
target.keys() = dict_keys(['boxes', 'labels'])
type(target['boxes']) = <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
type(target['labels']) = <class 'torch.Tensor'>


We used the ``target_keys`` parameter to specify the kind of output we're
interested in. Our dataset now returns a target which is dict where the values
are `TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
subclasses). We're dropped all unncessary keys from the previous output, but
if you need any of the original keys e.g. "image_id", you can still ask for
it.

<div class="alert alert-info"><h4>Note</h4><p>If you just want to do detection, you don't need and shouldn't pass
    "masks" in ``target_keys``: if masks are present in the sample, they will
    be transformed, slowing down your transformations unnecessarily.</p></div>

As baseline, let's have a look at a sample without transformations:



In [8]:
import helpers
# helpers.plot([dataset[0], dataset[10]])

## Transforms

Let's now define our pre-processing transforms. All the transforms know how
to handle images, bouding boxes and masks when relevant.

Transforms are typically passed as the ``transforms`` parameter of the
dataset so that they can leverage multi-processing from the
:class:`torch.utils.data.DataLoader`.



In [9]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.Resize(int(HEIGHT * SCALE)),
        # v2.RandomPhotometricDistort(p=1),
        v2.RandomPerspective(distortion_scale=0.6, p=1.0),
        v2.RandomRotation(degrees=(0, 180)),
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
        v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels"])

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


A few things are worth noting here:

- We're converting the PIL image into a
  :class:`~torchvision.transforms.v2.Image` object. This isn't strictly
  necessary, but relying on Tensors (here: a Tensor subclass) will
  `generally be faster <transforms_perf>`.
- We are calling :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` to
  make sure we remove degenerate bounding boxes, as well as their
  corresponding labels and masks.
  :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` should be placed
  at least once at the end of a detection pipeline; it is particularly
  critical if :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.

Let's look how the sample looks like with our augmentation pipeline in place:



In [10]:
# helpers.plot([dataset[25], dataset[28]])

We can see that the color of the images were distorted, zoomed in or out, and flipped.
The bounding boxes and the masks were transformed accordingly. And without any further ado, we can start training.

## Data loading and training loop

Below we're using Mask-RCNN which is an instance segmentation model, but
everything we've covered in this tutorial also applies to object detection and
semantic segmentation tasks.



In [11]:
try:
    del train_one_epoch
except:
    pass

In [12]:
import math
import torch.distributed as dist
from engine import train_one_epoch, evaluate

In [13]:
train_dataset, test_dataset = tuple(torch.utils.data.random_split(dataset, [0.8,0.2]))

data_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    drop_last=True, # Drop remainder
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bouding
    # boxes varies between the images of a same batch.
    collate_fn=lambda batch: tuple(zip(*batch)),
)

data_loader_test = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    drop_last=True, # Drop remainder
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bouding
    # boxes varies between the images of a same batch.
    collate_fn=lambda batch: tuple(zip(*batch)),
)

In [14]:
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = None

In [15]:
import torch, gc
if model:
    del model
gc.collect()
torch.cuda.empty_cache()

In [16]:
model = models.get_model("ssdlite320_mobilenet_v3_large", weights=None, weights_backbone=None)

# !pip install -U ultralytics
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

In [17]:
# move model to the right device
model.to(device)

SSD(
  (backbone): SSDLiteFeatureExtractorMobileNet(
    (features): Sequential(
      (0): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
          (2): Hardswish()
        )
        (1): InvertedResidual(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
              (2): ReLU(inplace=True)
            )
            (1): Conv2dNormActivation(
              (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            )
          )
        )
        (2): Invert

In [19]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(
    params,
    lr=0.005,
    # momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

num_epochs = 20

os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True"

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

    # print(f"{[img.shape for img in imgs] = }")
    # print(f"{[type(target) for target in targets] = }")
    # for name, loss_val in loss_dict.items():
    #     print(f"{name:<20}{loss_val:.3f}")

TypeError: Adam.__init__() got an unexpected keyword argument 'momentum'

In [None]:
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
print(torch.cuda.memory_summary())

## Training References

From there, you can check out the [torchvision references](https://github.com/pytorch/vision/tree/main/references) where you'll find
the actual training scripts we use to train our models.

**Disclaimer** The code in our references is more complex than what you'll
need for your own use-cases: this is because we're supporting different
backends (PIL, tensors, TVTensors) and different transforms namespaces (v1 and
v2). So don't be afraid to simplify and only keep what you need.

