This example illustrates all of what you need to know to get started with the new torchvision.transforms.v2 API. We’ll cover simple tasks like image classification, and more advanced ones like object detection / segmentation.

First, a bit of setup


In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import v2
from torchvision.io import read_image
from helpers import plot

In [None]:
plt.rcParams["savefig.bbox"] = 'tight'

torch.manual_seed(1)

# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")

# If you just care about image classification, things are very simple. A basic classification pipeline may look like this:


In [None]:
transform = v2.RandomCrop(size=(224, 224))
out = transform(img)

plot([img, out])

In [None]:
transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])

# Detection, Segmentation, Videos


In [None]:
from torchvision import tv_tensors

boxes = tv_tensors.BoundingBoxes(
    [
        [15, 10, 370, 510],
        [275, 340, 510, 510],
        [130, 345, 210, 425]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])

` TVTensors are torch.Tensor subclasses. The available TVTensors are Image, BoundingBoxes, Mask, and Video.`

`TVTensors look and feel just like regular tensors - they are tensors. Everything that is supported on a plain torch.Tensor like .sum() or any torch.\* operator will also work on a TVTensor:`


In [5]:
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))

print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")

isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25151131)
