In [1]:

import requests, zipfile, io

def download_data():
    url = "https://www.dropbox.com/s/l1e45oht447053f/ADE20k_toy_dataset.zip?dl=1"
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall()

download_data()

In [2]:
from datasets import load_dataset

load_entire_dataset = False

if load_entire_dataset:
    dataset = load_dataset("scene_parse_150")

In [3]:
from torch.utils.data import Dataset
import os
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, feature_extractor, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train

        sub_path = "training" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, "images", sub_path)
        self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)
        
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)
        
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [4]:
from transformers import SegformerFeatureExtractor

root_dir = './ADE20k_toy_dataset'
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)

train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor)
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, train=False)



In [5]:

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

Number of training examples: 10
Number of validation examples: 10


In [6]:

encoded_inputs = train_dataset[0]
     
encoded_inputs["pixel_values"].shape
     
encoded_inputs["labels"].shape
     
encoded_inputs["labels"]
     
encoded_inputs["labels"].squeeze().unique()

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2)

batch = next(iter(train_dataloader))
     

for k,v in batch.items():
    print(k, v.shape)
     

batch["labels"].shape

mask = (batch["labels"] != 255)
mask



batch["labels"][mask]

pixel_values torch.Size([2, 3, 512, 512])
labels torch.Size([2, 512, 512])


tensor([0, 0, 0,  ..., 3, 3, 3])

In [8]:
from transformers import SegformerForSemanticSegmentation
import json
from huggingface_hub import cached_download, hf_hub_url

# load id2label mapping from a JSON on the hub
# repo_id = "datasets/huggingface/label-files"
filename = "ade20k-id2label.json"
# id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(filename, "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=150, 
                                                         id2label=id2label, 
                                                         label2id=label2id,
)


from datasets import load_metric

metric = load_metric("mean_iou")

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.1.proj.bias', 'decode_head.classi

Downloading builder script:   0%|          | 0.00/3.14k [00:00<?, ?B/s]

In [None]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(200):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)

          # note that the metric expects predictions + labels as numpy arrays
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        # let's print loss and metrics every 100 batches
        if idx % 100 == 0:
            metrics = metric.compute(num_labels=len(id2label), 
                                   ignore_index=255,
                                   reduce_labels=False, # we've already reduced the labels before)
            )

            print("Loss:", loss.item())
            print("Mean_iou:", metrics["mean_iou"])
            print("Mean accuracy:", metrics["mean_accuracy"])

    torch.save(model, f"./checkpoint/model_{epoch}.pth")

Epoch: 0


  0%|          | 0/5 [00:00<?, ?it/s]

  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label


Loss: 5.083634853363037
Mean_iou: 0.000872726704809543
Mean accuracy: 0.02625071979286301
Epoch: 1


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.9570817947387695
Mean_iou: 0.0020488594530478576
Mean accuracy: 0.02228172344407478
Epoch: 2


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.785323143005371
Mean_iou: 0.0063430898239593015
Mean accuracy: 0.04676395855403611
Epoch: 3


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.602604866027832
Mean_iou: 0.014031503001674968
Mean accuracy: 0.10649824172317988
Epoch: 4


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.463755130767822
Mean_iou: 0.018681527606579894
Mean accuracy: 0.12115054003402756
Epoch: 5


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.264898777008057
Mean_iou: 0.03830783951840777
Mean accuracy: 0.17892253403300518
Epoch: 6


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.144379615783691
Mean_iou: 0.04317706407551832
Mean accuracy: 0.1700403152027156
Epoch: 7


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.1956071853637695
Mean_iou: 0.04118160799578239
Mean accuracy: 0.16185505252966206
Epoch: 8


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.8265485763549805
Mean_iou: 0.10915512843322746
Mean accuracy: 0.2553750304089508
Epoch: 9


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.9455580711364746
Mean_iou: 0.05513144236416093
Mean accuracy: 0.1903076201941269
Epoch: 10


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.9794440269470215
Mean_iou: 0.0972939497069342
Mean accuracy: 0.22521938509826214
Epoch: 11


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.5631840229034424
Mean_iou: 0.12991140997449352
Mean accuracy: 0.2031704879514114
Epoch: 12


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.6403746604919434
Mean_iou: 0.1588597536767722
Mean accuracy: 0.2445338007389535
Epoch: 13


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.0155792236328125
Mean_iou: 0.1764191782220241
Mean accuracy: 0.25050622398015127
Epoch: 14


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.99212384223938
Mean_iou: 0.16389293532289204
Mean accuracy: 0.2699191512693911
Epoch: 15


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.047090530395508
Mean_iou: 0.20902907785865532
Mean accuracy: 0.26924018558379786
Epoch: 16


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.8073089122772217
Mean_iou: 0.23869955253160394
Mean accuracy: 0.3124341089297942
Epoch: 17


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.5908772945404053
Mean_iou: 0.1771115542447049
Mean accuracy: 0.23823296937920926
Epoch: 18


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.375248670578003
Mean_iou: 0.2169632421626907
Mean accuracy: 0.29718072692124964
Epoch: 19


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.5620577335357666
Mean_iou: 0.1991143184847925
Mean accuracy: 0.25425310149786506
Epoch: 20


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.8127763271331787
Mean_iou: 0.18012757905852955
Mean accuracy: 0.24873413288410248
Epoch: 21


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.598235607147217
Mean_iou: 0.23228515051820794
Mean accuracy: 0.2792293507884741
Epoch: 22


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.4259865283966064
Mean_iou: 0.24344832399227115
Mean accuracy: 0.30447836968262465
Epoch: 23


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.501638889312744
Mean_iou: 0.19661824950280965
Mean accuracy: 0.23714354054907186
Epoch: 24


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.688905715942383
Mean_iou: 0.2277175576557373
Mean accuracy: 0.272707188180831
Epoch: 25


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.3967463970184326
Mean_iou: 0.24239703511116592
Mean accuracy: 0.30046512233539135
Epoch: 26


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.43829345703125
Mean_iou: 0.29615899878556123
Mean accuracy: 0.34215922629480217
Epoch: 27


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.1405038833618164
Mean_iou: 0.24143547877673488
Mean accuracy: 0.2863028685084338
Epoch: 28


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.3272511959075928
Mean_iou: 0.25665873198636524
Mean accuracy: 0.30161425344717985
Epoch: 29


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.322385311126709
Mean_iou: 0.2754421134240847
Mean accuracy: 0.32687974575472695
Epoch: 30


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.3545970916748047
Mean_iou: 0.32393720417416194
Mean accuracy: 0.3768557119365437
Epoch: 31


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.350829839706421
Mean_iou: 0.22320799456963988
Mean accuracy: 0.2735001358673426


In [None]:

image = Image.open('/content/ADE20k_toy_dataset/images/training/ADE_train_00000001.jpg')
image

In [None]:
# prepare the image for the model
encoding = feature_extractor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)
print(pixel_values.shape)

In [9]:

# forward pass
outputs = model(pixel_values=pixel_values)

# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = outputs.logits.cpu()
print(logits.shape)

def ade_palette():
    """ADE20K palette that maps each class to RGB values."""
    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
            [102, 255, 0], [92, 0, 255]]


NameError: name 'model' is not defined

In [None]:
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# First, rescale logits to original image size
upsampled_logits = nn.functional.interpolate(logits,
                size=image.size[::-1], # (height, width)
                mode='bilinear',
                align_corners=False)

# Second, apply argmax on the class dimension
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

In [None]:

map = Image.open('/content/ADE20k_toy_dataset/annotations/training/ADE_train_00000001.png') 
map 

In [None]:
# convert map to NumPy array
map = np.array(map)
map[map == 0] = 255 # background class is replaced by ignore_index
map = map - 1 # other classes are reduced by one
map[map == 254] = 255

classes_map = np.unique(map).tolist()
unique_classes = [model.config.id2label[idx] if idx!=255 else None for idx in classes_map]
print("Classes in this image:", unique_classes)

# create coloured map
color_seg = np.zeros((map.shape[0], map.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[map == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()