Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and simplify semantic-segmentation example #30145

Merged
merged 9 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/pytorch/_tests_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ torchaudio
jiwer
librosa
evaluate >= 0.2.0
albumentations
5 changes: 4 additions & 1 deletion examples/pytorch/semantic-segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ The script leverages the [🤗 Trainer API](https://huggingface.co/docs/transfor

Here we show how to fine-tune a [SegFormer](https://huggingface.co/nvidia/mit-b0) model on the [segments/sidewalk-semantic](https://huggingface.co/datasets/segments/sidewalk-semantic) dataset:

In order to use `segments/sidewalk-semantic`:
- Log in to Hugging Face with `huggingface-cli login` (token can be accessed [here](https://huggingface.co/settings/tokens)).
- Accept terms of use for `sidewalk-semantic` on [dataset page](https://huggingface.co/datasets/segments/sidewalk-semantic).

```bash
python run_semantic_segmentation.py \
--model_name_or_path nvidia/mit-b0 \
Expand All @@ -105,7 +109,6 @@ python run_semantic_segmentation.py \
--remove_unused_columns False \
--do_train \
--do_eval \
--evaluation_strategy steps \
--push_to_hub \
--push_to_hub_model_id segformer-finetuned-sidewalk-10k-steps \
--max_steps 10000 \
Expand Down
6 changes: 4 additions & 2 deletions examples/pytorch/semantic-segmentation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
git://github.com/huggingface/accelerate.git
datasets >= 2.0.0
torch >= 1.3
evaluate
accelerate
evaluate
Pillow
albumentations
210 changes: 54 additions & 156 deletions examples/pytorch/semantic-segmentation/run_semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@
import json
import logging
import os
import random
import sys
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import Optional

import albumentations as A
import evaluate
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from torch import nn
from torchvision import transforms
from torchvision.transforms import functional

import transformers
from transformers import (
Expand All @@ -57,118 +56,19 @@
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")


def pad_if_smaller(img, size, fill=0):
size = (size, size) if isinstance(size, int) else size
original_width, original_height = img.size
pad_height = size[1] - original_height if original_height < size[1] else 0
pad_width = size[0] - original_width if original_width < size[0] else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
def reduce_labels_transform(labels: np.ndarray, **kwargs) -> np.ndarray:
"""Set `0` label as with value 255 and then reduce all other labels by 1.

Example:
Initial class labels: 0 - background; 1 - road; 2 - car;
Transformed class labels: 255 - background; 0 - road; 1 - car;

class Compose:
def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target


class Identity:
def __init__(self):
pass

def __call__(self, image, target):
return image, target


class Resize:
def __init__(self, size):
self.size = size

def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target


class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size

def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target


class RandomCrop:
def __init__(self, size):
self.size = size if isinstance(size, tuple) else (size, size)

def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, self.size)
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target


class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob

def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target


class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target


class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype

def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target


class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target


class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255

target = Image.fromarray(target)
return image, target
**kwargs are required to use this function with albumentations.
"""
labels[labels == 0] = 255
labels = labels - 1
labels[labels == 254] = 255
return labels


@dataclass
Expand Down Expand Up @@ -365,7 +265,7 @@ def main():
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()}

# Load the mean IoU metric from the datasets package
# Load the mean IoU metric from the evaluate package
metric = evaluate.load("mean_iou", cache_dir=model_args.cache_dir)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
Expand Down Expand Up @@ -424,64 +324,62 @@ def compute_metrics(eval_pred):
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# `reduce_labels` is a property of dataset labels, in case we use image_processor
# pretrained on another dataset we should override the default setting
image_processor.do_reduce_labels = data_args.reduce_labels

# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
# Define transforms to be applied to each image and target.
if "shortest_edge" in image_processor.size:
# We instead set the target size as (shortest_edge, shortest_edge) to here to ensure all images are batchable.
size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])
height, width = image_processor.size["shortest_edge"], image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
train_transforms = Compose(
height, width = image_processor.size["height"], image_processor.size["width"]
train_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
RandomCrop(size=size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
# pad image with 255, because it is ignored by loss
A.PadIfNeeded(min_height=height, min_width=width, border_mode=0, value=255, p=1.0),
A.RandomCrop(height=height, width=width, p=1.0),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)
# Define torchvision transform to be applied to each image.
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
val_transforms = A.Compose(
[
ReduceLabels() if data_args.reduce_labels else Identity(),
Resize(size=size),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
A.Lambda(
name="reduce_labels",
mask=reduce_labels_transform if data_args.reduce_labels else None,
p=1.0,
),
A.Resize(height=height, width=width, p=1.0),
A.Normalize(mean=image_processor.image_mean, std=image_processor.image_std, max_pixel_value=255.0, p=1.0),
ToTensorV2(),
]
)

def preprocess_train(example_batch):
def preprocess_batch(example_batch, transforms: A.Compose):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
transformed = transforms(image=np.array(image.convert("RGB")), mask=np.array(target))
pixel_values.append(transformed["image"])
labels.append(transformed["mask"])

encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
encoding["pixel_values"] = torch.stack(pixel_values).to(torch.float)
encoding["labels"] = torch.stack(labels).to(torch.long)

return encoding

def preprocess_val(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)

encoding = {}
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)

return encoding
# Preprocess function for dataset should have only one argument,
# so we use partial to pass the transforms
preprocess_train_batch_fn = partial(preprocess_batch, transforms=train_transforms)
preprocess_val_batch_fn = partial(preprocess_batch, transforms=val_transforms)

if training_args.do_train:
if "train" not in dataset:
Expand All @@ -491,7 +389,7 @@ def preprocess_val(example_batch):
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms
dataset["train"].set_transform(preprocess_train)
dataset["train"].set_transform(preprocess_train_batch_fn)

if training_args.do_eval:
if "validation" not in dataset:
Expand All @@ -501,7 +399,7 @@ def preprocess_val(example_batch):
dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
)
# Set the validation transforms
dataset["validation"].set_transform(preprocess_val)
dataset["validation"].set_transform(preprocess_val_batch_fn)

# Initialize our trainer
trainer = Trainer(
Expand Down