In [None]:
import json
import numpy as np
import torch
import torch.nn as nn

from transformers import AutoImageProcessor
from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer
from huggingface_hub import notebook_login, cached_download, hf_hub_url
from datasets import load_dataset
import evaluate
from torchvision.transforms import ColorJitter

In [None]:
repo_id = "user_id/dataset"  # input respiratory id to download datasets
file_label = "id2label.json"  # input file label name

# Load and preprocess the dataset
ds = load_dataset(repo_id)  # load dataset from Hugging Face Hub
train_ds = ds["train"]
val_ds = ds["validation"]
filename = file_label
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

# load checkpoint
checkpoint = "nvidia/mit-b0"  # input checkpoint name
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) # set color jitter
image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
metric = evaluate.load("mean_iou") # load metric

In [None]:
def train_transforms(example_batch):
    """
    transform training images and labels for training by applying random jitter
    """
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["validation"]]
    inputs = image_processor(images, labels)
    return inputs


def val_transforms(example_batch):
    """
    transform validation images and labels for validation
    """
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["validation"]]
    inputs = image_processor(images, labels)
    return inputs


def compute_metrics(eval_pred):
    """
    compute metrics for validation
    """
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        pred_labels = logits_tensor.detach().cpu().numpy()
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=num_labels,
            ignore_index=255,
            reduce_labels=False,
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics

In [None]:
notebook_login()  # Login to Hugging Face Hub

In [None]:
# set transform for training and validation and define metric
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms) 

In [None]:
# define image processor and model from checkpoint
model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id) # load model from checkpoint

In [None]:
# define training arguments and trainer and start training

training_args = TrainingArguments(
    output_dir="segment_50ep",
    learning_rate=6e-5,
    num_train_epochs=50,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
# push model and image_processor to hub
trainer.push_to_hub() 
image_processor.push_to_hub("user_id/segment_50ep") #input your repo id