In [None]:
import matplotlib.pyplot as plt
from segments.utils import get_semantic_bitmap
from segments import SegmentsClient
from segments.huggingface import release2dataset
from datasets import load_dataset

import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw

### Setting up environment - logging into Hugging Face and Segments.ai API, defining environment variables

In [None]:
#your segments.ai api key
api_key = "your_api_key"
from huggingface_hub import notebook_login

client = SegmentsClient(api_key) #initializing segments.ai client
notebook_login() #logging into HF 

In [None]:
dataset_identifier = "dskong07/chargers-full"
name = "chargers-labeled-full-v0.1"
release_name = "chargers-labeled-full-v0.1"
hf_dataset_identifier = f"dskong07/chargers-full-v0.1"
id2label = {0: 'unlabeled', 1: 'screen', 2: 'body', 3: 'cable', 4: 'plug', 5: 'void-background'}
label2id = {v: k for k, v in id2label.items()}

### Loading the dataset 

In [None]:
from datasets import load_dataset

ds = load_dataset(hf_dataset_identifier)

In [None]:
#creating train test

ds = ds.shuffle(seed=1)
ds = ds["train"].train_test_split(test_size=0.2)
train_ds = ds["train"]
test_ds = ds["test"]

### Importing baseline model

In [None]:
from transformers import SegformerForSemanticSegmentation


pretrained_model_name = "nvidia/mit-b3" 
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id
)

### Perform data augmentation on the training dataset to make training process more robust

In [None]:
from torchvision.transforms import ColorJitter
from transformers import (
    SegformerImageProcessor,
)

processor = SegformerImageProcessor()
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) 

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch['pixel_values']]
    labels = [x for x in example_batch['label']]
    inputs = processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

### Declaring training arguments

In [None]:
from transformers import TrainingArguments

epochs = 50
lr = 0.00006
batch_size = 2

hub_model_id = "segformer-b3-finetuned-segments-chargers-full-v3.1"

training_args = TrainingArguments(
    "segformer-b1-finetuned-segments-chargers-outputs-v0.1",
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="end",
)

### Developing a method to determine training metrics - Here, we use mean Intersection over Union (IoU), using pytorch and evaluate libraries

In [None]:
import torch
from torch import nn
import evaluate
import multiprocessing

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=0,
            reduce_labels=processor.do_reduce_labels,
        )
    
    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

    return metrics

### Now training the model

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

### Uploading the model to HuggingFace

In [None]:
hub_model_id = "segformer-b3-finetuned-segments-chargers-full-v3.1"
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": hf_dataset_identifier,
}

processor.push_to_hub(hub_model_id)
trainer.push_to_hub(**kwargs)