# SegFormer with RoPE model at b0 and 512px input images.


## mode: axial


# 0. imports


In [1]:
pip install -q datasets transformers evaluate icecream

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/480.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip list | grep datasets

In [1]:
import gc
import json
import math
from functools import partial
from pathlib import Path

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from icecream import ic
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import ColorJitter
from tqdm.notebook import trange
from transformers import (
  AutoImageProcessor,
  AutoModelForSemanticSegmentation,
  SegformerConfig,
  SegformerForSemanticSegmentation,
  SegformerImageProcessor,
  SegformerModel,
  Trainer,
  TrainingArguments,
)
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from google.colab import drive

drive.mount("./drive")

# 1. dataset


In [2]:
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

In [3]:
# dataset = dataset.shuffle(seed=42).select(range(10000)).train_test_split(test_size=0.1)
# train_dataset = dataset["train"]
# test_dataset = dataset["test"]

dataset = load_dataset("scene_parse_150")
ic(dataset)
train_dataset = dataset["train"].shuffle(seed=42)
validation_dataset = dataset["validation"].shuffle(seed=42).select(range(1000))

ic(train_dataset)
ic(validation_dataset);

ic| dataset: DatasetDict({
                 train: Dataset({
                     features: ['image', 'annotation', 'scene_category'],
                     num_rows: 20210
                 })
                 test: Dataset({
                     features: ['image', 'annotation', 'scene_category'],
                     num_rows: 3352
                 })
                 validation: Dataset({
                     features: ['image', 'annotation', 'scene_category'],
                     num_rows: 2000
                 })
             })
ic| train_dataset: Dataset({
                       features: ['image', 'annotation', 'scene_category'],
                       num_rows: 20210
                   })
ic| validation_dataset: Dataset({
                            features: ['image', 'annotation', 'scene_category'],
                            num_rows: 1000
                        })


# 2. Preproces


In [4]:
checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"
image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
image_processor

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}

In [5]:
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)


def train_transforms(example_batch):
  images = [jitter((x).convert("RGB")) if x.mode == "L" else jitter(x) for x in example_batch["image"]]
  labels = list(example_batch["annotation"])
  inputs = image_processor(images, labels)
  return inputs


def val_transforms(example_batch):
  images = [x.convert("RGB") if x.mode == "L" else x for x in example_batch["image"]]
  labels = list(example_batch["annotation"])

  inputs = image_processor(images, labels)
  return inputs


train_dataset.set_transform(train_transforms)
validation_dataset.set_transform(val_transforms)

In [12]:
ic(validation_dataset[0]["labels"])
ic(train_dataset[0]["labels"].shape);

ic| validation_dataset[0]["labels"]: array([[255, 255, 255, ..., 255, 255, 255],
                                            [255, 255, 255, ...,   2,   2,   2],
                                            [255,   2,   2, ...,   2,   2,   2],
                                            ...,
                                            [255,  11,  11, ...,  11,  11, 255],
                                            [255,  11,  11, ...,  11,  11, 255],
                                            [255, 255, 255, ...,  11,  11, 255]], shape=(512, 512))
ic| train_dataset[0]["labels"].shape: (512, 512)


# 3. Compute Metorics


In [None]:
metric = evaluate.load("mean_iou")


def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    labels = labels.astype(np.uint8)

    batch_size = 100
    num_samples = logits.shape[0]

    pred_labels = []

    for i in trange(0, num_samples, batch_size, leave=False):
      logits_batch = logits[i : i + batch_size]
      logits_batch = torch.from_numpy(logits_batch)
      logits_batch = nn.functional.interpolate(
        logits_batch,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
      )
      pred_labels.append(logits_batch.argmax(dim=1).detach().cpu().numpy().astype(np.uint8))
    pred_labels = np.concatenate(pred_labels, axis=0)

    metrics = metric.compute(
      predictions=pred_labels,
      references=labels,
      num_labels=num_labels,
      ignore_index=255,
      reduce_labels=False,
    )

    keys_to_remove = [key for key, value in metrics.items() if isinstance(value, np.ndarray)]
    for key in keys_to_remove:
      del metrics[key]

    gc.collect()

    return metrics

# 4. Model


## 4.1 実装


In [None]:
# reference: transformers.segformer


def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
  freqs_x = []  # (num_heads, dim//4)
  freqs_y = []  # (num_heads, dim//4)
  mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))  # (dim//4)
  for _ in range(num_heads):
    angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)  # (1)
    fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1)  # (dim//2)
    fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1)
    freqs_x.append(fx)
    freqs_y.append(fy)
  freqs_x = torch.stack(freqs_x, dim=0)
  freqs_y = torch.stack(freqs_y, dim=0)
  freqs = torch.stack([freqs_x, freqs_y], dim=0)
  return freqs  # (2, num_heads, dim // 2)


def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

  Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  argument.
  """
  if drop_prob == 0.0 or not training:
    return input
  keep_prob = 1 - drop_prob
  shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
  random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  random_tensor.floor_()  # binarize
  output = input.div(keep_prob) * random_tensor
  return output


def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
  N = t_x.shape[0]
  depth = freqs.shape[1]  # (2, len(self.blocks), C//2)[1]
  # No float 16 for this range
  with torch.amp.autocast(device_type="cuda", enabled=False):
    freqs_x = (
      (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
      # (N, depth, C//2)
      .view(depth, N, num_heads, -1)  # (depth, N, num_heads, C_per_head//2)
      .permute(0, 2, 1, 3)
    )
    freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
    freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)

  return freqs_cis  # (depth, num_heads, N, dim//2)


def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

  t_x, t_y = init_t_xy(end_x, end_y)
  freqs_x = torch.outer(t_x, freqs_x)
  freqs_y = torch.outer(t_y, freqs_y)
  freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
  freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
  return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)


def init_t_xy(end_x: int, end_y: int):
  t = torch.arange(end_x * end_y, dtype=torch.float32)
  t_x = (t % end_x).float()
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
  return t_x, t_y


class SegformerSelfOutput(nn.Module):
  def __init__(self, config, hidden_size):
    super().__init__()
    self.dense = nn.Linear(hidden_size, hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.dropout(hidden_states)
    return hidden_states


class SegformerDropPath(nn.Module):
  """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

  def __init__(self, drop_prob: float | None = None) -> None:
    super().__init__()
    self.drop_prob = drop_prob

  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    return drop_path(hidden_states, self.drop_prob, self.training)

  def extra_repr(self) -> str:
    return "p={}".format(self.drop_prob)


class SegformerOverlapPatchEmbeddings(nn.Module):
  """Construct the overlapping patch embeddings."""

  def __init__(self, patch_size, stride, num_channels, hidden_size):
    super().__init__()
    self.proj = nn.Conv2d(
      num_channels,
      hidden_size,
      kernel_size=patch_size,
      stride=stride,
      padding=patch_size // 2,
    )

    self.layer_norm = nn.LayerNorm(hidden_size)

  def forward(self, pixel_values):
    embeddings = self.proj(pixel_values)
    _, _, height, width = embeddings.shape
    # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
    # this can be fed to a Transformer layer
    embeddings = embeddings.flatten(2).transpose(1, 2)
    embeddings = self.layer_norm(embeddings)
    return embeddings, height, width


class SegformerDWConv(nn.Module):
  def __init__(self, dim=768):
    super().__init__()
    self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

  def forward(self, hidden_states, height, width):
    batch_size, seq_len, num_channels = hidden_states.shape
    hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
    hidden_states = self.dwconv(hidden_states)
    hidden_states = hidden_states.flatten(2).transpose(1, 2)

    return hidden_states


class SegformerMixFFN(nn.Module):
  def __init__(self, config, in_features, hidden_features=None, out_features=None):
    super().__init__()
    out_features = out_features or in_features
    self.dense1 = nn.Linear(in_features, hidden_features)
    self.dwconv = SegformerDWConv(hidden_features)
    if isinstance(config.hidden_act, str):
      self.intermediate_act_fn = ACT2FN[config.hidden_act]
    else:
      self.intermediate_act_fn = config.hidden_act
    self.dense2 = nn.Linear(hidden_features, out_features)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, hidden_states, height, width):
    hidden_states = self.dense1(hidden_states)
    hidden_states = self.dwconv(hidden_states, height, width)
    hidden_states = self.intermediate_act_fn(hidden_states)
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.dense2(hidden_states)
    hidden_states = self.dropout(hidden_states)
    return hidden_states


In [None]:
# reference: vit-rope


def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
  freqs_x = []  # (num_heads, dim//4)
  freqs_y = []  # (num_heads, dim//4)
  mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))  # (dim//4)
  for _ in range(num_heads):
    angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)  # (1)
    fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1)  # (dim//2)
    fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1)
    freqs_x.append(fx)
    freqs_y.append(fy)
  freqs_x = torch.stack(freqs_x, dim=0)
  freqs_y = torch.stack(freqs_y, dim=0)
  freqs = torch.stack([freqs_x, freqs_y], dim=0)
  return freqs  # (2, num_heads, dim // 2)


def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
  N = t_x.shape[0]
  depth = freqs.shape[1]  # (2, len(self.blocks), C//2)[1]
  # No float 16 for this range
  with torch.amp.autocast(device_type="cuda", enabled=False):
    freqs_x = (
      (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
      # (N, depth, C//2)
      .view(depth, N, num_heads, -1)  # (depth, N, num_heads, C_per_head//2)
      .permute(0, 2, 1, 3)
    )
    freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
    freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)

  return freqs_cis  # (depth, num_heads, N, dim//2)


def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
  freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
  freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))

  t_x, t_y = init_t_xy(end_x, end_y)
  freqs_x = torch.outer(t_x, freqs_x)
  freqs_y = torch.outer(t_y, freqs_y)
  freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
  freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
  return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)


def init_t_xy(end_x: int, end_y: int):
  t = torch.arange(end_x * end_y, dtype=torch.float32)
  t_x = (t % end_x).float()
  t_y = torch.div(t, end_x, rounding_mode="floor").float()
  return t_x, t_y


In [None]:
# 実装


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  # (B, num_heads, N, C_per_head // 2)
  # freqs_cis (N, C_per_head) or (num_heads, N, C_per_head//2)

  ndim = x.ndim

  if ndim <= 1:
    raise ValueError("ndim must be greater than 1")

  if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
    shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
  elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
    shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
  else:
    msg = f"Invalid shape for `freqs_cis {freqs_cis.shape}` and `x {x.shape}`"
    raise ValueError(msg)

  return freqs_cis.view(*shape)  # (1, 1 or num_heads, N, C_per_head//2)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, sr_ratio: int, freqs_cis: torch.Tensor, height, width):
  # freqs_cis (N, C_per_head) or (num_heads, N, C_per_head//2)

  xq_ = torch.view_as_complex(
    xq.float().reshape(*xq.shape[:-1], -1, 2)
  )  # (B, num_heads, N, C_per_head) -> (B, num_heads, N, C_per_head//2)
  xk_ = torch.view_as_complex(
    xk.float().reshape(*xk.shape[:-1], -1, 2)
  )  # (B, num_heads, H*W/sr^2, C_per_head) -> (B, num_heads, H*W/sr^2, C_per_head//2)

  xq_freqs_cis = reshape_for_broadcast(freqs_cis, xq_)  # (1, 1 or num_heads, N, C_per_head//2)
  xk_freqs_cis = xq_freqs_cis  # (1, 1 or num_heads, N, C_per_head//2)

  if sr_ratio > 1:
    f_b, f_head, _, f_c = xq_freqs_cis.shape

    xk_freqs_cis = xk_freqs_cis.view(f_b, f_head, height, width, f_c).permute(
      0, 1, 4, 2, 3
    )  # (1, 1 or num_heads, C_per_head//2, H, W)

    xk_freqs_cis = xk_freqs_cis.view(f_b * f_head, f_c, height, width)
    # (1, 1 or num_heads, C_per_head//2, H, W)

    # === pooling ===

    real_part = xk_freqs_cis.real
    imag_part = xk_freqs_cis.imag

    pooled_real = F.avg_pool2d(real_part, kernel_size=sr_ratio, stride=sr_ratio)
    pooled_imag = F.avg_pool2d(imag_part, kernel_size=sr_ratio, stride=sr_ratio)

    xk_freqs_cis = torch.complex(pooled_real, pooled_imag)

    # === end of pooling ===

    xk_freqs_cis = xk_freqs_cis.view(f_b, f_head, f_c, height // sr_ratio, width // sr_ratio)

    xk_freqs_cis = xk_freqs_cis.permute(0, 1, 3, 4, 2).reshape(
      f_b, f_head, -1, f_c
    )  # (1, 1 or num_heads, H*W/sr^2, C_per_head//2)

  xq_out = torch.view_as_real(xq_ * xq_freqs_cis).flatten(
    3
  )  # (B, num_heads, N, C_per_head//2) -> (B, num_heads, N, C_per_head//2, 2) -> (B, num_heads, N, C_per_head)
  xk_out = torch.view_as_real(
    xk_ * xk_freqs_cis
  ).flatten(
    3
  )  # (B, num_heads, H*W/sr^2, C_per_head//2) -> (B, num_heads, N/sr^2, C_per_head//2, 2) -> (B, num_heads, N/sr^2, C_per_head)
  return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(
    xk.device
  )  # (B, num_heads, N, C_per_head), (B, num_heads, N/sr^2, C_per_head)


class SegformerWithRoPEEfficientSelfAttention(nn.Module):
  """Efficient Self Attention with RoPE module."""

  def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
    super().__init__()
    self.hidden_size = hidden_size
    self.num_attention_heads = num_attention_heads

    if self.hidden_size % self.num_attention_heads != 0:
      msg = (
        f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
        f"heads ({self.num_attention_heads})"
      )
      raise ValueError(msg)

    self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size  # == hidden_size

    self.query = nn.Linear(self.hidden_size, self.all_head_size)  # (B, H*W, C) -> (B, H*W, C)
    self.key = nn.Linear(self.hidden_size, self.all_head_size)  # (B, H*W, C) -> (B, H*W, C)
    self.value = nn.Linear(self.hidden_size, self.all_head_size)  # (B, H*W, C) -> (B, H*W, C)

    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    self.sr_ratio = sequence_reduction_ratio
    if sequence_reduction_ratio > 1:
      self.sr = nn.Conv2d(
        hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
      )
      self.layer_norm = nn.LayerNorm(hidden_size)

  def transpose_for_scores(self, hidden_states):
    new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    hidden_states = hidden_states.view(new_shape)
    return hidden_states.permute(0, 2, 1, 3)

  def forward(
    self,
    hidden_states,
    height,
    width,
    freqs_cis,  # added (N, C_per_head//2) or (num_heads, N, C_per_head//2)
    output_attentions=False,
  ):
    query_layer = self.transpose_for_scores(
      self.query(hidden_states)
    )  # (B, H*W, C) -> (B, H*W, C) -> (B, num_attention_heads, H*W, C_per_head)

    if self.sr_ratio > 1:
      batch_size, seq_len, num_channels = hidden_states.shape  # (B, H*W, C)
      # Reshape to (batch_size, num_channels, height, width)
      hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)  # (B, C, H, W)
      # Apply sequence reduction
      hidden_states = self.sr(hidden_states)  # (B, C, H, W) -> (B, C, H/sr, W/sr)
      # Reshape back to (batch_size, seq_len, num_channels)
      hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(
        0, 2, 1
      )  # (B, C, H/sr*W/sr) -> (B, H/sr*W/sr, C) = (B, H*W/sr^2, C)
      hidden_states = self.layer_norm(hidden_states)

    key_layer = self.transpose_for_scores(
      self.key(hidden_states)
    )  # (B, H*W, C) -> (B, H*W, C) -> (B, num_attention_heads, H*W, C_per_head)
    value_layer = self.transpose_for_scores(
      self.value(hidden_states)
    )  # (B, H*W, C) -> (B, H*W, C) -> (B, num_attention_heads, H*W, C_per_head)

    # === Apply RoPE ===

    # input: xq, xk, sr_ratio, freqs_cis
    #   query_layer: (B, num_attention_heads, H*W, C_per_head)
    #   key_layer: (B, num_attention_heads, H*W/sr^2, C_per_head)
    # output: (B, num_heads, N, C_per_head), (B, num_heads, N/sr^2, C_per_head)
    query_layer, key_layer = apply_rotary_emb(
      query_layer,
      key_layer,
      self.sr_ratio,
      freqs_cis,
      height,
      width,
    )

    # === End of RoPE ===

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    # Normalize the attention scores to probabilities.
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    return outputs


class SegformerWithRoPEAttention(nn.Module):
  def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
    super().__init__()
    self.self = SegformerWithRoPEEfficientSelfAttention(
      config=config,
      hidden_size=hidden_size,
      num_attention_heads=num_attention_heads,
      sequence_reduction_ratio=sequence_reduction_ratio,
    )
    self.output = SegformerSelfOutput(config, hidden_size=hidden_size)
    self.pruned_heads = set()

  # def prune_heads(self, heads):
  #     if len(heads) == 0:
  #         return
  #     heads, index = find_pruneable_heads_and_indices(
  #         heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  #     )

  #     # Prune linear layers
  #     self.self.query = prune_linear_layer(self.self.query, index)
  #     self.self.key = prune_linear_layer(self.self.key, index)
  #     self.self.value = prune_linear_layer(self.self.value, index)
  #     self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

  #     # Update hyper params and store pruned heads
  #     self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  #     self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  #     self.pruned_heads = self.pruned_heads.union(heads)

  def forward(
    self,
    hidden_states,
    height,
    width,
    freqs_cis,  # added (N, C_per_head//2) or (num_heads, N, C_per_head//2)
    output_attentions=False,
  ):
    self_outputs = self.self(hidden_states, height, width, freqs_cis, output_attentions)

    attention_output = self.output(self_outputs[0], hidden_states)
    outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
    return outputs


class SegformerWithRoPELayer(nn.Module):
  """This corresponds to the Block class in the original implementation."""

  def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
    super().__init__()
    self.layer_norm_1 = nn.LayerNorm(hidden_size)
    self.attention = SegformerWithRoPEAttention(
      config,
      hidden_size=hidden_size,
      num_attention_heads=num_attention_heads,
      sequence_reduction_ratio=sequence_reduction_ratio,
    )
    self.drop_path = SegformerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
    self.layer_norm_2 = nn.LayerNorm(hidden_size)
    mlp_hidden_size = int(hidden_size * mlp_ratio)
    self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)

  def forward(
    self,
    hidden_states,
    height,
    width,
    freqs_cis,  # added (N, C_per_head//2) or (num_heads, N, C_per_head//2)
    output_attentions=False,
  ):
    self_attention_outputs = self.attention(
      self.layer_norm_1(hidden_states),  # in Segformer, layernorm is applied before self-attention
      height,
      width,
      freqs_cis,
      output_attentions=output_attentions,
    )

    attention_output = self_attention_outputs[0]
    outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

    # first residual connection (with stochastic depth)
    attention_output = self.drop_path(attention_output)
    hidden_states = attention_output + hidden_states

    mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)

    # second residual connection (with stochastic depth)
    mlp_output = self.drop_path(mlp_output)
    layer_output = mlp_output + hidden_states

    outputs = (layer_output, *outputs)

    return outputs


class SegformerWithRoPEEncoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config

    # stochastic depth decay rule
    drop_path_decays = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]

    # patch embeddings
    embeddings = []
    for i in range(config.num_encoder_blocks):  # 4 layers
      embeddings.append(
        SegformerOverlapPatchEmbeddings(
          patch_size=config.patch_sizes[i],
          stride=config.strides[i],
          num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
          hidden_size=config.hidden_sizes[i],
        )
      )
    self.patch_embeddings = nn.ModuleList(embeddings)

    # Transformer blocks
    blocks = []

    cur = 0
    for i in range(config.num_encoder_blocks):  # 4 layers
      # each block consists of layers
      layers = []
      if i != 0:
        cur += config.depths[i - 1]  # b0 [2, 2, 2, 2]
      for j in range(config.depths[i]):
        layers.append(
          SegformerWithRoPELayer(
            config,
            hidden_size=config.hidden_sizes[i],
            num_attention_heads=config.num_attention_heads[i],
            drop_path=drop_path_decays[cur + j],
            sequence_reduction_ratio=config.sr_ratios[i],
            mlp_ratio=config.mlp_ratios[i],
          )
        )
      blocks.append(nn.ModuleList(layers))

    # axial: (encoder_blocks, H*W, dim_per_head//2 )
    # mixed: (encoder_blocks, transformer_blocks, 2, num_blocks, num_heads, dim//2)
    freqs = []
    embedding_size = config.image_size
    self.compute_cis = []

    # === compute frequency ===
    for i in range(config.num_encoder_blocks):  # 4
      embedding_size = (
        (embedding_size + 2 * (config.patch_sizes[i] // 2) - config.patch_sizes[i]) // config.strides[i]
      ) + 1

      if self.config.rope_mixed:
        compute_cis = partial(compute_mixed_cis, num_heads=config.num_attention_heads[i])
        self.compute_cis.append(compute_cis)

        f = []  # (blocks, 2, num_heads, C_per_heads//2)

        for _ in range(config.depths[i]):
          f.append(
            init_random_2d_freqs(
              dim=config.hidden_sizes[i] // config.num_attention_heads[i],
              num_heads=config.num_attention_heads[i],
              theta=config.rope_theta,
            )  # (2, num_heads, C_per_heads//2)
          )

        # (2, num_heads, C_per_heads//2)[num_blocks] -> (2, num_blocks, num_heads, C_per_heads // 2) -> (2, config.depths[i], num_blocks * num_heads * C_per_heads // (2 * config.depths[i]))
        # -> (2, config.depths[i], C//2)
        f = torch.stack(f, dim=1).view(2, config.depths[i], -1)

        freqs.append(nn.Parameter(f.clone()))

        _t_x, _t_y = init_t_xy(end_x=embedding_size, end_y=embedding_size)

        self.register_buffer(f"t_x_{i}", _t_x)
        self.register_buffer(f"t_y_{i}", _t_y)

      else:
        compute_cis = partial(
          compute_axial_cis,
          theta=self.config.rope_theta,
          dim=config.hidden_sizes[i] // config.num_attention_heads[i],
        )
        self.compute_cis.append(compute_cis)

        freqs_cis = compute_cis(end_x=embedding_size, end_y=embedding_size)  # (N, C_per_head//2)

        freqs.append(freqs_cis)

    if self.config.rope_mixed:
      self.freqs = nn.ParameterList(freqs)
    else:
      self.freqs = freqs

    # === end of compute frequency ===

    self.block = nn.ModuleList(blocks)

    # Layer norms
    self.layer_norm = nn.ModuleList([nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)])

  def forward(
    self,
    pixel_values: torch.FloatTensor,
    output_attentions: bool | None = False,
    output_hidden_states: bool | None = False,
    return_dict: bool | None = True,
  ) -> tuple | BaseModelOutput:
    all_hidden_states = () if output_hidden_states else None
    all_self_attentions = () if output_attentions else None

    batch_size = pixel_values.shape[0]

    hidden_states = pixel_values
    for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm, strict=False)):
      embedding_layer, block_layer, norm_layer = x

      # first, obtain patch embeddings
      hidden_states, height, width = embedding_layer(hidden_states)  # (B, H*W, C)

      # === compute cis ===

      if self.config.rope_mixed:
        t_x = getattr(self, f"t_x_{idx}")
        t_y = getattr(self, f"t_y_{idx}")

        compute_cis = self.compute_cis[idx]

        # freqs_cis: (2, config.depths[i], C//2) -> # (depth, num_heads, H*W, C//2)
        freqs_cis = compute_cis(freqs=self.freqs[idx], t_x=t_x, t_y=t_y)
      else:
        freqs_cis = self.freqs[idx].to(
          pixel_values.device
        )  # (encoder_blocks, H*W, dim_per_head//2 )[idx] -> (N, C_per_head//2)

      # === end of compute cis ===

      # second, send embeddings through blocks
      for i, blk in enumerate(block_layer):
        layer_outputs = blk(
          hidden_states, height, width, freqs_cis[i] if self.config.rope_mixed else freqs_cis, output_attentions
        )

        hidden_states = layer_outputs[0]

        if output_attentions:
          all_self_attentions = (*all_self_attentions, layer_outputs[1])

      # third, apply layer norm
      hidden_states = norm_layer(hidden_states)

      # fourth, optionally reshape back to (batch_size, num_channels, height, width)
      if idx != len(self.patch_embeddings) - 1 or (
        idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage
      ):
        hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
      if output_hidden_states:
        all_hidden_states = (*all_hidden_states, hidden_states)

    if not return_dict:
      return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
    return BaseModelOutput(
      last_hidden_state=hidden_states,
      hidden_states=all_hidden_states,
      attentions=all_self_attentions,
    )


class SegformerWithRoPEModel(SegformerModel):
  def __init__(self, config):
    super().__init__(config)
    self.encoder = SegformerWithRoPEEncoder(config)
    self.post_init()


class SegformerWithRoPEForSemanticSegmentation(SegformerForSemanticSegmentation):
  def __init__(self, config):
    super().__init__(config)
    self.segformer = SegformerWithRoPEModel(config)
    self.post_init()


In [None]:
# config

from transformers import SegformerConfig


class SegformerWithRoPEConfig(SegformerConfig):
  """The configuration class of a Segformer model with a RoPE module.

  Args:
      SegformerConfig (_type_): _description_

  """

  model_type = "segformer-with-rope"

  def __init__(self, rope_theta: float = 100.0, image_size: int = 512, *, rope_mixed: bool = False, **kwargs):
    super().__init__(**kwargs)
    self.rope_theta = rope_theta
    self.image_size = image_size
    self.rope_mixed = rope_mixed


## 4.2 Instance


In [None]:
configuration = SegformerConfig(num_labels=150, id2label=id2label, label2id=label2id)
model = SegformerForSemanticSegmentation(configuration)

In [None]:
configuration = SegformerWithRoPEConfig(num_labels=150, id2label=id2label, label2id=label2id, rope_mixed=False)
model = SegformerWithRoPEForSemanticSegmentation(configuration)

# 5. Training


In [None]:
!wandb login --relogin

In [None]:
NAME = "segformer-with-ropoe-b0-mixed"
DIRECTORY = ""

TRAIN_DATASET_SIZE = len(train_dataset)
EVAL_DATASET_SIZE = len(validation_dataset)

NUM_EPOCK = 128
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 16

GRADIENT_ACCUMULATION = 1
EVAL_ACCUMULATION = 4

EVAL_LATE = 5 / 100

EVAL_STEPS = int(TRAIN_DATASET_SIZE // (TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION) * NUM_EPOCK * EVAL_LATE)

training_args = TrainingArguments(
  run_name=NAME,
  output_dir=f"{DIRECTORY}/outputs/{NAME}",
  logging_dir=f"{DIRECTORY}/logs/{NAME}",
  learning_rate=1e-3,
  lr_scheduler_type="polynomial",
  warmup_ratio=0.1,
  num_train_epochs=NUM_EPOCK,
  per_device_train_batch_size=TRAIN_BATCH_SIZE,
  per_device_eval_batch_size=EVAL_BATCH_SIZE,
  eval_accumulation_steps=EVAL_ACCUMULATION,
  gradient_accumulation_steps=GRADIENT_ACCUMULATION,
  save_total_limit=3,
  eval_strategy="steps",
  save_strategy="steps",
  save_steps=EVAL_STEPS,
  eval_steps=EVAL_STEPS,
  logging_steps=EVAL_STEPS,
  remove_unused_columns=False,
  dataloader_num_workers=4,
  # fp16=True,
  # use_cpu=True
)

trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=train_dataset,
  eval_dataset=validation_dataset,
  compute_metrics=compute_metrics,
)

ic(EVAL_STEPS);

In [None]:
gc.collect()

In [None]:
trainer.train()

In [None]:
trainer.evaluate()