In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git datasets
!pip install -q evaluate

In [None]:
!pip show torch
!pip show torchvision
!pip show datasets
!pip show albumentations

In [None]:
!pip show transformers

## Load dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip drive/MyDrive/KCDH/data_S1.zip



In [None]:
from datasets import Dataset, DatasetDict
from PIL import Image
import os

# Function to load images and labels from directory
def load_data(directory):
    images = []
    labels = []
    image_dir = os.path.join(directory, "images")
    label_dir = os.path.join(directory, "labels")
    image_files = sorted(os.listdir(image_dir))
    label_files = sorted(os.listdir(label_dir))
    for image_file, label_file in zip(image_files, label_files):
        # Assuming image and label files have corresponding names
        image_path = os.path.join(image_dir, image_file)
        label_path = os.path.join(label_dir, label_file)
        image = Image.open(image_path)
        label = Image.open(label_path)
        images.append(image)
        labels.append(label)
    return {"image": images, "label": labels}

# Load train and validation data
train_data = load_data("data_S1/train")
validation_data = load_data("data_S1/test")

# Create DatasetDict
dataset_dict = DatasetDict({
    "train": Dataset.from_dict(train_data),
    "validation": Dataset.from_dict(validation_data)
})

Let's take a look at the dataset in more detail. It has a train and validation split:

In [None]:
dataset_dict

In [None]:
dataset_dict['train'][1]

Let's take a look at the first training example:

In [None]:
example = dataset_dict["train"][1]
image = example["image"]
# image

In [None]:
segmentation_map = example["label"]
# segmentation_map

In [None]:
enhanced_segmentation_map = segmentation_map.point(lambda p: p*25)
# enhanced_segmentation_map

In case of semantic segmentation, every pixel is labeled with a certain class. 0 is the "background" class.

In [None]:
import numpy as np

segmentation_map = np.array(segmentation_map)
segmentation_map

Let's load the mappings between integers and their classes (I got that from the [dataset card](https://huggingface.co/datasets/EduardoPacheco/FoodSeg103#data-categories) and asked an LLM to turn it into a dictionary).

In [None]:
id2label = {
    0: "id_bg",
    1: "id_1",
    2: "id_2",
    3: "id_3",
    4: "id_4",
    5: "id_5",
    6: "id_6",
    7: "id_7",
    8: "id_8",
    9: "id_9",
    10: "id_10",
}

In [None]:
print(id2label)

We can visualize the segmentation map on top of the image, like so:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# # map every class to a random color
# id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()}
# id2color[0] = [0,0,0] # setting bg as black color
# id2color

# map every class to a random color
cmap = plt.cm.get_cmap('viridis')  # Choose any colormap you prefer
num_classes = 11
colors = [cmap(i / num_classes) for i in range(num_classes)]
id2color = {i: [int(c * 255) for c in color[:3]] for i, color in enumerate(colors)}
id2color[0] = [0, 0, 0]  # Setting background as black color
id2color

In [None]:
# display the colors for each class with their class IDs
plt.figure(figsize=(10, 2))
for i, (class_id, color) in enumerate(id2color.items()):
    plt.subplot(1, num_classes, i + 1)
    plt.imshow([[color]])
    plt.title(f"class {class_id}")
    plt.axis("off")
plt.show()

In [None]:
def visualize_map(image, segmentation_map):
    color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in id2color.items():
        color_seg[segmentation_map == label, :] = color

    # Show image + mask
    img = np.array(image) * 1 + color_seg * 0.5
    img = img.astype(np.uint8)

    plt.figure(figsize=(15, 10))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

visualize_map(image, segmentation_map)

## Create PyTorch dataset

In [None]:
from torch.utils.data import Dataset
import torch

class SegmentationDataset(Dataset):
  def __init__(self, dataset, transform):
    self.dataset = dataset
    self.transform = transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    item = self.dataset[idx]
    original_image = np.array(item["image"])
    original_segmentation_map = np.array(item["label"])

    transformed = self.transform(image=original_image, mask=original_segmentation_map)
    image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask'])

    # convert to C, H, W
    image = image.permute(2,0,1)

    return image, target, original_image, original_segmentation_map

Let's create the training and validation datasets (note that we only randomly crop for training images).

In [None]:
import albumentations as A

ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255

train_transform = A.Compose([
    # hadded an issue with an image being too small to crop, PadIfNeeded didn't help...
    # if anyone knows why this is happening I'm happy to read why
    # A.PadIfNeeded(min_height=448, min_width=448),
    # A.RandomResizedCrop(height=448, width=448),
    A.Resize(width=448, height=448),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
], is_check_shapes=False)

val_transform = A.Compose([
    A.Resize(width=448, height=448),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),

], is_check_shapes=False)

train_dataset = SegmentationDataset(dataset_dict["train"], transform=train_transform)
val_dataset = SegmentationDataset(dataset_dict["validation"], transform=val_transform)

In [None]:
pixel_values, target, original_image, original_segmentation_map = train_dataset[3]
print(pixel_values.shape)
print(target.shape)

In [None]:
Image.fromarray(original_image)

In [None]:
[id2label[id] for id in np.unique(original_segmentation_map).tolist()]

## Create PyTorch dataloaders

Next, we create PyTorch dataloaders, which allow us to get batches of data (as neural networks are trained on batches using stochastic gradient descent or SGD). We just stack the various images and labels along a new batch dimension.

In [None]:
from torch.utils.data import DataLoader

def collate_fn(inputs):
    batch = dict()
    batch["pixel_values"] = torch.stack([i[0] for i in inputs], dim=0)
    batch["labels"] = torch.stack([i[1] for i in inputs], dim=0)
    batch["original_images"] = [i[2] for i in inputs]
    batch["original_segmentation_maps"] = [i[3] for i in inputs]
    return batch

train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=3, shuffle=False, collate_fn=collate_fn)

Let's check a batch:

In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

In [None]:
batch = next(iter(val_dataloader))
for k,v in batch.items():
  if isinstance(v,torch.Tensor):
    print(k,v.shape)

Note that the pixel values are float32 tensors, whereas the labels are long tensors:

In [None]:
batch["pixel_values"].dtype

In [None]:
batch["labels"].dtype

In [None]:
from PIL import Image

unnormalized_image = (batch["pixel_values"][0].numpy() * np.array(ADE_STD)[:, None, None]) + np.array(ADE_MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
unnormalized_image = Image.fromarray(unnormalized_image)
unnormalized_image

In [None]:
[id2label[id] for id in torch.unique(batch["labels"][0]).tolist()]

In [None]:
visualize_map(unnormalized_image, batch["labels"][0].numpy())

## Define model

In [None]:
import torch
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput

class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=32, tokenH=32, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.classifier = torch.nn.Conv2d(in_channels, num_labels, (1,1))


    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(embeddings)


class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.dinov2 = Dinov2Model(config)
    self.classifier = LinearClassifier(config.hidden_size, 32, 32, config.num_labels)

  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
    # use frozen features
    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    # get the patch embeddings - so we exclude the CLS token
    patch_embeddings = outputs.last_hidden_state[:,1:,:]

    # convert to logits and upsample to the size of the pixel values
    logits = self.classifier(patch_embeddings)
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)

    loss = None
    if labels is not None:
      # important: we're going to use 0 here as ignore index instead of the default -100
      # as we don't want the model to learn to predict background
      loss_fct = torch.nn.CrossEntropyLoss(ignore_index=0) # change it to -100
    #   loss = loss_fct(logits.squeeze(), labels.squeeze())
      loss = loss_fct(logits, labels)

    return SemanticSegmenterOutput(
        loss=loss,
        logits=logits,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )

We can instantiate the model as follows:

In [None]:
model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))

Important: we don't want to train the DINOv2 backbone, only the linear classification head. Hence we don't want to track any gradients for the backbone parameters. This will greatly save us in terms of memory used:

In [None]:
for name, param in model.named_parameters():
  if name.startswith("dinov2"):
    param.requires_grad = False

Let's perform a forward pass on a random batch, to verify the shape of the logits, verify we can calculate a loss:

In [None]:
outputs = model(pixel_values=batch["pixel_values"], labels=batch["labels"])
print(outputs.logits.shape)
print(outputs.loss)

As can be seen, the logits are of shape (batch_size, num_labels, height, width). We can then just take the highest logit (score) for each pixel as the model's prediction.

## Train the model

We'll train the model in regular PyTorch fashion. We also use the mIoU (mean Intersection-over-Union) metric to evaluate the performance during training.

Note that I made this entire notebook just for demo purposes, I haven't done any hyperparameter tuning, so feel free to improve. You can also of course use other training frameworks (like the 🤗 Trainer, PyTorch Lightning, 🤗 Accelerate, ...).

In [None]:
import evaluate

metric = evaluate.load("mean_iou")

In [None]:
from torch.optim import AdamW
from tqdm.auto import tqdm

# training hyperparameters
# NOTE: I've just put some random ones here, not optimized at all
# feel free to experiment, see also DINOv2 paper
learning_rate = 5e-3
#epochs = 20
optimizer = AdamW(model.parameters(), lr=learning_rate)

# put model on GPU (set runtime to GPU in Google Colab)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

In [None]:
from hyperopt import hp, fmin, tpe, Trials, STATUS_OK

# Define the search space
space = {
    'learning_rate': hp.choice('learning_rate', [1e-3, 5e-4, 1e-4]),
    'max_epochs': hp.choice('max_epochs', [10, 20, 30])
}

# Define the objective function
def objective(params):
    learning_rate = params['learning_rate']
    max_epochs = params['max_epochs']

    # Train the model and calculate validation loss (not shown for brevity)
    # Use a validation dataset or cross-validation

    val_loss = compute_validation_loss(model, val_dataloader, device)

    return {'loss': val_loss, 'status': STATUS_OK}

# Initialize Trials object
trials = Trials()

# Perform Bayesian Optimization
best = fmin(fn=objective,
            space=space,
            algo=tpe.suggest,
            max_evals=10,  # Number of optimization iterations
            trials=trials)

# Print the best hyperparameters
best_learning_rate = space['learning_rate'][best['learning_rate']]
best_max_epochs = space['max_epochs'][best['max_epochs']]
print(f"Best Learning Rate: {best_learning_rate}, Best Max Epochs: {best_max_epochs}")


In [None]:
from tqdm import tqdm

# Initialize the number of epochs
max_epochs = 100  # Set a high upper limit for epochs

# Early stopping parameters
patience = 5 # Number of epochs to wait for improvement
min_delta = 0.001  # Minimum change to qualify as an improvement
best_val_loss = np.inf
epochs_no_improve = 0

# Function to compute validation loss
def compute_validation_loss(model, val_dataloader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(pixel_values, labels=labels)
            val_loss += outputs.loss.item()
    return val_loss / len(val_dataloader)

# Put model in training mode
model.train()

for epoch in range(max_epochs):
    print("Epoch:", epoch)
    for idx, batch in enumerate(tqdm(train_dataloader)):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(pixel_values, labels=labels)
        loss = outputs.loss

        # Print shapes
        print("Logits shape:", outputs.logits.shape)
        print("Labels shape:", labels.shape)

        loss.backward()
        optimizer.step()

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Evaluate
        with torch.no_grad():
            predicted = outputs.logits.argmax(dim=1)

            # Note that the metric expects predictions + labels as numpy arrays
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())

        # Let's print loss and metrics every batch
        if idx % 1 == 0:
            metrics = metric.compute(num_labels=len(id2label),
                                     ignore_index=0,
                                     reduce_labels=False)
            print("idx:", idx)
            print("Loss:", loss.item())
            print("Mean_iou:", metrics["mean_iou"])
            print("Mean accuracy:", metrics["mean_accuracy"])
            print("------------------------------------------")

    # Validation loss computation
    val_loss = compute_validation_loss(model, val_dataloader, device)
    print(f"Validation Loss: {val_loss}")

    # Early stopping check
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}")
            break


## Inference

Once we've trained a model, we can perform inference on new images as follows:

In [None]:
from PIL import Image
test_image = dataset_dict["validation"][6]["image"]
test_image

In [None]:
pixel_values = val_transform(image=np.array(test_image))["image"]
pixel_values = torch.tensor(pixel_values)
pixel_values = pixel_values.permute(2,0,1).unsqueeze(0) # convert to (batch_size, num_channels, height, width)
print(pixel_values.shape)

In [None]:
# forward pass
with torch.no_grad():
  outputs = model(pixel_values.to(device))

In [None]:
upsampled_logits = torch.nn.functional.interpolate(outputs.logits,
                                                   size=test_image.size[::-1],
                                                   mode="bilinear", align_corners=False)
predicted_map = upsampled_logits.argmax(dim=1)

In [None]:
visualize_map(dataset_dict["validation"][6]["image"], np.array(dataset_dict["validation"][6]["label"]))

In [None]:
visualize_map(test_image, predicted_map.squeeze().cpu())

In [None]:
visualize_map(test_image, predicted_map.squeeze().cpu())

In [None]:
torch.save(model, 'model.pt')

In [None]:
torch.save(model, 'model.pth')

In [None]:
from google.colab import files
files.download('model.pt')

In [None]:
from google.colab import files
files.download('model.pth')