# Preparation

In [None]:
!pip install -q "pytorch-lightning<2.0.0"
!pip install -q transformers datasets roboflow evaluate python-dotenv

In [None]:
import os
import shutil
import evaluate
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from roboflow import Roboflow
from google.colab import drive
from roboflow import Roboflow
from dotenv import load_dotenv
from PIL import Image

In [None]:
load_dotenv()

rf_api_key = os.getenv("RF_API_KEY")
workspace = os.getenv("WORKSPACE")
project = os.getenv("PROJECT")
version = int(os.getenv("VERSION", 1))
dataset = os.getenv("DATASET")
source_dir = os.getenv("SOURCE_DIRECTORY")
save_dir = os.getenv("SAVE_DIRECTORY")

In [None]:
drive.mount('/content/drive')

rf = Roboflow(api_key=rf_api_key)
project = rf.workspace(workspace).project(project)
version = project.version(version)
dataset = version.download(dataset)

# Dataloader Preparation

In [None]:
class SemanticSegmentationDataset(Dataset):

    def __init__(self, root_dir, feature_extractor):
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.classes_csv_file = os.path.join(self.root_dir, "_classes.csv")

        with open(self.classes_csv_file, 'r') as fid:
            data = [l.split(',') for i,l in enumerate(fid) if i !=0]
        self.id2label = {x[0]:x[1] for x in data}

        image_file_names = [f for f in os.listdir(self.root_dir) if '.jpg' in f]
        mask_file_names = [f for f in os.listdir(self.root_dir) if '.png' in f]
        self.images = sorted(image_file_names)
        self.masks = sorted(mask_file_names)


    def __len__(self):
        return len(self.images)


    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_()

        return encoded_inputs

In [None]:
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b3-finetuned-ade-512-512")
feature_extractor.do_reduce_labels = False
feature_extractor.size = 512

In [None]:
train_dataset = SemanticSegmentationDataset(f"{dataset.location}/train/", feature_extractor)
val_dataset = SemanticSegmentationDataset(f"{dataset.location}/valid/", feature_extractor)
test_dataset = SemanticSegmentationDataset(f"{dataset.location}/test/", feature_extractor)

batch_size = 8
num_workers = 2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, prefetch_factor=8)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=8)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=8)

# Trainer Preparation

In [None]:
class SegformerFinetuner(pl.LightningModule):

    def __init__(self, id2label, train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=100):
        super(SegformerFinetuner, self).__init__()
        self.id2label = id2label
        self.metrics_interval = metrics_interval
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader
        self.num_classes = len(id2label.keys())
        self.label2id = {v: k for k, v in self.id2label.items()}

        self.model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b3-finetuned-ade-512-512",
            return_dict=False,
            num_labels=self.num_classes,
            id2label=self.id2label,
            label2id=self.label2id,
            ignore_mismatched_sizes=True,
        )

        self.train_mean_iou = evaluate.load("mean_iou")
        self.val_mean_iou = evaluate.load("mean_iou")
        self.test_mean_iou = evaluate.load("mean_iou")

    def forward(self, images, masks):
        outputs = self.model(pixel_values=images, labels=masks)
        return outputs

    def training_step(self, batch, batch_nb):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]

        predicted = nn.functional.interpolate(
            logits,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        self.train_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        if batch_nb % self.metrics_interval == 0:
            metrics = self.train_mean_iou.compute(
                num_labels=self.num_classes, ignore_index=255, reduce_labels=False
            )
            metrics = {
                "loss": loss,
                "mean_iou": metrics["mean_iou"],
                "mean_accuracy": metrics["mean_accuracy"],
            }

            for k, v in metrics.items():
                self.log(k, v)

            return metrics
        else:
            return {"loss": loss}

    def validation_step(self, batch, batch_nb):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]

        predicted = nn.functional.interpolate(
            logits,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        self.val_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        metrics = self.val_mean_iou.compute(
            num_labels=self.num_classes, ignore_index=255, reduce_labels=False
        )

        avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        val_mean_iou = metrics["mean_iou"]
        val_mean_accuracy = metrics["mean_accuracy"]

        metrics = {
            "val_loss": avg_val_loss,
            "val_mean_iou": val_mean_iou,
            "val_mean_accuracy": val_mean_accuracy,
        }
        for k, v in metrics.items():
            self.log(k, v)

        return metrics

    def test_step(self, batch, batch_nb):
        images, masks = batch["pixel_values"], batch["labels"]
        outputs = self(images, masks)
        loss, logits = outputs[0], outputs[1]

        predicted = nn.functional.interpolate(
            logits,
            size=masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        self.test_mean_iou.add_batch(
            predictions=predicted.detach().cpu().numpy(),
            references=masks.detach().cpu().numpy(),
        )

        return {"test_loss": loss}

    def test_epoch_end(self, outputs):
        metrics = self.test_mean_iou.compute(
            num_labels=self.num_classes, ignore_index=255, reduce_labels=False
        )

        avg_test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        test_mean_iou = metrics["mean_iou"]
        test_mean_accuracy = metrics["mean_accuracy"]
        print(metrics["per_category_iou"])
        metrics = {
            "test_loss": avg_test_loss,
            "test_mean_iou": test_mean_iou,
            "test_mean_accuracy": test_mean_accuracy,
        }

        for k, v in metrics.items():
            self.log(k, v)
        return metrics

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad], lr=1e-03, eps=1e-08
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.3, patience=5, verbose=True
        )
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def train_dataloader(self):
        return self.train_dl

    def val_dataloader(self):
        return self.val_dl

    def test_dataloader(self):
        return self.test_dl


In [None]:
segformer_finetuner = SegformerFinetuner(
    train_dataset.id2label,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    metrics_interval=10,
)

In [None]:
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.0005,
    patience=15,
    verbose=False,
    mode="min",
)

checkpoint_callback = ModelCheckpoint(save_top_k=1, save_last=True, monitor="val_loss")

In [None]:
last_ckpt = f"{save_dir}/version_{version}/checkpoints/last.ckpt"

trainer = pl.Trainer(
    gpus=1,
    callbacks=[early_stop_callback, checkpoint_callback],
    max_epochs=100,
    val_check_interval=len(train_dataloader),
    resume_from_checkpoint= last_ckpt
)

# Training

In [None]:
trainer.fit(segformer_finetuner)

# Save checkpoints

In [None]:
source_dir = source_dir
destination_dir = save_dir

if os.path.exists(destination_dir):
    shutil.rmtree(destination_dir)

shutil.copytree(source_dir, destination_dir)
print(f"Copied '{source_dir}' to '{destination_dir}' successfully!")

# Testing

In [None]:
model = SegformerFinetuner.load_from_checkpoint(
    f"{save_dir}/version_{version}/checkpoints/best.ckpt",
    id2label=train_dataset.id2label,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    metrics_interval=10,
)

trainer = pl.Trainer(gpus=1)
trainer.test(model=model, dataloaders=test_dataloader)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b3-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([6, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([6]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[0.9740799  0.79912776 0.27521162 0.30796264 0.03126659 0.40513434]


[{'test_loss': 0.1446959227323532,
  'test_mean_iou': 0.4654638073219914,
  'test_mean_accuracy': 0.5417490271867611}]