In [1]:

import requests, zipfile, io

def download_data():
    url = "https://www.dropbox.com/s/l1e45oht447053f/ADE20k_toy_dataset.zip?dl=1"
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall()

download_data()

In [2]:
from datasets import load_dataset

load_entire_dataset = False

if load_entire_dataset:
    dataset = load_dataset("scene_parse_150")

In [3]:
from torch.utils.data import Dataset
import os
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, feature_extractor, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train

        sub_path = "training" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, "images", sub_path)
        self.ann_dir = os.path.join(self.root_dir, "annotations", sub_path)
        
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
            image_file_names.extend(files)
        self.images = sorted(image_file_names)
        
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
            annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [4]:
from transformers import SegformerFeatureExtractor

root_dir = './ADE20k_toy_dataset'
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)

train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor)
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, train=False)



In [5]:

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

Number of training examples: 10
Number of validation examples: 10


In [6]:

encoded_inputs = train_dataset[0]
     
encoded_inputs["pixel_values"].shape
     
encoded_inputs["labels"].shape
     
encoded_inputs["labels"]
     
encoded_inputs["labels"].squeeze().unique()

from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2)

batch = next(iter(train_dataloader))
     

for k,v in batch.items():
    print(k, v.shape)
     

batch["labels"].shape

mask = (batch["labels"] != 255)
mask



batch["labels"][mask]

pixel_values torch.Size([2, 3, 512, 512])
labels torch.Size([2, 512, 512])


tensor([ 0,  0,  0,  ..., 12, 12, 12])

In [7]:
from transformers import SegformerForSemanticSegmentation
import json
from huggingface_hub import cached_download, hf_hub_url

# load id2label mapping from a JSON on the hub
# repo_id = "datasets/huggingface/label-files"
filename = "ade20k-id2label.json"
# id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(filename, "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=150, 
                                                         id2label=id2label, 
                                                         label2id=label2id,
)


from datasets import load_metric

metric = load_metric("mean_iou")

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_c.0

Downloading builder script:   0%|          | 0.00/3.14k [00:00<?, ?B/s]

In [9]:
import torch
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.train()
for epoch in range(200):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()

        # evaluate
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)

          # note that the metric expects predictions + labels as numpy arrays
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        # let's print loss and metrics every 100 batches
        if idx % 100 == 0:
            metrics = metric.compute(num_labels=len(id2label), 
                                   ignore_index=255,
                                   reduce_labels=False, # we've already reduced the labels before)
            )

            print("Loss:", loss.item())
            print("Mean_iou:", metrics["mean_iou"])
            print("Mean accuracy:", metrics["mean_accuracy"])

    torch.save(model, f"./checkpoint/model_{epoch}.pth")

Epoch: 0


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.885988712310791
Mean_iou: 0.001108338343730259
Mean accuracy: 0.007746149899091702
Epoch: 1


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.725132942199707
Mean_iou: 0.0035554620313158866
Mean accuracy: 0.02243139804877473
Epoch: 2


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.603793621063232
Mean_iou: 0.013006979638812165
Mean accuracy: 0.08189280203857704
Epoch: 3


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.550928592681885
Mean_iou: 0.020197999324659904
Mean accuracy: 0.11286524757610408
Epoch: 4


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.420210838317871
Mean_iou: 0.0397075922937465
Mean accuracy: 0.22925212784175808
Epoch: 5


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.125049114227295
Mean_iou: 0.056225572204735404
Mean accuracy: 0.19038358429400679
Epoch: 6


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 4.238358020782471
Mean_iou: 0.050585504611021744
Mean accuracy: 0.18904903192378045
Epoch: 7


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.8766393661499023
Mean_iou: 0.08223577672379523
Mean accuracy: 0.2223413107594349
Epoch: 8


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.74088454246521
Mean_iou: 0.12084609717843582
Mean accuracy: 0.28016342249081805
Epoch: 9


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.7442092895507812
Mean_iou: 0.09054087183235143
Mean accuracy: 0.22926905553495763
Epoch: 10


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.2657930850982666
Mean_iou: 0.12049164913609708
Mean accuracy: 0.24076603326830504
Epoch: 11


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.5445048809051514
Mean_iou: 0.13711914695490687
Mean accuracy: 0.23568209193768977
Epoch: 12


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.1043272018432617
Mean_iou: 0.14489756880587998
Mean accuracy: 0.23125018071904604
Epoch: 13


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.751434803009033
Mean_iou: 0.1658224108766598
Mean accuracy: 0.2355966239613236
Epoch: 14


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.453922748565674
Mean_iou: 0.1574993253813144
Mean accuracy: 0.24486848440289447
Epoch: 15


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.5949971675872803
Mean_iou: 0.1878514006056526
Mean accuracy: 0.2602655258110511
Epoch: 16


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 3.3972370624542236
Mean_iou: 0.17154587587142012
Mean accuracy: 0.23321836567115956
Epoch: 17


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.2737514972686768
Mean_iou: 0.2246116615567543
Mean accuracy: 0.29524238707057476
Epoch: 18


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.8995683193206787
Mean_iou: 0.18935672912824503
Mean accuracy: 0.25460998487176784
Epoch: 19


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.8046059608459473
Mean_iou: 0.23823017644689004
Mean accuracy: 0.30432960133057735
Epoch: 20


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.626376152038574
Mean_iou: 0.2316868751862082
Mean accuracy: 0.27534809402898275
Epoch: 21


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.898878574371338
Mean_iou: 0.27537975770708184
Mean accuracy: 0.3303293380466333
Epoch: 22


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.5656933784484863
Mean_iou: 0.22629232928692816
Mean accuracy: 0.2815270038635133
Epoch: 23


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.9332070350646973
Mean_iou: 0.21588899800614755
Mean accuracy: 0.2581436900221101
Epoch: 24


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.569089412689209
Mean_iou: 0.2228533244866479
Mean accuracy: 0.276717665287277
Epoch: 25


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.6637920141220093
Mean_iou: 0.22088933049185705
Mean accuracy: 0.2669222012829067
Epoch: 26


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.4255850315093994
Mean_iou: 0.25115846712771617
Mean accuracy: 0.2979287488593476
Epoch: 27


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.6918587684631348
Mean_iou: 0.30426687231782185
Mean accuracy: 0.3718831956498892
Epoch: 28


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.328789234161377
Mean_iou: 0.22589665415568652
Mean accuracy: 0.2650873353045408
Epoch: 29


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.4222352504730225
Mean_iou: 0.2899298104557454
Mean accuracy: 0.33710182185716914
Epoch: 30


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.335547685623169
Mean_iou: 0.2506809646499852
Mean accuracy: 0.29530766089037375
Epoch: 31


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.118636131286621
Mean_iou: 0.27139679302644787
Mean accuracy: 0.3215277187268468
Epoch: 32


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.729643702507019
Mean_iou: 0.26324055138976715
Mean accuracy: 0.30241518365240866
Epoch: 33


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.4611603021621704
Mean_iou: 0.26169075352768734
Mean accuracy: 0.30573594468310306
Epoch: 34


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.1289539337158203
Mean_iou: 0.2584820240998209
Mean accuracy: 0.3103010620276924
Epoch: 35


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.583867073059082
Mean_iou: 0.23762090832467242
Mean accuracy: 0.2899661585552459
Epoch: 36


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.7704905271530151
Mean_iou: 0.2712068375174568
Mean accuracy: 0.3183047214643409
Epoch: 37


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.9638242721557617
Mean_iou: 0.31933281571171185
Mean accuracy: 0.35910146449171454
Epoch: 38


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.221874475479126
Mean_iou: 0.27404923778751833
Mean accuracy: 0.3124910606918728
Epoch: 39


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.8076571226119995
Mean_iou: 0.28036677424163065
Mean accuracy: 0.3153583880432331
Epoch: 40


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7410295009613037
Mean_iou: 0.3035600400562499
Mean accuracy: 0.3521177715503467
Epoch: 41


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 2.160517930984497
Mean_iou: 0.2758054471971794
Mean accuracy: 0.3209752453382516
Epoch: 42


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.4704872369766235
Mean_iou: 0.2257247716154136
Mean accuracy: 0.27229220221763734
Epoch: 43


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.8195880651474
Mean_iou: 0.30275460887729555
Mean accuracy: 0.3393982305110408
Epoch: 44


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.59140944480896
Mean_iou: 0.2758471200072617
Mean accuracy: 0.313178614259412
Epoch: 45


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.7049239873886108
Mean_iou: 0.31531932225535214
Mean accuracy: 0.3641385589242523
Epoch: 46


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.7563910484313965
Mean_iou: 0.3482639193251015
Mean accuracy: 0.38864452266260824
Epoch: 47


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.694391131401062
Mean_iou: 0.2490568513549926
Mean accuracy: 0.29899251160929907
Epoch: 48


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.8144384622573853
Mean_iou: 0.38888497357404955
Mean accuracy: 0.4290960350461796
Epoch: 49


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9885122776031494
Mean_iou: 0.27809102306176675
Mean accuracy: 0.3364765994114571
Epoch: 50


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.9192452430725098
Mean_iou: 0.2780875430319819
Mean accuracy: 0.32603817936328855
Epoch: 51


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.296571969985962
Mean_iou: 0.2744061903047479
Mean accuracy: 0.3372810288311168
Epoch: 52


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.6289998292922974
Mean_iou: 0.33910265063565476
Mean accuracy: 0.37990413925953403
Epoch: 53


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.7797770500183105
Mean_iou: 0.25034768600100143
Mean accuracy: 0.30434878306788554
Epoch: 54


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.5958081483840942
Mean_iou: 0.4747811130369007
Mean accuracy: 0.5150435017351512
Epoch: 55


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7234823703765869
Mean_iou: 0.23734603973106824
Mean accuracy: 0.2889895763646194
Epoch: 56


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.1009719371795654
Mean_iou: 0.2868178930360583
Mean accuracy: 0.3329585205094886
Epoch: 57


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.0785696506500244
Mean_iou: 0.34472014218954844
Mean accuracy: 0.3786752529466086
Epoch: 58


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.1374708414077759
Mean_iou: 0.30664960867304386
Mean accuracy: 0.339309948687433
Epoch: 59


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.329535722732544
Mean_iou: 0.31127987272650914
Mean accuracy: 0.34774647938200176
Epoch: 60


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.4442685842514038
Mean_iou: 0.35216159502056976
Mean accuracy: 0.39432180242364767
Epoch: 61


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.5084984302520752
Mean_iou: 0.3439891164803589
Mean accuracy: 0.3816117579376965
Epoch: 62


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5826025605201721
Mean_iou: 0.295540496058099
Mean accuracy: 0.3464883714840868
Epoch: 63


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.1386276483535767
Mean_iou: 0.29121207034349256
Mean accuracy: 0.33068423443325706
Epoch: 64


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.3459959030151367
Mean_iou: 0.367560588397828
Mean accuracy: 0.4189850476781917
Epoch: 65


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.0102214813232422
Mean_iou: 0.37690878975276765
Mean accuracy: 0.41681318964032704
Epoch: 66


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5942162871360779
Mean_iou: 0.30636752173205456
Mean accuracy: 0.3520040309464197
Epoch: 67


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.3014960289001465
Mean_iou: 0.29760134285341594
Mean accuracy: 0.3420924106657326
Epoch: 68


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.0575844049453735
Mean_iou: 0.3579616428085801
Mean accuracy: 0.3917559966175284
Epoch: 69


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.4579142332077026
Mean_iou: 0.3454129569416998
Mean accuracy: 0.3969424397918993
Epoch: 70


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.77836674451828
Mean_iou: 0.30899991646114083
Mean accuracy: 0.3685838611599275
Epoch: 71


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8565409183502197
Mean_iou: 0.33178904698366796
Mean accuracy: 0.3698426844000425
Epoch: 72


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.879786491394043
Mean_iou: 0.3679537613646584
Mean accuracy: 0.41612676738399335
Epoch: 73


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.353743076324463
Mean_iou: 0.3155865103071223
Mean accuracy: 0.361593929429391
Epoch: 74


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6652700304985046
Mean_iou: 0.32034402858836336
Mean accuracy: 0.3706957465405417
Epoch: 75


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9894077181816101
Mean_iou: 0.36212152051741914
Mean accuracy: 0.4007715666771562
Epoch: 76


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.0986864566802979
Mean_iou: 0.38265896984889064
Mean accuracy: 0.4196695760314711
Epoch: 77


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.094161033630371
Mean_iou: 0.30092423864528967
Mean accuracy: 0.3510876543477628
Epoch: 78


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9086965322494507
Mean_iou: 0.36295092920267685
Mean accuracy: 0.4040782199368995
Epoch: 79


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9318731427192688
Mean_iou: 0.40297248909574657
Mean accuracy: 0.4454844454505838
Epoch: 80


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6310086250305176
Mean_iou: 0.37643521409132485
Mean accuracy: 0.4075477480658694
Epoch: 81


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.1685681343078613
Mean_iou: 0.3335552430550505
Mean accuracy: 0.36549206093002584
Epoch: 82


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.36545345187187195
Mean_iou: 0.3238208929042852
Mean accuracy: 0.3741628467003568
Epoch: 83


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9655871391296387
Mean_iou: 0.3238241603457432
Mean accuracy: 0.36495836794236325
Epoch: 84


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.2913546562194824
Mean_iou: 0.37442711293991265
Mean accuracy: 0.42285008606268587
Epoch: 85


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4001641273498535
Mean_iou: 0.35935214702974627
Mean accuracy: 0.3887693664059347
Epoch: 86


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8385242223739624
Mean_iou: 0.33768350090290167
Mean accuracy: 0.3739750910185696
Epoch: 87


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.9849682450294495
Mean_iou: 0.38769536165953966
Mean accuracy: 0.42951652790278344
Epoch: 88


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5379916429519653
Mean_iou: 0.41661300333788714
Mean accuracy: 0.4541258358977237
Epoch: 89


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.2851184904575348
Mean_iou: 0.3877365718813683
Mean accuracy: 0.42116246887181774
Epoch: 90


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8762495517730713
Mean_iou: 0.35829464308710274
Mean accuracy: 0.395833436383818
Epoch: 91


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.970747172832489
Mean_iou: 0.4078145161225567
Mean accuracy: 0.43714408018608863
Epoch: 92


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.1171332597732544
Mean_iou: 0.3425244813343218
Mean accuracy: 0.38341982405811653
Epoch: 93


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8656970858573914
Mean_iou: 0.3241940684146723
Mean accuracy: 0.3668860349910859
Epoch: 94


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.878955602645874
Mean_iou: 0.4326347708316877
Mean accuracy: 0.46246999782135023
Epoch: 95


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.0100836753845215
Mean_iou: 0.3692549708676087
Mean accuracy: 0.41004873675965997
Epoch: 96


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6896365880966187
Mean_iou: 0.4602115314102317
Mean accuracy: 0.49453760733296614
Epoch: 97


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8777745366096497
Mean_iou: 0.3631127848632478
Mean accuracy: 0.39980004421650567
Epoch: 98


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.29986879229545593
Mean_iou: 0.35299616050084404
Mean accuracy: 0.40596307737483406
Epoch: 99


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7966178059577942
Mean_iou: 0.3412114500642487
Mean accuracy: 0.37845312721622115
Epoch: 100


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.24253107607364655
Mean_iou: 0.3770109116036853
Mean accuracy: 0.40543251679304093
Epoch: 101


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8277925252914429
Mean_iou: 0.3834494051307309
Mean accuracy: 0.41945234505760864
Epoch: 102


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5542042851448059
Mean_iou: 0.40903993273078465
Mean accuracy: 0.43991525790957675
Epoch: 103


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.901894211769104
Mean_iou: 0.36552226404283394
Mean accuracy: 0.41003821015150466
Epoch: 104


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8318941593170166
Mean_iou: 0.42611414441584045
Mean accuracy: 0.46988936687370375
Epoch: 105


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.931637167930603
Mean_iou: 0.32517944565823276
Mean accuracy: 0.38767283346123027
Epoch: 106


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.18074332177639008
Mean_iou: 0.4369325414662308
Mean accuracy: 0.46986718074009204
Epoch: 107


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6808466911315918
Mean_iou: 0.3907088298544126
Mean accuracy: 0.43324341693228496
Epoch: 108


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6175119876861572
Mean_iou: 0.46148790678123486
Mean accuracy: 0.4983447107696269
Epoch: 109


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7403811812400818
Mean_iou: 0.4211616062589777
Mean accuracy: 0.4613888128395054
Epoch: 110


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7389628887176514
Mean_iou: 0.44529682753257344
Mean accuracy: 0.4786805295199809
Epoch: 111


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.25571414828300476
Mean_iou: 0.4183277879185514
Mean accuracy: 0.46577475359899845
Epoch: 112


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.15472166240215302
Mean_iou: 0.39901859224607245
Mean accuracy: 0.44199438884456504
Epoch: 113


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6578091382980347
Mean_iou: 0.39170618610234575
Mean accuracy: 0.42164352147838163
Epoch: 114


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.18620382249355316
Mean_iou: 0.423220172671926
Mean accuracy: 0.45918017762111013
Epoch: 115


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.43848058581352234
Mean_iou: 0.39133393358964996
Mean accuracy: 0.42669223367074144
Epoch: 116


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6930124163627625
Mean_iou: 0.406564719789487
Mean accuracy: 0.4521065846512766
Epoch: 117


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7634966969490051
Mean_iou: 0.3401563024419852
Mean accuracy: 0.3943525039027016
Epoch: 118


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6337383389472961
Mean_iou: 0.4051418790429486
Mean accuracy: 0.44985864965950817
Epoch: 119


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6081811189651489
Mean_iou: 0.37286090820647283
Mean accuracy: 0.4307674453579037
Epoch: 120


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5652259588241577
Mean_iou: 0.4237874854210772
Mean accuracy: 0.45490666937176094
Epoch: 121


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6741742491722107
Mean_iou: 0.39941773382597673
Mean accuracy: 0.4351232331435876
Epoch: 122


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6154605746269226
Mean_iou: 0.4648018747222005
Mean accuracy: 0.5066771394475155
Epoch: 123


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8098996877670288
Mean_iou: 0.434565105078279
Mean accuracy: 0.464046695730103
Epoch: 124


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5209024548530579
Mean_iou: 0.4617030827036037
Mean accuracy: 0.4969587278897309
Epoch: 125


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6353365778923035
Mean_iou: 0.44427073853866533
Mean accuracy: 0.48875061281906546
Epoch: 126


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4355647563934326
Mean_iou: 0.4343808788326983
Mean accuracy: 0.475760904882935
Epoch: 127


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6298298239707947
Mean_iou: 0.4239910584318781
Mean accuracy: 0.4653472835282521
Epoch: 128


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5560383796691895
Mean_iou: 0.4342409409464306
Mean accuracy: 0.46901453763555623
Epoch: 129


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.26101988554000854
Mean_iou: 0.43821251985673493
Mean accuracy: 0.4702355620819941
Epoch: 130


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5445801019668579
Mean_iou: 0.44936266839602373
Mean accuracy: 0.4821529492212214
Epoch: 131


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6035285592079163
Mean_iou: 0.5182527933508585
Mean accuracy: 0.5661007898077879
Epoch: 132


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5450356006622314
Mean_iou: 0.4687077994522603
Mean accuracy: 0.4985390492722174
Epoch: 133


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.6859911680221558
Mean_iou: 0.4238860322995575
Mean accuracy: 0.4664723898760596
Epoch: 134


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.38650083541870117
Mean_iou: 0.4802255762819815
Mean accuracy: 0.5109175957519582
Epoch: 135


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.521648645401001
Mean_iou: 0.45611376172500445
Mean accuracy: 0.4909973959719008
Epoch: 136


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4626729190349579
Mean_iou: 0.4420547107316139
Mean accuracy: 0.4991635373232508
Epoch: 137


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.7313882112503052
Mean_iou: 0.40848125052361406
Mean accuracy: 0.4551406773225589
Epoch: 138


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4799870252609253
Mean_iou: 0.4687199112194236
Mean accuracy: 0.5137108748669594
Epoch: 139


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.17216844856739044
Mean_iou: 0.5671697472942123
Mean accuracy: 0.6028571480831162
Epoch: 140


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3616710305213928
Mean_iou: 0.4539587727030675
Mean accuracy: 0.49842650062067656
Epoch: 141


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.8127738833427429
Mean_iou: 0.4675530693986859
Mean accuracy: 0.4989554603287434
Epoch: 142


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5658367872238159
Mean_iou: 0.5390765812191527
Mean accuracy: 0.5710463550879691
Epoch: 143


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4341752231121063
Mean_iou: 0.47407011497024654
Mean accuracy: 0.5172589558108199
Epoch: 144


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4021514058113098
Mean_iou: 0.503780725250419
Mean accuracy: 0.5353660466974619
Epoch: 145


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5588246583938599
Mean_iou: 0.4413603980188392
Mean accuracy: 0.4852799005451112
Epoch: 146


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.40457963943481445
Mean_iou: 0.5179216453720857
Mean accuracy: 0.5501734334095034
Epoch: 147


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5006874203681946
Mean_iou: 0.44837343228984944
Mean accuracy: 0.4910205437411855
Epoch: 148


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4187666177749634
Mean_iou: 0.47495844355565114
Mean accuracy: 0.501839379665931
Epoch: 149


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.09646811336278915
Mean_iou: 0.5075992593321124
Mean accuracy: 0.5456359929083475
Epoch: 150


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4997323751449585
Mean_iou: 0.48854146074788907
Mean accuracy: 0.5268308425411634
Epoch: 151


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3825819790363312
Mean_iou: 0.5150049259067401
Mean accuracy: 0.556597011892689
Epoch: 152


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.47028517723083496
Mean_iou: 0.47648346266303643
Mean accuracy: 0.5181763632533508
Epoch: 153


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.15590637922286987
Mean_iou: 0.46786365697790533
Mean accuracy: 0.5223724828364621
Epoch: 154


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.16554279625415802
Mean_iou: 0.4853850282409213
Mean accuracy: 0.5139534409941531
Epoch: 155


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5099190473556519
Mean_iou: 0.49312719519813736
Mean accuracy: 0.5268396278945761
Epoch: 156


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.19148105382919312
Mean_iou: 0.5520455611820825
Mean accuracy: 0.5875770902543972
Epoch: 157


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4767568111419678
Mean_iou: 0.4601422624605484
Mean accuracy: 0.49078560421580253
Epoch: 158


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.11740610748529434
Mean_iou: 0.43939719900358115
Mean accuracy: 0.5006183373437191
Epoch: 159


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4702413082122803
Mean_iou: 0.5054972477168935
Mean accuracy: 0.5445392701529114
Epoch: 160


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.30358484387397766
Mean_iou: 0.5100216788051305
Mean accuracy: 0.5370140109119964
Epoch: 161


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3852705955505371
Mean_iou: 0.5205982369276705
Mean accuracy: 0.5474408254656946
Epoch: 162


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.15865421295166016
Mean_iou: 0.49081730426860787
Mean accuracy: 0.5353529969508919
Epoch: 163


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4103953540325165
Mean_iou: 0.4850957454926618
Mean accuracy: 0.5226782996248341
Epoch: 164


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.541419506072998
Mean_iou: 0.4760688760820596
Mean accuracy: 0.5166922898753195
Epoch: 165


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.19783645868301392
Mean_iou: 0.5864109469529684
Mean accuracy: 0.6162869961883952
Epoch: 166


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4615781605243683
Mean_iou: 0.510369616195734
Mean accuracy: 0.5396607922939481
Epoch: 167


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5065624713897705
Mean_iou: 0.5023903060940534
Mean accuracy: 0.5447112187000797
Epoch: 168


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.41413378715515137
Mean_iou: 0.4648989214888115
Mean accuracy: 0.5192328457971597
Epoch: 169


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.37748903036117554
Mean_iou: 0.4997403787123815
Mean accuracy: 0.5340901219101158
Epoch: 170


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3933143615722656
Mean_iou: 0.6062697879866041
Mean accuracy: 0.6380473932053243
Epoch: 171


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5193633437156677
Mean_iou: 0.519846763612525
Mean accuracy: 0.5599392418869321
Epoch: 172


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3317051827907562
Mean_iou: 0.478301974575612
Mean accuracy: 0.5311536193916438
Epoch: 173


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.30979931354522705
Mean_iou: 0.5523623803121498
Mean accuracy: 0.5808837830584139
Epoch: 174


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.42054232954978943
Mean_iou: 0.46900398151154604
Mean accuracy: 0.5157777573967941
Epoch: 175


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.32452335953712463
Mean_iou: 0.6061782375612265
Mean accuracy: 0.6387856865153654
Epoch: 176


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.31448739767074585
Mean_iou: 0.5237083558183803
Mean accuracy: 0.552303479614111
Epoch: 177


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.2987029552459717
Mean_iou: 0.5355054767169081
Mean accuracy: 0.564135031664022
Epoch: 178


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4921465516090393
Mean_iou: 0.4829803283999065
Mean accuracy: 0.5270043141477971
Epoch: 179


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3944723904132843
Mean_iou: 0.6134166300994515
Mean accuracy: 0.6756293885091048
Epoch: 180


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3746565282344818
Mean_iou: 0.5751876060190025
Mean accuracy: 0.6073511082012643
Epoch: 181


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.35406357049942017
Mean_iou: 0.5062967870887927
Mean accuracy: 0.547421940504534
Epoch: 182


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3105689287185669
Mean_iou: 0.5443110746333464
Mean accuracy: 0.5743057567810884
Epoch: 183


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.302610844373703
Mean_iou: 0.5213021036519739
Mean accuracy: 0.5664815213108275
Epoch: 184


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3880035877227783
Mean_iou: 0.48915932344519686
Mean accuracy: 0.5316598112724041
Epoch: 185


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.32361122965812683
Mean_iou: 0.5969645535425456
Mean accuracy: 0.6491509900680639
Epoch: 186


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3530969023704529
Mean_iou: 0.5425614186971365
Mean accuracy: 0.5702261128947265
Epoch: 187


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.5431556701660156
Mean_iou: 0.5335069386724938
Mean accuracy: 0.5610333069947939
Epoch: 188


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4134800136089325
Mean_iou: 0.5095885943996517
Mean accuracy: 0.5680695732805897
Epoch: 189


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.22731567919254303
Mean_iou: 0.5353506854858335
Mean accuracy: 0.5926375622296514
Epoch: 190


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.08865748345851898
Mean_iou: 0.5310160800991731
Mean accuracy: 0.5739762316530084
Epoch: 191


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.43424251675605774
Mean_iou: 0.5273000853333908
Mean accuracy: 0.5606784145906843
Epoch: 192


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3403221070766449
Mean_iou: 0.6517024015127886
Mean accuracy: 0.6946901192340396
Epoch: 193


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.21183714270591736
Mean_iou: 0.5649606650250351
Mean accuracy: 0.593340258993544
Epoch: 194


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.2613163888454437
Mean_iou: 0.5694456373872124
Mean accuracy: 0.6176489108632796
Epoch: 195


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3841126263141632
Mean_iou: 0.5627813916266778
Mean accuracy: 0.6063052344556524
Epoch: 196


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3303564488887787
Mean_iou: 0.4750391635377197
Mean accuracy: 0.5252772699843878
Epoch: 197


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.2713817358016968
Mean_iou: 0.6256939996288091
Mean accuracy: 0.6791405346065466
Epoch: 198


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.4910604655742645
Mean_iou: 0.5374251316792966
Mean accuracy: 0.5883710562664433
Epoch: 199


  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.3044281601905823
Mean_iou: 0.6612256482909546
Mean accuracy: 0.7333916666653908


In [None]:

image = Image.open('/content/ADE20k_toy_dataset/images/training/ADE_train_00000001.jpg')
image

In [None]:
# prepare the image for the model
encoding = feature_extractor(image, return_tensors="pt")
pixel_values = encoding.pixel_values.to(device)
print(pixel_values.shape)

In [9]:

# forward pass
outputs = model(pixel_values=pixel_values)

# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = outputs.logits.cpu()
print(logits.shape)

def ade_palette():
    """ADE20K palette that maps each class to RGB values."""
    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
            [102, 255, 0], [92, 0, 255]]


NameError: name 'model' is not defined

In [None]:
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

# First, rescale logits to original image size
upsampled_logits = nn.functional.interpolate(logits,
                size=image.size[::-1], # (height, width)
                mode='bilinear',
                align_corners=False)

# Second, apply argmax on the class dimension
seg = upsampled_logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()

In [None]:

map = Image.open('/content/ADE20k_toy_dataset/annotations/training/ADE_train_00000001.png') 
map 

In [None]:
# convert map to NumPy array
map = np.array(map)
map[map == 0] = 255 # background class is replaced by ignore_index
map = map - 1 # other classes are reduced by one
map[map == 254] = 255

classes_map = np.unique(map).tolist()
unique_classes = [model.config.id2label[idx] if idx!=255 else None for idx in classes_map]
print("Classes in this image:", unique_classes)

# create coloured map
color_seg = np.zeros((map.shape[0], map.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
    color_seg[map == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]

# Show image + mask
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)

plt.figure(figsize=(15, 10))
plt.imshow(img)
plt.show()