In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git datasets
!pip install -q evaluate
!pip install accelerate

In [None]:
!pip show torch
!pip show torchvision
!pip show datasets
!pip show albumentations

In [None]:
!pip show transformers

## Load dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -o drive/MyDrive/KCDH/data_S1.zip



In [None]:

from datasets import Dataset, DatasetDict
from PIL import Image
import os

# Function to load images and labels from directory
def load_data(directory):
    images = []
    labels = []
    image_dir = os.path.join(directory, "images")
    label_dir = os.path.join(directory, "labels")
    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))
    for image_file, label_file in zip(image_files, label_files):
        # Assuming image and label files have corresponding names
        image_path = os.path.join(image_dir, image_file)
        label_path = os.path.join(label_dir, label_file)
        image = Image.open(image_path)
        label = Image.open(label_path)
        images.append(image)
        labels.append(label)
    return {"image": images, "label": labels}

# Load train and validation data
train_data = load_data("data_S1/train")
validation_data = load_data("data_S1/test")

# Create DatasetDict
dataset_dict = DatasetDict({
    "train": Dataset.from_dict(train_data),
    "validation": Dataset.from_dict(validation_data)
})

Let's take a look at the dataset in more detail. It has a train and validation split:

In [None]:
dataset_dict

In [None]:
dataset_dict['train'][1]

Let's take a look at the first training example:

In [None]:
example = dataset_dict["train"][1]
image = example["image"]
# image

In [None]:
segmentation_map = example["label"]
# segmentation_map

In [None]:
enhanced_segmentation_map = segmentation_map.point(lambda p: p*25)
# enhanced_segmentation_map

In case of semantic segmentation, every pixel is labeled with a certain class. 0 is the "background" class.

In [None]:
import numpy as np

segmentation_map = np.array(segmentation_map)
segmentation_map

Let's load the mappings between integers and their classes (I got that from the [dataset card](https://huggingface.co/datasets/EduardoPacheco/FoodSeg103#data-categories) and asked an LLM to turn it into a dictionary).

In [None]:
id2label = {
    0: "id_bg",
    1: "id_1",
    2: "id_2",
    3: "id_3",
    4: "id_4",
    5: "id_5",
    6: "id_6",
    7: "id_7",
    8: "id_8",
    9: "id_9",
    10: "id_10",
}

In [None]:
print(id2label)

We can visualize the segmentation map on top of the image, like so:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# # map every class to a random color
# id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()}
# id2color[0] = [0,0,0] # setting bg as black color
# id2color

# map every class to a random color
cmap = plt.cm.get_cmap('viridis')  # Choose any colormap you prefer
num_classes = 11
colors = [cmap(i / num_classes) for i in range(num_classes)]
id2color = {i: [int(c * 255) for c in color[:3]] for i, color in enumerate(colors)}
id2color[0] = [0, 0, 0]  # Setting background as black color
id2color

In [None]:
# display the colors for each class with their class IDs
plt.figure(figsize=(10, 2))
for i, (class_id, color) in enumerate(id2color.items()):
    plt.subplot(1, num_classes, i + 1)
    plt.imshow([[color]])
    plt.title(f"class {class_id}")
    plt.axis("off")
plt.show()

In [None]:
def visualize_map(image, segmentation_map):
    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in id2color.items():
        color_seg[segmentation_map == label, :] = color

    # Show image + mask
    img = np.array(image) * 1 + color_seg * 0.5
    img = img.astype(np.uint8)

    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

visualize_map(image, segmentation_map)

## Create PyTorch dataset

In [None]:

from torch.utils.data import Dataset
import torch

class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform):
    self.dataset = dataset
    self.transform = transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item["image"])
    original_segmentation_map = np.array(item["label"])

    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask'])

    # convert to C, H, W
    image = image.permute(2,0,1)

    return image, target, original_image, original_segmentation_map

Let's create the training and validation datasets (note that we only randomly crop for training images).

In [None]:
import albumentations as A

ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

train_transform = A.Compose([
    # hadded an issue with an image being too small to crop, PadIfNeeded didn't help...
    # if anyone knows why this is happening I'm happy to read why
    # A.PadIfNeeded(min_height=448, min_width=448),
    # A.RandomResizedCrop(height=448, width=448),
    A.Resize(width=448, height=448),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
], is_check_shapes=False)

val_transform = A.Compose([
    A.Resize(width=448, height=448),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),

], is_check_shapes=False)

train_dataset = SegmentationDataset(dataset_dict["train"], transform=train_transform)
val_dataset = SegmentationDataset(dataset_dict["validation"], transform=val_transform)

In [None]:
pixel_values, target, original_image, original_segmentation_map = train_dataset[3]
print(pixel_values.shape)
print(target.shape)

In [None]:
Image.fromarray(original_image)

In [None]:
[id2label[id] for id in np.unique(original_segmentation_map).tolist()]

## Create PyTorch dataloaders

Next, we create PyTorch dataloaders, which allow us to get batches of data (as neural networks are trained on batches using stochastic gradient descent or SGD). We just stack the various images and labels along a new batch dimension.

In [None]:
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator

# Initialize the Accelerator
accelerator = Accelerator()

# Define your collate function
def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]
    return batch

# Wrap your dataloaders with Accelerator
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
train_dataloader = accelerator.prepare(train_dataloader)

val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)
val_dataloader = accelerator.prepare(val_dataloader)


Let's check a batch:

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

In [None]:
batch = next(iter(val_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

Note that the pixel values are float32 tensors, whereas the labels are long tensors:

In [None]:
batch["pixel_values"].dtype

In [None]:
batch["labels"].dtype

In [None]:
import numpy as np
from PIL import Image

# Ensure the tensor is on the CPU
pixel_values_cpu = batch["pixel_values"][0].cpu()

# Unnormalize the image
unnormalized_image = (pixel_values_cpu.numpy() * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = Image.fromarray(unnormalized_image)
unnormalized_image.show()


In [None]:

 [id2label[id] for id in torch.unique(batch["labels"][0]).tolist()]

In [None]:
# Ensure the labels tensor is on the CPU
labels_cpu = batch["labels"][0].cpu().numpy()

visualize_map(unnormalized_image, labels_cpu)


## Define model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput



class DeconvDecoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DeconvDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1)
        )

    def forward(self, embeddings):
        B, N, C = embeddings.shape
        H = W = int(N**0.5)
        if H * W != N:
            raise ValueError("The number of patches is not a perfect square")
        embeddings = embeddings.permute(0, 2, 1).reshape(B, C, H, W)
        return self.decoder(embeddings)

class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
    def __init__(self, config, num_labels=11, **kwargs):
        super().__init__(config)
        self.dinov2 = Dinov2Model(config)
        self.decoder = DeconvDecoder(config.hidden_size, num_labels)

    def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
        outputs = self.dinov2(pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
        patch_embeddings = outputs.last_hidden_state[:, 1:, :]  # Exclude CLS token
        logits = self.decoder(patch_embeddings)

        logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

        loss = None
        if labels is not None:
            assert torch.max(labels) < logits.shape[1], f"Max label {torch.max(labels)} is greater than the number of classes {logits.shape[1]}"
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0)
            loss = loss_fct(logits, labels)

        return SemanticSegmenterOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )






# Example usage:
# config = Dinov2Config.from_pretrained('path/to/config')
# model = HybridSemanticSegmentation(dinov2_config=config, openclip_model_name="ViT-B-32", openclip_pretrained="openai")





We can instantiate the model as follows:

In [None]:
model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))


Important: we don't want to train the DINOv2 backbone, only the linear classification head. Hence we don't want to track any gradients for the backbone parameters. This will greatly save us in terms of memory used:

In [None]:
for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

Let's perform a forward pass on a random batch, to verify the shape of the logits, verify we can calculate a loss:

In [None]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the same device as the input tensors
model.to(device)

# Ensure that the input tensors are on the same device as the model
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)

# Pass the tensors to the model
outputs = model(pixel_values=pixel_values, labels=labels)
print(outputs.logits.shape)
print(outputs.loss)


In [None]:
pip install --upgrade transformers

As can be seen, the logits are of shape (batch_size, num_labels, height, width). We can then just take the highest logit (score) for each pixel as the model's prediction.

## Train the model

We'll train the model in regular PyTorch fashion. We also use the mIoU (mean Intersection-over-Union) metric to evaluate the performance during training.

Note that I made this entire notebook just for demo purposes, I haven't done any hyperparameter tuning, so feel free to improve. You can also of course use other training frameworks (like the 🤗 Trainer, PyTorch Lightning, 🤗 Accelerate, ...).

In [None]:
import evaluate
metric = evaluate.load("mean_iou")

In [None]:

import torch
from torch.optim import AdamW
from tqdm.auto import tqdm

# Define the Decoder and Dinov2ForSemanticSegmentation classes here...

!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model.to(device)

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm
!pip install torch torchvision torchaudio


# training hyperparameters
# NOTE: I've just put some random ones here, not optimized at all
# feel free to experiment, see also DINOv2 paper
#learning_rate = 5e-3
#epochs = 20
#optimizer = AdamW(model.parameters(), lr=learning_rate)

# put model on GPU (set runtime to GPU in Google Colab)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

In [None]:
pip install optuna


In [None]:
import torch
import torch.optim as optim
from transformers import AdamW
from tqdm import tqdm
import optuna
import numpy as np
warmup_steps = 800

# Define a cosine learning rate scheduler
class WarmupCosineSchedule(optim.lr_scheduler.LambdaLR):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr, max_lr):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(warmup_steps)
            progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return min_lr + 0.5 * (max_lr - min_lr) * (1. + np.cos(np.pi * progress))
        super().__init__(optimizer, lr_lambda)

# Define the momentum scheduler
class MomentumSchedule:
    def __init__(self, start_momentum, end_momentum, total_steps):
        self.start_momentum = start_momentum
        self.end_momentum = end_momentum
        self.total_steps = total_steps

    def get_momentum(self, step):
        progress = float(step) / float(self.total_steps)
        return self.start_momentum + (self.end_momentum - self.start_momentum) * 0.5 * (1. + np.cos(np.pi * progress))

# Run hyperparameter optimization with Optuna
def objective(trial):
    learning_rate = trial.suggest_loguniform("learning_rate", 0.00887, 0.1)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-5, 1e-3)

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = 4000  # Reduced from 625000
    warmup_steps = 800  # Reduced from 100000
    lr_scheduler = WarmupCosineSchedule(optimizer, warmup_steps, total_steps, min_lr=0.04, max_lr=0.2)
    momentum_scheduler = MomentumSchedule(0.994, 1.0, total_steps)

    epochs = 8  # Further reduced the number of epochs for faster tuning
    early_stopping_patience = 2  # Reduced the patience for early stopping for faster tuning
    best_mean_iou = -1
    best_model_state_dict = None
    early_stopping_counter = 0

    for epoch in range(epochs):
        total_loss = 0
        model.train()
        predictions = []
        true_labels = []
        with torch.no_grad():
            for val_batch in val_dataloader:
                val_pixel_values = val_batch["pixel_values"].to(device)
                val_labels = val_batch["labels"].to(device)
                val_outputs = model(val_pixel_values, labels=val_labels)
                predictions.extend(torch.argmax(val_outputs.logits, dim=1).cpu().numpy())
                true_labels.extend(val_labels.cpu().numpy())

        for idx, batch in enumerate(tqdm(train_dataloader)):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(pixel_values, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            total_loss += loss.item()

        total_loss /= len(train_dataloader)

        metrics = compute_metrics(predictions, true_labels, num_labels=len(id2color))
        binary_iou = metrics["binary_iou"]
        binary_accuracy = metrics["binary_accuracy"]
        mean_iou = metrics["mean_iou"]
        mean_accuracy = metrics["mean_accuracy"]
        per_class_iou_list = metrics["per_class_iou_list"]
        per_class_accuracy_list = metrics["per_class_accuracy_list"]

        print(f"Epoch: {epoch+1}, Loss: {total_loss:.4f}, Mean IoU (Multi-class): {mean_iou:.4f}, Accuracy (Multi-class): {mean_accuracy:.4f}, Mean IoU (Binary): {binary_iou:.4f}, Accuracy (Binary): {binary_accuracy:.4f}")
        for cls, (iou, accuracy) in enumerate(zip(per_class_iou_list[1:], per_class_accuracy_list), start=1):
            print(f"Class {cls }: IoU = {iou:.4f}, Accuracy = {accuracy:.4f}")

        trial.report(mean_iou, epoch)

        if mean_iou > best_mean_iou:
            best_mean_iou = mean_iou
            best_model_state_dict = model.state_dict()
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break

    return best_mean_iou

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=3)

print("Best trial:")
print(study.best_trial)
# Adjust the number of trials as needed

# Get best hyperparameters
best_params = study.best_params
best_learning_rate = best_params["learning_rate"]
best_weight_decay = best_params["weight_decay"]

# Train with best hyperparameters
optimizer = AdamW(model.parameters(), lr=best_learning_rate, weight_decay=best_weight_decay)
lr_scheduler = WarmupCosineSchedule(optimizer, warmup_steps, total_steps, min_lr=0.04, max_lr=0.2)
momentum_scheduler = MomentumSchedule(0.994, 1.0, total_steps)

# Load the best model state dict
model.load_state_dict(best_model_state_dict)




## Inference

Once we've trained a model, we can perform inference on new images as follows:

In [None]:
from PIL import Image
test_image = dataset_dict["validation"][6]["image"]
test_image

In [None]:
pixel_values = val_transform(image=np.array(test_image))["image"]
pixel_values = torch.tensor(pixel_values)
pixel_values = pixel_values.permute(2,0,1).unsqueeze(0) # convert to (batch_size, num_channels, height, width)
print(pixel_values.shape)

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(pixel_values.to(device))

In [None]:
upsampled_logits = torch.nn.functional.interpolate(outputs.logits,
                                                   size=test_image.size[::-1],
                                                   mode="bilinear", align_corners=False)
predicted_map = upsampled_logits.argmax(dim=1)

In [None]:
def visualize_map(image, segmentation_map, id2color):
    # Reshape segmentation_map to 2D if needed
    if len(segmentation_map.shape) == 1:
        side_length = int(np.sqrt(segmentation_map.shape[0]))
        segmentation_map = segmentation_map.reshape((side_length, side_length))

    # Ensure segmentation_map is a 2D array
    if len(segmentation_map.shape) != 2:
        raise ValueError("Segmentation map must be a 2D array.")

    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in id2color.items():
        color_seg[segmentation_map == label, :] = color

    # Show image + mask
    img = np.array(image) * 1 + color_seg * 0.5
    img = img.astype(np.uint8)

    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()







In [None]:
# Assuming `test_image` is the original image and `compute_metrics` is a function that computes metrics
visualize_map(test_image, predicted_map.squeeze(), id2color)


In [None]:
torch.save(model, 'model.pt')

In [None]:
torch.save(model, 'model.pth')

In [None]:
from google.colab import files
files.download('model.pt')

In [None]:
from google.colab import files
files.download('model.pth')