In [1]:
import torchvision.transforms.v2 as T

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.tv_tensors import Mask
from torchvision.datasets import VOCSegmentation
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torch.utils.data import DataLoader, Dataset

In [3]:
import torchvision
torchvision.__version__

'0.18.0'

In [5]:
def f(x):
    print(type(x))
    x = x.clone()  # Ensure no in-place modifications
    x[x == 255] = 0
    return x


l = T.Lambda(f, Mask)

In [6]:
transform = T.Compose([
    T.ToImage(),
    l,
    T.RandomResizedCrop(size=(224, 224), antialias=True, scale=(0.5, 1.0)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [7]:
ds = VOCSegmentation(root="/Users/ts/Datasets/", year="2012", image_set="train", download=True, transforms=transform)
ds = wrap_dataset_for_transforms_v2(ds)

len(ds)

Using downloaded and verified file: /Users/ts/Datasets/VOCtrainval_11-May-2012.tar
Extracting /Users/ts/Datasets/VOCtrainval_11-May-2012.tar to /Users/ts/Datasets/


1464

In [8]:
img, mask = ds[0]
img.shape, mask.unique()

<class 'torchvision.tv_tensors._mask.Mask'>


(torch.Size([3, 224, 224]), tensor([ 0,  1, 15], dtype=torch.uint8))

In [31]:
img, mask = transform(img, mask)
img.shape, mask.shape

(torch.Size([3, 224, 224]), torch.Size([1, 224, 224]))

In [16]:
dl = DataLoader(ds, batch_size=2, shuffle=False, num_workers=2)

In [17]:
next(iter(dl))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 316, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 173, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 173, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 145, in collate
    return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 213, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
  File "/Users/ts/miniconda3/envs/new_ml/lib/python3.9/site-packages/torchvision/tv_tensors/_tv_tensor.py", line 77, in __torch_function__
    output = func(*args, **kwargs or dict())
RuntimeError: stack expects each tensor to be equal size, but got [1, 281, 500] at entry 0 and [1, 375, 500] at entry 1


In [3]:
out = torch.randn(8, 197, 768)

In [4]:
class SegmentationHead(nn.Module):
    def __init__(self,
                 in_channels: int,
                 width: int,
                 height: int,
                 num_classes: int,
                 ):
        super().__init__()
        self.in_channels = in_channels
        self.width = width
        self.height = height
        self.classifier = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(-1, self.height, self.width, self.in_channels) # (bs, num_tokens, c) -> (bs, h, w, c)
        x = x.permute(0, 3, 1, 2) # (bs, h, w, c) -> (bs, c, h, w)
        return self.classifier(x) # (bs, c, h, w) -> (bs, num_classes, h, w)

head = SegmentationHead(in_channels=768, width=14, height=14, num_classes=2)

In [7]:
logits = head(out[:, 1:, :])
logits.shape

torch.Size([8, 2, 14, 14])

In [8]:
logits = F.interpolate(logits, size=(224, 224), mode="bilinear", align_corners=False)

In [11]:
logits.squeeze().shape

torch.Size([8, 2, 224, 224])