In [1]:
from typing import Optional, Tuple, List, Dict, Any
from dataclasses import dataclass, field

import numpy as np
import torch
import wandb
import evaluate
from torch import optim, nn
from datasets import load_dataset
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from torch.utils.data import DataLoader
import albumentations as A

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@dataclass
class TrainingConfig:
    """
    Training configuration for the model.

    Args:
        batch_size: Batch size for training.
        epochs: Number of epochs to train for.
        learning_rate: Learning rate for the optimizer.
        background_weight: Weight for the background class.
        other_classes_weight: Weight for the other classes.
        lr_decay_rate: Learning rate decay rate.
        seed: Random seed for reproducibility.
        model_name: Name of the model to use.
        project_name: Name of the project for wandb.
        device: Device to use for training.
    """
    batch_size: int = 8
    epochs: int = 6
    learning_rate: float = 1e-4
    background_weight: float = 1.0
    other_classes_weight: float = 3.0
    lr_decay_rate: float = 0.9998
    seed: int = 42
    model_name: str = "nvidia/mit-b2"
    project_name: str = "Clothes segmentation"
    device: Optional[str] = field(
        default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu"
    )

    def as_dict(self) -> Dict[str, Any]:
        return vars(self)


config = TrainingConfig(
    batch_size=8,
    epochs=6,
    learning_rate=1e-4,
    background_weight=1.0,
    other_classes_weight=3.0,
    lr_decay_rate=0.9998,
    model_name="nvidia/mit-b2",
    project_name="Clothes segmentation",
)
wandb_config = config.as_dict()

In [3]:
# Dataset Loading and Preparation
ds = load_dataset("mattmdjaga/human_parsing_dataset", split="train[:100%]", num_proc=8)
ds.shuffle(seed=config.seed)

split_ratio: float = 0.006
split_size: int = int(len(ds) * split_ratio)
print(f"Split size: {split_size}")
ds_split = ds.train_test_split(test_size=split_ratio, seed=config.seed)
train_ds = ds_split["train"]
test_ds = ds_split["test"]

Split size: 106


In [4]:
id2label: Dict[str, str] = {
    "0": "Background",
    "1": "Hat",
    "2": "Hair",
    "3": "Sunglasses",
    "4": "Upper-clothes",
    "5": "Skirt",
    "6": "Pants",
    "7": "Dress",
    "8": "Belt",
    "9": "Left-shoe",
    "10": "Right-shoe",
    "11": "Face",
    "12": "Left-leg",
    "13": "Right-leg",
    "14": "Left-arm",
    "15": "Right-arm",
    "16": "Bag",
    "17": "Scarf",
}
label2id: Dict[str, str] = {v: k for k, v in id2label.items()}
num_labels: int = len(id2label)

In [5]:
# Model and Tokenizer Initialization
tokenizer = SegformerImageProcessor.from_pretrained(config.model_name)
model = AutoModelForSemanticSegmentation.from_pretrained(
    config.model_name, num_labels=num_labels, id2label=id2label, label2id=label2id
).to(config.device)



Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b2 and are newly initialized: ['decode_head.classifier.weight', 'decode_head.linear_c.3.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.linear_fuse.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.2.proj.bias', 'decode_head.batch_norm.running_mean', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_var', 'decode_head.linear_c.0.proj.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Image Transformations
img_transforms = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.OneOf(
            [
                A.Downscale(p=0.1, scale_min=0.4, scale_max=0.6),
                A.GaussNoise(p=0.2),
            ],
            p=0.1,
        ),
        A.OneOf(
            [
                A.RandomBrightnessContrast(p=0.2),
                A.ColorJitter(p=0.2),
                A.HueSaturationValue(p=0.2),
            ],
            p=0.1,
        ),
        A.OneOf([A.PixelDropout(p=0.2), A.RandomGravel(p=0.2)], p=0.15),
    ]
)



In [7]:
# Dataset Transformation Functions
def train_transforms(example_batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    """
    Transform the dataset for training.

    Args:
        example_batch: Batch of examples from the dataset.

    Returns:
        A dictionary of the inputs to the model.
    """
    trans = [
        img_transforms(image=np.array(x), mask=np.array(m))
        for x, m in zip(example_batch["image"], example_batch["mask"])
    ]
    images = [x["image"] for x in trans]
    labels = [x["mask"] for x in trans]
    inputs = tokenizer(images, labels)
    return inputs


def val_transforms(example_batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    """
    Transform the dataset for validation.

    Args:
        example_batch: Batch of examples from the dataset.

    Returns:
        A dictionary of the inputs to the model.
    """
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["mask"]]
    inputs = tokenizer(images, labels)
    return inputs


train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

In [8]:
# Metrics and Validation
metric = evaluate.load("mean_iou")


@torch.no_grad()
def compute_metrics(eval_pred: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]:
    """
    Compute the IOU and accuracy metrics.

    Args:
        eval_pred: Tuple of logits and labels.

    Returns:
        A dictionary of the metrics.
    """
    logits, labels = eval_pred
    logits_tensor = logits.argmax(dim=1)

    pred_labels = logits_tensor

    mets = metric.compute(
        predictions=pred_labels,
        references=labels,
        num_labels=num_labels,
        ignore_index=255,
        reduce_labels=False,
    )
    for key, value in mets.items():
        if type(value) is np.ndarray:
            mets[key] = value.tolist()

    return mets


@torch.no_grad()
def validation(
    model: nn.Module, val_loader: DataLoader
) -> Tuple[torch.Tensor, torch.Tensor, List[float]]:
    """
    Perform validation on the model.

    Args:
        model: Model to validate.
        val_loader: Validation data loader.

    Returns:
        Tuple of logits, labels, and validation losses.
    """
    model.eval()
    val_losses = []
    all_labels = []
    all_logits = []

    for i, batch in enumerate(val_loader, 1):
        inputs = batch["pixel_values"].to(config.device)
        labels = batch["labels"].to(config.device)

        outputs = model(inputs)
        logits_tensor = nn.functional.interpolate(
            outputs.logits,  # Detach to avoid saving gradients
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )

        loss = loss_func(logits_tensor, labels)
        val_losses.append(loss.item())

        # Storing logits and labels as CPU tensors to save GPU memory
        all_labels.append(labels.cpu())
        all_logits.append(logits_tensor.cpu())

    # Concatenate all logits and labels
    logits = torch.cat(all_logits, dim=0).to(config.device)
    labels = torch.cat(all_labels, dim=0).to(config.device)

    model.train()
    return logits, labels, val_losses

In [9]:
def save_model(model: nn.Module, tokenizer: SegformerImageProcessor, name: str) -> None:
    model.save_pretrained(name)
    tokenizer.save_pretrained(name)

In [10]:
# DataLoader, Optimizer, and Scheduler
train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False)
optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
lambda1 = lambda step: config.lr_decay_rate**step
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)

weights = torch.tensor(
    [config.background_weight] + [config.other_classes_weight] * 17
).to(config.device)
loss_func = nn.CrossEntropyLoss(weight=weights)

In [None]:
running_loss: List[float] = []

# Generating a run name using the configuration parameters
run_name: str = "_".join([f"{key}_{value}" for key, value in wandb_config.items()])
t_steps: int = 0  # Used to get metrics at an interval

# Initialize wandb with the configuration
wandb.init(project=config.project_name, config=wandb_config)
wandb.run.name = run_name
wandb.run.save()

In [None]:
for epoch in range(config.epochs):
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        inputs = batch["pixel_values"].to(config.device)
        labels = batch["labels"].to(config.device)
        outputs = model(inputs)
        # The models predicts small masks, so we need to upsample them to the correct size like in inference
        logits_tensor = nn.functional.interpolate(
            outputs.logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        loss = loss_func(logits_tensor, labels)
        running_loss.append(loss.item())

        loss.backward()
        optimizer.step()
        scheduler.step()

        if t_steps % 600 == 0:
            wandb_logs = {}

            last_lr = scheduler.get_last_lr()[0]

            logits_tensor, labels, val_losses = validation(model, val_loader)
            r_loss = sum(running_loss) / len(running_loss)
            val_loss = sum(val_losses) / len(val_losses)

            mets = compute_metrics((logits_tensor, labels))

            wandb_logs["training_loss"] = r_loss
            wandb_logs["val_loss"] = val_loss

            for key, value in mets.items():
                if isinstance(value, float):
                    wandb_logs[key] = value

            print(f"\nEpoch {epoch} Iteration {i}")
            for key, score in wandb_logs.items():
                print(f"{key}: {score:.3f}")

            print(f"LR: {last_lr}")
            wandb_logs["LR"] = last_lr
            wandb.log(wandb_logs)
            running_loss = []
        t_steps += 1
    save_model(model, tokenizer, f"{epoch}")