# vit_rgb_segmentation Notebook

_Jupyter notebook for training SegFormer ViT on RGB data_

## 1. Импорты

In [None]:
# import os
# import csv
# from pathlib import Path
# import numpy as np
# import cv2
# from tqdm import tqdm

# import torch
# from torch.utils.data import Dataset, DataLoader
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# from transformers import SegformerConfig, SegformerForSemanticSegmentation

# import torch.nn.functional as F


import csv
from pathlib import Path

import numpy as np
import cv2
from tqdm import tqdm

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

import albumentations as A
from albumentations.pytorch import ToTensorV2


In [None]:
import torch
print(torch.__version__)                
print(torch.cuda.is_available())        
print(torch.cuda.get_device_name(0))    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device.type

2.7.1+cu126
True
NVIDIA GeForce RTX 3050 Laptop GPU


In [None]:

class PatchEmbed(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size, stride, padding=None):
        super().__init__()
        if padding is None:
            padding = patch_size // 2
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=stride, padding=padding)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.proj(x)                       
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1,2)          
        x = self.norm(x)
        x = x.transpose(1,2).reshape(B, C, H, W)  
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn  = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp   = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
        self.drop_path = nn.Identity()
    def forward(self, x):
        B, C, H, W = x.shape
        x_flat = x.flatten(2).transpose(1,2)    
        x2 = self.norm1(x_flat)
        attn_out,_ = self.attn(x2, x2, x2)
        x_flat = x_flat + self.drop_path(attn_out)
        # --- MLP ---
        x2 = self.norm2(x_flat)
        mlp_out = self.mlp(x2)
        x_flat = x_flat + self.drop_path(mlp_out)
        # --- back to map ---
        x = x_flat.transpose(1,2).reshape(B, C, H, W)
        return x

class CustomSegFormer(nn.Module):
    def __init__(self, in_channels=3, num_classes=9,
                 embed_dims=[64,128,320,512],
                 num_heads=[1,2,5,8],
                 depths=[3,4,6,3],
                 decoder_dim=256):
        super().__init__()
        self.stages = nn.ModuleList()
        ch = in_channels
        patch_sizes = [7,3,3,3]
        strides     = [4,2,2,2]
        for i in range(4):
            layers = [ PatchEmbed(ch, embed_dims[i], patch_sizes[i], strides[i]) ]
            for _ in range(depths[i]):
                layers.append( TransformerBlock(embed_dims[i], num_heads[i]) )
            self.stages.append(nn.Sequential(*layers))
            ch = embed_dims[i]
        self.proj_convs = nn.ModuleList([
            nn.Conv2d(embed_dims[i], decoder_dim, kernel_size=1)
            for i in range(4)
        ])
        self.head = nn.Conv2d(decoder_dim*4, num_classes, kernel_size=1)

    def forward(self, x):
        feats = []
        for stage in self.stages:
            x = stage(x)
            feats.append(x)
        H0,W0 = feats[0].shape[2:]
        proj = []
        for i,f in enumerate(feats):
            p = self.proj_convs[i](f)
            if p.shape[2:] != (H0,W0):
                p = F.interpolate(p, size=(H0,W0),
                                  mode='bilinear', align_corners=False)
            proj.append(p)
        x_dec = torch.cat(proj, dim=1)
        x_dec = self.head(x_dec)  
        scale = 4*2*2*2
        x_dec = F.interpolate(x_dec, scale_factor=scale,
                              mode='bilinear', align_corners=False)
        return x_dec


## 2. Конфигурация

In [17]:
class CFG:
    train_dir = Path("/mnt/d/Agriculture-Vision-2021 2/train")
    val_dir = Path("/mnt/d/Agriculture-Vision-2021 2/val")
    checkpoint_dir = Path("runs/vit_rgb")
    num_classes = 9
    img_size =  256 #512 #256 
    batch_size = 16
    epochs = 10 # 50
    lr = 6e-5
    model_size = "b2"  # ["b0", "b1", "b2", "b3"]

CFG.checkpoint_dir.mkdir(parents=True, exist_ok=True)


## 3. Пути к данным

In [18]:
train_rgb = CFG.train_dir / "images" / "rgb"
train_mask = CFG.train_dir / "masks"
val_rgb = CFG.val_dir / "images" / "rgb"
val_mask = CFG.val_dir / "masks"

torch.backends.cudnn.benchmark = True


## 4. Определение датасета

In [None]:
class RGBNDVIDataset(Dataset):
    def __init__(self, rgb_dir, nir_dir, mask_dir, size=512, augment=False):
        self.rgb_paths = sorted(Path(rgb_dir).glob("*"))
        self.nir_dir   = Path(nir_dir)
        self.mask_dir  = Path(mask_dir)
        self.size      = size
        self.augment   = augment
        self.tf        = self.build_tf()

    def build_tf(self):
        transforms = [A.Resize(self.size, self.size)]
        if self.augment:
            transforms += [
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ColorJitter(p=0.3),
            ]
        transforms += [
            A.Normalize(mean=(0.5, 0.5, 0.5, 0.5),
                        std =(0.5, 0.5, 0.5, 0.5)),
            ToTensorV2(transpose_mask=True)
        ]
        return A.Compose(transforms)

    def __len__(self):
        return len(self.rgb_paths)

    def __getitem__(self, idx):
        rgb_path  = self.rgb_paths[idx]
        nir_path  = self.nir_dir  / rgb_path.name
        mask_path = self.mask_dir / rgb_path.with_suffix(".png").name

        rgb = cv2.cvtColor(cv2.imread(str(rgb_path)), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        nir = cv2.imread(str(nir_path), cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0
        nir = nir[..., None]

        r    = rgb[..., :1]
        ndvi = (nir - r) / (nir + r + 1e-6)
        ndvi = (ndvi + 1.0) / 2.0

        img4 = np.concatenate([rgb, ndvi], axis=-1)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

        transformed = self.tf(image=img4, mask=mask)
        return transformed["image"].float(), transformed["mask"].long()



## 5. Метрики

In [None]:
@torch.no_grad()
def compute_metrics(logits: torch.Tensor, target: torch.Tensor, num_classes: int):
    if logits.shape[-2:] != target.shape[-2:]:
        logits = F.interpolate(
            logits,
            size=target.shape[-2:],      # (H, W) маски
            mode="bilinear",
            align_corners=False
        )
    preds = logits.argmax(1)  # [B,H,W]
    valid = target != 255
    correct = (preds[valid] == target[valid]).sum()
    total = valid.sum()
    pix_acc = (correct / total).item() if total > 0 else 0.0

    ious = []
    for cls in range(num_classes):
        pred_c = (preds == cls) & valid
        targ_c = (target == cls) & valid
        inter = (pred_c & targ_c).sum().item()
        union = pred_c.sum().item() + targ_c.sum().item() - inter
        if union > 0:
            ious.append(inter / union)
    miou = float(np.mean(ious)) if ious else 0.0

    return pix_acc, miou


## 6. Тренировочный и валидационный циклы

In [None]:
LOSS = torch.nn.CrossEntropyLoss(ignore_index=255)

def train_epoch(model, loader, opt, device, num_classes):
    model.train()
    tot_loss = tot_acc = tot_iou = 0.0
    for img, mask in tqdm(loader, desc="Train", leave=False):
        img, mask = img.to(device), mask.to(device)
        opt.zero_grad()
        logits = model(img)                                    # B×C×H'×W'
        loss   = LOSS(logits, mask)                           # CrossEntropy
        loss.backward()
        opt.step()
        acc, miou = compute_metrics(logits.detach(), mask, num_classes)
        n = img.size(0)
        tot_loss += loss.item()*n
        tot_acc  += acc * n
        tot_iou  += miou* n
    N = len(loader.dataset)
    return tot_loss/N, tot_acc/N, tot_iou/N

@torch.no_grad()
def eval_epoch(model, loader, device, num_classes):
    model.eval()
    tot_loss = tot_acc = tot_iou = 0.0
    for img, mask in tqdm(loader, desc="Val", leave=False):
        img, mask = img.to(device), mask.to(device)
        logits = model(img)
        loss   = LOSS(logits, mask)
        acc, miou = compute_metrics(logits, mask, num_classes)
        n = img.size(0)
        tot_loss += loss.item()*n
        tot_acc  += acc * n
        tot_iou  += miou* n
    N = len(loader.dataset)
    return tot_loss/N, tot_acc/N, tot_iou/N



## 7. Инициализация модели и загрузчики

In [None]:
# Model config and initialization
id2label = {i: f"class_{i}" for i in range(CFG.num_classes)}
model = CustomSegFormer(
    in_channels=3,         # RGB-only
    num_classes=CFG.num_classes,
    embed_dims=[64,128,320,512],
    num_heads=[1,2,5,8],
    depths=[2,2,3,1],      # можно уменьшить для ресурсоёмкости
    decoder_dim=128        # можно варьировать
).to(device)

# DataLoaders
train_ds = RGBDataset(
    train_rgb, train_mask, size=CFG.img_size, augment=True,
)
val_ds   = RGBDataset(
    val_rgb, val_mask, size=CFG.img_size, augment=False,
)

train_loader = DataLoader(
    train_ds, 
    batch_size=CFG.batch_size, 
    shuffle=True,
    num_workers=4, 
    pin_memory=True,
)
val_loader = DataLoader(
    val_ds, 
    batch_size=CFG.batch_size, 
    shuffle=False,
    num_workers=4, 
    pin_memory=True,
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b2 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 8. Обучение

In [None]:
# Optimizer and scheduler
opt = AdamW(model.parameters(), lr=CFG.lr, weight_decay=0.01)
sched = ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=3)

# Training loop
log_path = CFG.checkpoint_dir / "train.csv"
with open(log_path, "w", newline="") as f:
    csv.writer(f).writerow(["epoch", "loss", "val_loss", "val_acc", "val_miou"])

best_miou = 0.0



SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)

In [None]:

for epoch in range(1, CFG.epochs + 1):
    print(f"\nEpoch {epoch}/{CFG.epochs}")
    tr_loss, tr_acc, tr_iou = train_epoch(model, train_loader, opt, device, CFG.num_classes)
    val_loss, val_acc, val_iou = eval_epoch(model, val_loader, device, CFG.num_classes)
    sched.step(val_loss)
    with open(log_path, "a", newline="") as f:
        csv.writer(f).writerow([epoch, tr_loss, val_loss, val_acc, val_iou])
    print(f"loss={tr_loss:.3f}  val_loss={val_loss:.3f}  val_acc={val_acc:.3f}  val_mIoU={val_iou:.3f}")

    if val_iou > best_miou:
        best_miou = val_iou
        torch.save({"model": model.state_dict(), "epoch": epoch, "miou": val_iou}, CFG.checkpoint_dir / "best_model.pt")
        print(f"✔ Saved best mIoU {val_iou:.3f} at epoch {epoch}")

print(f"Training finished. Best mIoU = {best_miou:.3f}")



Epoch 1/10


Train:   0%|          | 0/3559 [00:00<?, ?it/s]