<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/MaskFormer/Fine-tuning/Fine_tuning_MaskFormerForInstanceSegmentation_on_semantic_sidewalk.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Load data

Now let's the dataset from the hub.

"But how can I use my own dataset?" Glad you asked. I wrote a detailed guide for that [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation#note-on-custom-data).

In [None]:
import yaml
arch_config = "config/arch/fusion-512-OV.yml"
data_cfg = "config/labels/semantic-kitti-OV.yaml"
data_dir = '../SemanticKITTI/dataset'
ARCH = yaml.safe_load(open(arch_config, 'r'))
DATA = yaml.safe_load(open(data_cfg, 'r'))

Let's take a look at this dataset in more detail. It consists of 1000 examples:

In [None]:
from dataset.kitti.parser import Parser
from transformers import Mask2FormerImageProcessor
import torch

preprocessor = Mask2FormerImageProcessor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)

def collate_fn(batch):
    actuall_input = []
    dummy = []
    labels = []
    for sample in batch:
        proj_data, rgb_data = sample
        in_vol, proj_mask, proj_labels = proj_data[0:3]
        actuall_input.append(in_vol)
        dummy.append(in_vol[:3,:,:])
        labels.append(proj_labels)
        
    batch = preprocessor(
        dummy,
        segmentation_maps=labels,
        return_tensors="pt",
    )
    actuall_input = torch.stack(actuall_input,dim=0)
    batch["pixel_values"] = actuall_input
    batch["original_images"] = actuall_input
    batch["original_segmentation_maps"] = labels
    
    return batch


dataset = Parser(root=data_dir,
                 train_sequences=DATA["split"]["train"],
                 valid_sequences=DATA["split"]["valid"],
                 test_sequences=None,
                 labels=DATA["labels"],
                 color_map=DATA["color_map"],
                 learning_map=DATA["learning_map"],
                 learning_map_inv=DATA["learning_map_inv"],
                 sensor=ARCH["dataset"]["sensor"],
                 max_points=ARCH["dataset"]["max_points"],
                 batch_size=ARCH["train"]["batch_size"],
                 workers=ARCH["train"]["workers"],
                 gt=True,
                 shuffle_train=True,
                 overfit=ARCH["train"]["overfit"],
                 share_subset_train=ARCH["train"]["share_subset_train"],
                 collate=collate_fn)


In [None]:
# # shuffle + split dataset
# dataset = dataset.shuffle(seed=1)
# dataset = dataset["train"].train_test_split(test_size=0.2)
# train_ds = dataset["train"]
# test_ds = dataset["test"]

In [None]:
# # let's look at one example (images are pretty high resolution)
# example = train_ds[1]
# image = example['pixel_values']
# image

In [None]:
# import numpy as np

# # load corresponding ground truth segmentation map, which includes a label per pixel
# segmentation_map = np.array(example['label'])
# segmentation_map

Let's look at the semantic categories in this particular example.

In [None]:
# np.unique(segmentation_map)

Cool, but we want to know the actual class names. For that we need the id2label mapping, which is hosted in a repo on the hub.

In [None]:
from huggingface_hub import hf_hub_download
import json

id2label = DATA["id2label"]
print(id2label)

In [None]:
# labels = [id2label[label] for label in np.unique(segmentation_map)]
# print(labels)

Let's visualize it:

In [None]:
def color_palette():
    """Color palette that maps each class to RGB values.
    
    This one is actually taken from ADE20k.
    """
    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
            [102, 255, 0], [92, 0, 255]]

palette = DATA["color_map"]

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
# for label, color in enumerate(palette):
#     color_segmentation_map[segmentation_map - 1 == label, :] = color
# # Convert to BGR
# ground_truth_color_seg = color_segmentation_map[..., ::-1]

# img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
# img = img.astype(np.uint8)

# plt.figure(figsize=(15, 10))
# plt.imshow(img)
# plt.show()

## Create PyTorch Dataset

Next, we create a standard PyTorch dataset. Each item of the dataset consists of the image and corresponding ground truth segmentation map. We also include the original image + map (before preprocessing) in order to compute metrics like mIoU.

In [None]:
import numpy as np
from torch.utils.data import Dataset

class ImageSegmentationDataset(Dataset):
    """Image segmentation dataset."""

    def __init__(self, dataset, transform):
        """
        Args:
            dataset
        """
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        original_image = np.array(self.dataset[idx]['pixel_values'])
        original_segmentation_map = np.array(self.dataset[idx]['label'])
        
        transformed = self.transform(image=original_image, mask=original_segmentation_map)
        image, segmentation_map = transformed['image'], transformed['mask']

        # convert to C, H, W
        image = image.transpose(2,0,1)

        return image, segmentation_map, original_image, original_segmentation_map

The dataset accepts image transformations which can be applied on both the image and the map. Here we use Albumentations, to resize, randomly crop + flip and normalize them. Data augmentation is a widely used technique in computer vision to make the model more robust.

In [None]:
# import albumentations as A

# ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
# ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

# train_transform = A.Compose([
#     A.LongestMaxSize(max_size=1333),
#     A.RandomCrop(width=512, height=512),
#     A.HorizontalFlip(p=0.5),
#     A.Normalize(mean=ADE_MEAN, std=ADE_STD),
# ])

# test_transform = A.Compose([
#     A.Resize(width=512, height=512),
#     A.Normalize(mean=ADE_MEAN, std=ADE_STD),

# ])

# train_dataset = dataset.get_train_set()
# test_dataset = dataset.get_valid_set()

In [None]:
# image, segmentation_map, _, _ = train_dataset[0]
# print(image.shape)
# print(segmentation_map.shape)

A great way to check that our data augmentations are working well is by denormalizing the pixel values. So here we perform the inverse operation of Albumentations' normalize method and visualize the image:

In [None]:
# from PIL import Image

# unnormalized_image = (image * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
# Image.fromarray(unnormalized_image)

This looks ok. Let's also verify whether the corresponding ground truth map is still ok.

In [None]:
# segmentation_map.shape

In [None]:
# labels = [id2label[label] for label in np.unique(segmentation_map)]
# print(labels)

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
# for label, color in enumerate(palette):
#     color_segmentation_map[segmentation_map == label, :] = color
# # Convert to BGR
# ground_truth_color_seg = color_segmentation_map[..., ::-1]

# img = np.moveaxis(image, 0, -1) * 0.5 + ground_truth_color_seg * 0.5
# img = img.astype(np.uint8)

# plt.figure(figsize=(15, 10))
# plt.imshow(img)
# plt.show()

Ok great!

## Create PyTorch DataLoaders

Next we create PyTorch DataLoaders, which allow us to get batches of the dataset. For that we define a custom so-called "collate function", which PyTorch allows you to do. It's in this function that we'll use the preprocessor of MaskFormer, to turn the images + maps into the format that MaskFormer expects.

It's here that we make the paradigm shift that the MaskFormer authors introduced: the "per-pixel" annotations of the segmentation map will be turned into a set of binary masks and corresponding labels. It's this format on which we can train MaskFormer. MaskFormer namely casts any image segmentation task to this format.

In [None]:
# from transformers import MaskFormerImageProcessor

# # Create a preprocessor
# preprocessor = MaskFormerImageProcessor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)

In [None]:
# from torch.utils.data import DataLoader

# def collate_fn(batch):
#     inputs = list(zip(*batch))
#     images = inputs[0]
#     segmentation_maps = inputs[1]
#     # this function pads the inputs to the same size,
#     # and creates a pixel mask
#     # actually padding isn't required here since we are cropping
#     batch = preprocessor(
#         images,
#         segmentation_maps=segmentation_maps,
#         return_tensors="pt",
#     )

#     batch["original_images"] = inputs[2]
#     batch["original_segmentation_maps"] = inputs[3]
    
#     return batch

train_dataloader = dataset.get_train_set()
test_dataloader = dataset.get_valid_set()

## Verify data (!!)

Next, it's ALWAYS very important to check whether the data you feed to the model actually makes sense. It's one of the main principles of [this amazing blog post](http://karpathy.github.io/2019/04/25/recipe/), if you wanna debug your neural networks.

Let's check the first batch, and its content.

In [None]:
import torch

batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k,v.shape)
  else:
    print(k,v[0].shape)

In [None]:
pixel_values = batch["pixel_values"][0].numpy()
pixel_values.shape

Again, let's denormalize an image and see what we got.

In [None]:
# unnormalized_image = (pixel_values * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
# Image.fromarray(unnormalized_image)

Let's verify the corresponding binary masks + class labels.

In [None]:
# verify class labels
labels = [id2label[label] for label in batch["class_labels"][0].tolist()]
print(labels)

In [None]:
# verify mask labels
batch["mask_labels"][0].shape

In [None]:
import numpy as np
from PIL import Image
def visualize_mask(labels, label_name):
  print("Label:", label_name)
  idx = labels.index(label_name)

  visual_mask = (batch["mask_labels"][0][idx].bool().numpy() * 255).astype(np.uint8)
  return Image.fromarray(visual_mask)

In [None]:
visualize_mask(labels, "traffic-sign")

## Define model

Next, we define the model. We equip the model with pretrained weights from the 🤗 hub. We will replace only the classification head. For that we provide the id2label mapping, and specify to ignore mismatches keys to replace the already fine-tuned classification head.

In [None]:
import torch
import torch.nn as nn
from transformers import MaskFormerForInstanceSegmentation
class New2Former(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(5,3,1)
        self.bn = nn.BatchNorm2d(3)
        self.lrelu = nn.LeakyReLU(inplace=True)
        self.main = MaskFormerForInstanceSegmentation.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic",
                                                          id2label=id2label,
                                                          ignore_mismatched_sizes=True)
    
    def forward(self, pixel_values, class_labels, mask_labels):
        pixel_values = self.conv(pixel_values)
        pixel_values = self.bn(pixel_values)
        pixel_values = self.lrelu(pixel_values)

        x = self.main(pixel_values=pixel_values, class_labels=class_labels, mask_labels=mask_labels)
        return x

In [None]:

model = New2Former()

See also the warning here: it's telling us that we are only replacing the class_predictor, which makes sense. As it's the only parameters that we will train from scratch.

## Compute initial loss

Another good way to debug neural networks is to verify the initial loss, see if it makes sense.

In [None]:
outputs = model(batch["pixel_values"],
                class_labels=batch["class_labels"],
                mask_labels=batch["mask_labels"])

In [None]:
outputs.loss

## Train the model

It's time to train the model! We'll use the mIoU metric to track progress.

In [None]:
#!pip install -q evaluate

In [None]:
import evaluate

metric = evaluate.load("mean_iou")

In [33]:
import torch
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

running_loss = 0.0
num_samples = 0
for epoch in range(100):
  print("Epoch:", epoch)
  model.train()
  for idx, batch in enumerate(tqdm(train_dataloader)):
      # Reset the parameter gradients
      optimizer.zero_grad()

      # Forward pass
      outputs = model(
          pixel_values=batch["pixel_values"].to(device),
          mask_labels=[labels.to(device) for labels in batch["mask_labels"]],
          class_labels=[labels.to(device) for labels in batch["class_labels"]],
      )

      # Backward propagation
      loss = outputs.loss
      loss.backward()

      batch_size = batch["pixel_values"].size(0)
      running_loss += loss.item()
      num_samples += batch_size

      if idx % 100 == 0:
        print("Loss:", running_loss/num_samples)

      # Optimization
      optimizer.step()

  model.eval()
  for idx, batch in enumerate(tqdm(test_dataloader)):
    if idx > 5:
      break

    pixel_values = batch["pixel_values"]
    
    # Forward pass
    with torch.no_grad():
      outputs = model(pixel_values=pixel_values.to(device), mask_labels=None, class_labels=None)

    # get original images
    original_images = batch["original_images"]
    target_sizes = [(image.shape[1], image.shape[2]) for image in original_images]
    # predict segmentation maps
    predicted_segmentation_maps = preprocessor.post_process_semantic_segmentation(outputs,
                                                                                  target_sizes=target_sizes)

    # get ground truth segmentation maps
    ground_truth_segmentation_maps = batch["original_segmentation_maps"]

    metric.add_batch(references=ground_truth_segmentation_maps, predictions=predicted_segmentation_maps)
  
  # NOTE this metric outputs a dict that also includes the mIoU per category as keys
  # so if you're interested, feel free to print them as well
  print("Mean IoU:", metric.compute(num_labels = len(id2label), ignore_index = 0)['mean_iou'])

Loss: 8.283796694851661


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 86


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.275933181645769


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 87


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.268375822992036


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 88


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.260364496752564


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 89


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.25353060828315


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 90


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.24765528165377


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 91


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.242518224577973


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 92


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.23660537418926


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 93


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.229419153632847


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 94


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.22374761481034


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 95


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.218466407722897


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 96


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.2134730873239


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 97


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.206566894946455


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 98


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.203398489390159


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113
Epoch: 99


  0%|          | 0/1 [00:00<?, ?it/s]

Loss: 8.197525908152262


  0%|          | 0/1 [00:00<?, ?it/s]

Mean IoU: 0.0005414909608642113


## Inference

After training, we can use the model to make predictions on new data.

Let's showcase this one of the examples of a test batch.

In [34]:
# let's take the first test batch
batch = next(iter(test_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k,v.shape)
  else:
    print(k,len(v))

pixel_values torch.Size([6, 5, 64, 512])
pixel_mask torch.Size([6, 64, 512])
mask_labels 6
class_labels 6
original_images torch.Size([6, 5, 64, 512])
original_segmentation_maps 6


In [36]:
# forward pass
with torch.no_grad():
  outputs = model(batch["pixel_values"].to(device), mask_labels=None, class_labels=None)

In [42]:
original_images = batch["original_images"]
target_sizes = [(image.shape[1], image.shape[2]) for image in original_images]
# predict segmentation maps
predicted_segmentation_maps = preprocessor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

In [39]:
image = batch["original_images"][0]
Image.fromarray(image)

KeyboardInterrupt: 

In [44]:
import numpy as np
import matplotlib.pyplot as plt

segmentation_map = predicted_segmentation_maps[0].cpu().numpy()

color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
    color_segmentation_map[segmentation_map == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]

img = ground_truth_color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

Compare to the ground truth:

In [None]:
import numpy as np
import matplotlib.pyplot as plt

segmentation_map = batch["original_segmentation_maps"][0]

color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
for label, color in enumerate(palette):
    color_segmentation_map[segmentation_map == label, :] = color
# Convert to BGR
ground_truth_color_seg = color_segmentation_map[..., ::-1]

img = image * 0.5 + ground_truth_color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

I didn't do a lot of training (only 2 epochs), and results don't look too bad. I'd suggest checking the paper to find all details regarding training hyperparameters (number of epochs, learning rate, etc.).