[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mikeagz/Fine-tune-a-Segformer/blob/main/Segformer.ipynb)

In [None]:
# @title # Necessary dependencies

!pip install -q transformers segments-ai datasets evaluate accelerate
!pip install --upgrade segments-ai

In [None]:
# @title HuggingFace Login

from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

The dataset is in my Segments.ai workspace, so I just need to connect via an API to access it. Contact me if you need access.

In [None]:
from segments import SegmentsClient

api_key = "XXXXXXXXXXXXXXXXXX"

client = SegmentsClient(api_key)

In [None]:
dataset_identifier = "neo/greenhouse"
name = "v0.1"

In [None]:
from segments.huggingface import release2dataset

release = client.get_release(dataset_identifier, name)
hf_dataset = release2dataset(release)

Map:   0%|          | 0/17 [00:00<?, ? examples/s]

Map:   0%|          | 0/17 [00:00<?, ? examples/s]

In [None]:
hf_dataset.features

{'name': Value(dtype='string', id=None),
 'uuid': Value(dtype='string', id=None),
 'status': Value(dtype='string', id=None),
 'image': Image(decode=True, id=None),
 'label.annotations': [{'id': Value(dtype='int32', id=None),
   'category_id': Value(dtype='int32', id=None)}],
 'label.segmentation_bitmap': Image(decode=True, id=None)}

In [None]:
from segments.utils import get_semantic_bitmap


def convert_segmentation_bitmap(example):
    return {
        "label.segmentation_bitmap":
            get_semantic_bitmap(
                example["label.segmentation_bitmap"],
                example["label.annotations"],
            )
    }


semantic_dataset = hf_dataset.map(
    convert_segmentation_bitmap,
)

Map:   0%|          | 0/17 [00:00<?, ? examples/s]



In [None]:
semantic_dataset = semantic_dataset.rename_column('image', 'pixel_values')
semantic_dataset = semantic_dataset.rename_column('label.segmentation_bitmap', 'label')
semantic_dataset = semantic_dataset.remove_columns(['name', 'uuid', 'status', 'label.annotations'])

In [None]:
semantic_dataset.features

I will train other models, so it is convenient to store the dataset in HF.

In [None]:
semantic_dataset.push_to_hub(f"MexicanVanGogh/greenhouse")

Map:   0%|          | 0/9 [00:00<?, ? examples/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Import of the data set to verify that it is appropriate.

In [None]:
from datasets import load_dataset

hf_dataset_identifier="MexicanVanGogh/greenhouse"
ds = load_dataset(hf_dataset_identifier)

Downloading readme:   0%|          | 0.00/485 [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/35.7M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/31.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/17 [00:00<?, ? examples/s]

In [None]:
ds

DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 17
    })
})

In [None]:
ds = ds.shuffle(seed=1)
ds = ds["train"].train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

In [None]:
ds

DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 13
    })
    test: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 4
    })
})

In [None]:
import json
from huggingface_hub import hf_hub_download

filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

num_labels = len(id2label)
print("Id2label:", id2label)

Downloading (…)e/main/id2label.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

Id2label: {0: 'unlabeled', 1: 'object', 2: 'road', 3: 'plant', 4: 'iron', 5: 'wood', 6: 'wall', 7: 'raw_road', 8: 'bottom_wall', 9: 'roof', 10: 'grass'}


Preprocessing is applied

In [None]:
from torchvision.transforms import ColorJitter
from transformers import (
    SegformerImageProcessor,
)

processor = SegformerImageProcessor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

Segformer pre-trained model is imported

In [None]:
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/70.0k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/14.4M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.2.proj.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.batch_norm.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.batch_norm.running_mean', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.batch_norm.running_var', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.0.proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Specify training parameters

In [None]:
from transformers import TrainingArguments

epochs = 30
lr = 0.00006
batch_size = 2

hub_model_id = "segformer-b0-finetuned-segments-greenhouse-oct-23"

training_args = TrainingArguments(
    "segformer-b0-finetuned-segments-greenhouse-outputs",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="end",
)

Adding metrics:

Mean Intersection over Union (Mean IoU) is a widely used performance metric in the field of computer vision, particularly in image segmentation tasks. It quantifies the accuracy of a segmentation model by measuring the overlap between the predicted and ground truth regions in an image. IoU is calculated as the intersection area divided by the union area of the predicted and ground truth regions. Mean IoU extends this concept by averaging the IoU scores across all classes or instances, providing a comprehensive measure of segmentation accuracy. High Mean IoU values indicate better model performance, making it a valuable metric for assessing the effectiveness of segmentation algorithms.

In [None]:
import torch
from torch import nn
import evaluate
import multiprocessing

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=0,
            reduce_labels=processor.do_reduce_labels,
        )

    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

    return metrics

Downloading builder script:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Training model

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Unlabeled,Accuracy Object,Accuracy Road,Accuracy Plant,Accuracy Iron,Accuracy Wood,Accuracy Wall,Accuracy Raw Road,Accuracy Bottom Wall,Accuracy Roof,Accuracy Grass,Iou Unlabeled,Iou Object,Iou Road,Iou Plant,Iou Iron,Iou Wood,Iou Wall,Iou Raw Road,Iou Bottom Wall,Iou Roof,Iou Grass
20,1.8756,2.00631,0.141535,0.226872,0.821646,,,0.788192,0.967378,0.0,0.059408,0.0,,0.0,0.0,0.0,0.0,0.0,0.776012,0.755249,0.0,0.02562,0.0,0.0,0.0,0.0,0.0
40,1.3624,1.091003,0.171529,0.238002,0.899108,,,0.920603,0.97573,0.0,0.007686,0.0,,0.0,0.0,0.0,0.0,,0.888833,0.821988,0.0,0.004464,0.0,0.0,0.0,0.0,0.0
60,1.4095,0.903261,0.173382,0.23921,0.906838,,,0.926412,0.987272,0.0,0.0,0.0,,0.0,0.0,0.0,0.0,,0.900025,0.833794,0.0,0.0,0.0,0.0,0.0,0.0,0.0
80,0.8802,0.77843,0.176394,0.241431,0.916452,,,0.946958,0.982251,0.0,0.002242,0.0,,0.0,0.0,0.0,0.0,,0.91548,0.846328,0.0,0.002132,0.0,0.0,0.0,0.0,0.0
100,1.0936,0.805968,0.194641,0.240491,0.913179,,,0.939998,0.983933,0.0,0.0,0.0,,0.0,0.0,0.0,,,0.909965,0.841809,0.0,0.0,0.0,0.0,0.0,0.0,0.0
120,0.8086,0.778607,0.193962,0.240228,0.911474,,,0.936098,0.985169,0.0,0.0,0.0,,0.0,0.000555,0.0,,,0.907066,0.838033,0.0,0.0,0.0,0.0,0.0,0.000555,0.0
140,1.0669,0.746192,0.207243,0.256183,0.908785,,,0.928217,0.98535,0.0,0.0,0.011271,,0.0,0.124624,0.0,,,0.901043,0.838471,0.0,0.0,0.010153,0.0,0.0,0.115517,0.0
160,0.7399,0.732849,0.213662,0.266223,0.907985,,,0.929022,0.978761,0.0,0.0,0.08145,,0.0,0.140548,0.0,,,0.899705,0.838856,0.0,0.0,0.06629,0.0,0.0,0.118109,0.0
180,0.808,0.729625,0.2218,0.279671,0.907248,,,0.927687,0.97424,0.0,0.0,0.183959,,0.0,0.151482,0.0,,,0.898078,0.840387,0.0,0.0,0.142261,0.0,0.0,0.115473,0.0
200,0.8494,0.705815,0.222686,0.280428,0.910078,,,0.93781,0.966699,0.0,0.0,0.193214,,0.0,0.145698,0.0,,,0.903893,0.842072,0.0,0.0,0.152068,0.0,0.0,0.106141,0.0


  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label


TrainOutput(global_step=210, training_loss=1.1565589444977897, metrics={'train_runtime': 784.1574, 'train_samples_per_second': 0.497, 'train_steps_per_second': 0.268, 'total_flos': 6840159153684480.0, 'train_loss': 1.1565589444977897, 'epoch': 30.0})

In [None]:
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": hf_dataset_identifier,
}

processor.push_to_hub(hub_model_id)
trainer.push_to_hub(**kwargs)

pytorch_model.bin:   0%|          | 0.00/14.9M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

'https://huggingface.co/MexicanVanGogh/segformer-b0-finetuned-segments-greenhouse-oct-23/tree/main/'