Skip to content

Commit

Permalink
Fix and simplify semantic-segmentation example (#30145)
Browse files Browse the repository at this point in the history
* Remove unused augmentation

* Fix pad_if_smaller() and remove unused augmentation

* Add indentation

* Fix requirements

* Update dataset use instructions

* Replace transforms with albumentations

* Replace identity transform with None

* Fixing formatting

* Fixed comment place
  • Loading branch information
qubvel authored and ArthurZucker committed Apr 22, 2024
1 parent 3260549 commit 9e3aff0
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 318 deletions.
1 change: 1 addition & 0 deletions examples/pytorch/_tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ torchaudio
jiwer
librosa
evaluate >= 0.2.0
albumentations
5 changes: 4 additions & 1 deletion examples/pytorch/semantic-segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ The script leverages the [馃 Trainer API](https://huggingface.co/docs/transfor

Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:

In order to use `segments/sidewalk-semantic`:
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).

```bash
python run_semantic_segmentation.py \
--model_name_or_path nvidia/mit-b0 \
Expand All @@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
--remove_unused_columns False \
--do_train \
--do_eval \
--evaluation_strategy steps \
--push_to_hub \
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
--max_steps 10000 \
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/semantic-segmentation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
git://github.com/huggingface/accelerate.git
datasets >= 2.0.0
torch >= 1.3
evaluate
accelerate
evaluate
Pillow
albumentations
210 changes: 54 additions & 156 deletions examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@
import json
import logging
import os
import random
import sys
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import Optional

import albumentations as A
import evaluate
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision import transforms
from torchvision.transforms import functional

import transformers
from transformers import (
Expand All @@ -57,118 +56,19 @@
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")


def pad_if_smaller(img, size, fill=0):
size = (size, size) if isinstance(size, int) else size
original_width, original_height = img.size
pad_height = size[1] - original_height if original_height < size[1] else 0
pad_width = size[0] - original_width if original_width < size[0] else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
"""Set `0` label as with value 255 and then reduce all other labels by 1.
Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;
class Compose:
def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target


class Identity:
def __init__(self):
pass

def __call__(self, image, target):
return image, target


class Resize:
def __init__(self, size):
self.size = size

def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target


class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size

def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target


class RandomCrop:
def __init__(self, size):
self.size = size if isinstance(size, tuple) else (size, size)

def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target


class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob

def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target


class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target


class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype

def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target


class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target


class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255

target = Image.fromarray(target)
return image, target
**kwargs are required to use this function with albumentations.
"""
labels[labels == 0] = 255
labels = labels - 1
labels[labels == 254] = 255
return labels


@dataclass
Expand Down Expand Up @@ -365,7 +265,7 @@ def main():
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()}

# Load the mean IoU metric from the datasets package
# Load the mean IoU metric from the evaluate package
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
Expand Down Expand Up @@ -424,64 +324,62 @@ def compute_metrics(eval_pred):
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = data_args.reduce_labels

# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
# Define transforms to be applied to each image and target.
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
# pad image with 255, because it is ignored by loss
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
A.RandomCrop(height=height, width=width, p=1.0),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
# Define torchvision transform to be applied to each image.
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
val_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
A.Resize(height=height, width=width, p=1.0),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)

def preprocess_train(example_batch):
def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(transformed["image"])
labels.append(transformed["mask"])

encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels).to(torch.long)

return encoding

def preprocess_val(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)

encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)

return encoding
# Preprocess function for dataset should have only one argument,
# so we use partial to pass the transforms
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)

if training_args.do_train:
if "train" not in dataset:
Expand All @@ -491,7 +389,7 @@ def preprocess_val(example_batch):
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
dataset["train"].set_transform(preprocess_train)
dataset["train"].set_transform(preprocess_train_batch_fn)

if training_args.do_eval:
if "validation" not in dataset:
Expand All @@ -501,7 +399,7 @@ def preprocess_val(example_batch):
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
dataset["validation"].set_transform(preprocess_val)
dataset["validation"].set_transform(preprocess_val_batch_fn)

# Initialize our trainer
trainer = Trainer(
Expand Down

0 comments on commit 9e3aff0

Please sign in to comment.