# Environment
- Ubuntu 24.04
- Python 3.10
- 해당 노트북 파일과 같은 디렉터리 상에 ./Wildfire 폴더에 대회 데이터셋이 존재해야합니다.
  - ./Wildfire/train_img
  - ./Wildfire/train_mask
  - ./Wildfire/test_img

# 주의사항
max_epoch은 150으로 잡혀있습니다.
단일 RTX4090에서도 이틀넘게 걸리는 긴 시간의 학습이 필요합니다.
하지만, 10epoch마다 체크포인트가 "Logs/{학습시작시간}" 폴더에 저장되는데, 50epoch정도면 충분히 최고 성능으로 수렴하는 편입니다.
그러므로, 빠르게 재현성을 확인하시고자 한다면, 한 번씩 중간 epoch의 체크포인트로 점수를 확인해보시길 권장드립니다.

In [1]:
# 필요 패키지 설치 (현시점 (2024-03-27) 최신버전으로 설치시 문제 없음)
! pip install numpy opencv-contrib-python albumentations --upgrade
! pip install torch torchvision openmim --upgrade
! pip install timm --upgrade
! mim install mmengine --upgrade
! pip install rasterio --upgrade

Collecting numpy
  Using cached numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Collecting opencv-contrib-python
  Using cached opencv_contrib_python-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Collecting albumentations
  Using cached albumentations-1.4.2-py3-none-any.whl.metadata (36 kB)
Collecting scipy>=1.10.0 (from albumentations)
  Using cached scipy-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting scikit-image>=0.21.0 (from albumentations)
  Using cached scikit_image-0.22.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting scikit-learn>=1.3.2 (from albumentations)
  Using cached scikit_learn-1.4.1.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting opencv-python-headless>=4.9.0 (from albumentations)
  Using cached opencv_python_headless-4.9.0.80-cp37-abi3-manylinux_2_17_x86_64.man

# Lovasz Loss
https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py

위의 공식 사이트의 코드를 그대로 가져왔고, lovasz_hinge에서 relu대신 elu를 사용하는 option만 추가하였습니다.

In [2]:
"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable

try:
    from itertools import ifilterfalse
except ImportError:  # py3k
    from itertools import filterfalse as ifilterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1.0 - intersection / union
    if p > 1:  # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def iou_binary(preds, labels, EMPTY=1.0, ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / float(union)
        ious.append(iou)
    iou = mean(ious)  # mean accross images if per_image
    return 100 * iou


def iou(preds, labels, C, EMPTY=1.0, ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []
        for i in range(C):
            if (
                i != ignore
            ):  # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / float(union))
        ious.append(iou)
    ious = [mean(iou) for iou in zip(*ious)]  # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES ---------------------------


def lovasz_hinge(logits, labels, per_image=True, ignore=None, use_elu=False):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(
            lovasz_hinge_flat(
                *flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore), use_elu=use_elu
            )
            for log, lab in zip(logits, labels)
        )
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore), use_elu=use_elu)
    return loss


def lovasz_hinge_flat(logits, labels, use_elu=False):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.0
    signs = 2.0 * labels.float() - 1.0
    errors = 1.0 - logits * Variable(signs)
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)

    if use_elu:
        loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    else:  # original
        loss = torch.dot(F.relu(errors_sorted), Variable(grad))

    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = labels != ignore
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


class StableBCELoss(torch.nn.modules.Module):
    def __init__(self):
        super(StableBCELoss, self).__init__()

    def forward(self, input, target):
        neg_abs = -input.abs()
        loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
        return loss.mean()


def binary_xloss(logits, labels, ignore=None):
    """
    Binary Cross entropy loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      ignore: void class id
    """
    logits, labels = flatten_binary_scores(logits, labels, ignore)
    loss = StableBCELoss()(logits, Variable(labels.float()))
    return loss


# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(
            lovasz_softmax_flat(
                *flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes
            )
            for prob, lab in zip(probas, labels)
        )
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss


def lovasz_softmax_flat(probas, labels, classes="present"):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.0
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ["all", "present"] else classes
    for c in class_to_sum:
        fg = (labels == c).float()  # foreground for class c
        if classes is "present" and fg.sum() == 0:
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError("Sigmoid output possible only with 1 class")
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = labels != ignore
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels


def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)


# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
    return x != x


def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == "raise":
            raise ValueError("Empty mean")
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

  if classes is "present" and fg.sum() == 0:


# Dataset
학습데이터 로딩을 위한 dataset class입니다.
k-fold 교차검증이 가능하도록 하여 실험에서 사용하였습니다.
본 노트북에서는 전체 학습데이터를 사용하여 최종 모델을 학습하는 케이스만 보여주기 때문에, k-fold 교차 검증을 사용하진 않습니다.
(WildfireDataset을 초기화할때 kfold_N=0으로 k-fold교차 검증을 사용하지 않고 전체 학습데이터를 로딩)

해당 노트북 파일과 같은 디렉터리 상에 Wildfire 폴더에 대회 데이터셋이 존재해야합니다.

train 모드에서는 데이터 augmentation이 이루어집니다.
"transform = ..." 코드를 보면 확인할 수 있고, padding, random crop, horizontal flip, vertical flip을 사용합니다.

WildfireDataset을 초기화 할 때, input_chs로 이미지의 어떤 채널을 이용할 지 선택할 수 있습니다.
dataset을 로딩하면 총 10개의 채널이 있고, 뒷쪽 3개 채널은 큰 의미가 없어보였기에, 본 실험에서는 앞쪽 7개 채널만 이용하였습니다.
이미지를 로딩하면 기본적으로 uint16타입이기에, float32로 바꾸면서 65535로 나누어 주었습니다.

7개 채널의 mean값을 입력 이미지의 채널에 추가 하였습니다. (코드상에 # Add mean channels 부분)
일반 이미지에 비해서 샘플 마다의 data 분포의 변화가 큰 것 같아서, 딥러닝 네트워크에 정보를 주고자 하였습니다.
그래서 최종적으로는 14x256x256 크기의 이미지와, 1x256x256 크기의 마스크를 출력합니다.

WildfireDataset을 초기화 할 때, epoch_scale_factor를 조절하면, 데이터셋의 epoch 양을 조절할 수 있습니다.
본 실험에서는 epoch_scale_factor=10으로 두었습니다.
데이터셋의 크기가 비교적 작아서 학습 epoch이 금방 변화하였기에, epoch 변경 단계마다 있는 약간의 오버헤드를 피하고 싶어서 10으로 두었습니다.
본 코드에서 100epoch은 실제 데이터셋 기준으로는 1000epoch에 해당한다고 보면 됩니다.

In [3]:
import math
import random
from pathlib import Path

import albumentations as A
import cv2
import numpy as np
import rasterio
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset

DATA_ROOT = Path("Wildfire")
TRAIN_IMG_DIR = DATA_ROOT / "train_img"
TRAIN_MASK_DIR = DATA_ROOT / "train_mask"
TEST_IMG_DIR = DATA_ROOT / "test_img"

transform = A.Compose(
    [
        A.OneOf(
            [
                A.PadIfNeeded(256 * 2, 256 * 2, border_mode=cv2.BORDER_REFLECT_101),
                A.PadIfNeeded(256 * 2, 256 * 2, border_mode=cv2.BORDER_WRAP),
            ],
            p=1.0,
        ),
        A.RandomCrop(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
    ]
)


def _imread_float(f: str | Path, input_chs: list[int]):
    img = rasterio.open(f).read()[input_chs].transpose((1, 2, 0))
    img = img / 65535
    img = img.astype(np.float32)

    return img


class WildfireDataset(Dataset):
    def __init__(
        self,
        mode: str,
        epoch_scale_factor: float,
        kfold_N: int,
        kfold_I: int,
        input_chs: list[int],
    ):
        super().__init__()

        assert mode in ["train", "val"]

        img_paths = sorted(TRAIN_IMG_DIR.glob("*.tif"))
        mask_paths = [TRAIN_MASK_DIR / x.name.replace("img", "mask") for x in img_paths]
        num_imgs = len(img_paths)

        if kfold_N == 0:  # Not use kfold, Use all training data.
            self.idx_map = list(range(num_imgs))
            assert mode == "train"
        else:
            if Path("/tmp/coverage_labels.npy").exists():
                coverage_labels = np.load("/tmp/coverage_labels.npy")
            else:
                coverages = []
                for x in mask_paths:
                    mask = cv2.imread(str(x), cv2.IMREAD_UNCHANGED)
                    coverages.append(mask.sum())
                coverages = np.array(coverages)
                hist = np.histogram(np.log(coverages))
                coverage_labels = np.digitize(np.log(coverages), hist[1], right=True)
                np.save("/tmp/coverage_labels.npy", coverage_labels)

            # coverage_labels = np.zeros(num_imgs, dtype=int)

            kf = StratifiedKFold(n_splits=kfold_N, shuffle=True, random_state=910103)
            train_idx, val_idx = list(kf.split(np.zeros(num_imgs), coverage_labels))[kfold_I]

            if mode == "train":
                self.idx_map = train_idx
            elif mode == "val":
                self.idx_map = val_idx
            else:
                raise ValueError(f"Unknown Mode {mode}")

        #
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.epoch_scale_factor = epoch_scale_factor
        self.mode = mode
        self.input_chs = input_chs

    def __getitem__(self, _idx):
        if _idx >= len(self):
            raise IndexError()

        if self.epoch_scale_factor < 1:
            _idx += len(self) * random.randrange(math.ceil(1 / self.epoch_scale_factor))

        idx = self.idx_map[_idx % len(self.idx_map)]

        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]

        img = _imread_float(img_path, self.input_chs)
        mask = cv2.imread(str(mask_path), cv2.IMREAD_UNCHANGED)

        if self.mode == "train":
            transformed = transform(image=img, mask=mask)
            img = transformed["image"]
            mask = transformed["mask"]
        else:
            # no augmentation
            pass

        # (H, W, C) -> (C, H, W)
        img = np.transpose(img, (2, 0, 1))
        mask = np.expand_dims(mask, axis=0).astype(np.float32)

        # Add mean channels
        img_mean = np.zeros_like(img)
        for i, each_ch in enumerate(img):
            if (each_ch > 0).sum() > 0:
                img_mean[i] = each_ch[each_ch > 0].mean()

        img = np.concatenate([img, img_mean], axis=0)

        # sample return
        sample = {"img": img, "mask": mask}

        return sample

    def __len__(self):
        return round(len(self.idx_map) * self.epoch_scale_factor)

# UNet Encoder
UNet의 인코더 파트입니다.
기본적으로 timm라이브러리의 regnetx_002를 가져와서 사용하였습니다.
timm에서 제공하는 pretrained weight로 초기화 시켰습니다.
해당 pretrained weight는 imagenet 데이터셋으로 학습된것으로 보입니다.
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/regnet.py
https://github.com/facebookresearch/pycls/blob/main/MODEL_ZOO.md

regnet encoder에서 UNet 생성에 필요없는 fnal_conv와 head layer는 제거 합니다.

conv0 레이어를 추가하여, regnet 이전에 붙였습니다.
해당 레이어는 2개의 1x1 conv layer로 이루어져있습니다.
일반 이미지와는 달리 Wildfire 영상의 scale변화가 샘플마다 컸기 때문에, 1x1 conv layer로 우선 각 픽셀에서 어떠한 정규화가 일어나길 기대햇습니다.

regnet의 conv1레이어는 원래 3채널의 RGB 값을 받는 레이어이기 때문에, conv0의 output인 32 채널을 받을 수 있도록, 수정을 가하였습니다.

In [4]:
import timm
import torch
from torch import nn


class RegNetEncoder(nn.Module):
    def __init__(
        self,
        name: str,
        in_ch: int,
        empty_out_depths: list[int],
    ):

        super().__init__()

        if name == "regnetx_002":
            self.model = timm.create_model("regnetx_002", pretrained=True)
        elif name == "regnetx_004":
            self.model = timm.create_model("regnetx_004", pretrained=True)
        elif name == "regnetx_006":
            self.model = timm.create_model("regnetx_006", pretrained=True)
        elif name == "regnetx_008":
            self.model = timm.create_model("regnetx_008", pretrained=True)
        else:
            raise ValueError(name)

        # Remove original fc layer
        del self.model.final_conv
        del self.model.head

        # conv0
        self.conv0 = nn.Sequential(
            nn.Conv2d(in_ch * 2, 32, kernel_size=1, padding=0, bias=True),
            # nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 32, kernel_size=1, padding=0, bias=True),
            # nn.BatchNorm2d(32),
            nn.ReLU(True),
        )

        # Patch first layer
        patch_ch = 32
        with torch.no_grad():
            orig_weight = self.model.stem.conv.weight.detach()

            new_conv = nn.Conv2d(patch_ch, 32, kernel_size=3, stride=1, padding=1, bias=False)
            new_conv.weight[:] = (
                orig_weight.repeat(1, (patch_ch + 2) // 3, 1, 1)[:, :patch_ch] * 3 / patch_ch
            )
            self.model.stem.conv = new_conv

        # ETC
        self.empty_out_depths = empty_out_depths

    def _get_stages(self) -> list[nn.Module]:
        return [
            # nn.Identity(),
            nn.Sequential(self.conv0, self.model.stem),
            self.model.s1,
            self.model.s2,
            self.model.s3,
            self.model.s4,
        ]

    @property
    def out_channels(self) -> list[int]:
        channels = [
            # self.model.stem.conv.in_channels,
            self.model.stem.conv.out_channels,
            self.model.s1.b1.conv1.conv.out_channels,
            self.model.s2.b1.conv1.conv.out_channels,
            self.model.s3.b1.conv1.conv.out_channels,
            self.model.s4.b1.conv1.conv.out_channels,
        ]
        return [0 if d in self.empty_out_depths else ch for d, ch in enumerate(channels)]

    def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
        stages = self._get_stages()

        features = []
        for depth, stage in enumerate(stages):
            x = stage(x)

            if depth in self.empty_out_depths:
                B, _, H, W = x.shape
                empty_tensor = torch.zeros((B, 0, H, W), dtype=x.dtype, device=x.device)
                features.append(empty_tensor)
            else:
                features.append(x)

        return features

  from .autonotebook import tqdm as notebook_tqdm


# UNet Decoder

UNet Decoder 파트입니다.
업샘플링 layer와 conv layer로 이루어져있는 전형적인 디코더 구조입니다.
인코더의 같은 level에서 skip connection 또한 수신하는 구조입니다.

In [5]:
import torch
from torch import nn


class Conv2dReLU(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
    ):

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm:
            bn = nn.BatchNorm2d(out_channels)
        else:
            bn = nn.Identity()

        super().__init__(conv, bn, relu)

class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_ch: int,
        skip_ch: int,
        out_ch: int,
        upsample_mode: str,
        use_batchnorm=True,
    ):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode=upsample_mode)

        self.conv1 = Conv2dReLU(
            in_ch + skip_ch,
            out_ch,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv2dReLU(
            out_ch,
            out_ch,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )

    def forward(self, x, skip=None):
        x = self.upsample(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_ch: int, out_ch: int, use_batchnorm=True):
        conv1 = Conv2dReLU(
            in_ch,
            out_ch,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = Conv2dReLU(
            out_ch,
            out_ch,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)


class UNetDecoder(nn.Module):
    def __init__(
        self,
        center_ch: int,
        skip_chs: list[int],
        decoder_chs: list[int],
        upsample_mode: str,
        use_batchnorm=True,
        use_center_block=False,
    ):
        super().__init__()

        if use_center_block:
            self.center = CenterBlock(center_ch, center_ch, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        in_chs = [center_ch] + decoder_chs[:-1]
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, upsample_mode, use_batchnorm=use_batchnorm)
            for in_ch, skip_ch, out_ch in zip(in_chs, skip_chs, decoder_chs, strict=True)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, center_feat, skip_feats):

        x = self.center(center_feat)
        xs = []
        for decoder_block, skip_feat in zip(self.blocks, skip_feats, strict=True):
            x = decoder_block(x, skip_feat)
            xs.append(x)

        return xs

# UNet
위에서 만든 인코더와 디코더를 합쳐 하나의 UNet을 만드는 코드입니다.

추가적으로 mmengine에서 사용하는 BaseModel도 정의되어 있습니다.
BaseModel 내부에 loss를 계산하는 부분을 확인할 수 있습니다.

In [6]:
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel as MMBaseModel
from mmengine.registry import MODELS
from torch import nn

class Conv2dUpsample(nn.Sequential):
    def __init__(self, in_ch: int, out_ch: int, kernel_size: int, upsampling: int):
        conv2d = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = (
            nn.Upsample(scale_factor=upsampling, mode="bilinear")
            if upsampling > 1
            else nn.Identity()
        )
        super().__init__(conv2d, upsampling)

class UNet(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()

        self.encoder = MODELS.build(encoder)
        encoder_chs = self.encoder.out_channels

        decoder["center_ch"] = encoder_chs[-1]
        decoder["skip_chs"] = encoder_chs[::-1][1:]
        self.decoder = MODELS.build(decoder)

        # head
        decoder_chs = decoder["decoder_chs"]
        self.head0 = Conv2dUpsample(decoder_chs[-1], 1, kernel_size=3, upsampling=1)

    def forward(self, inputs):
        x = inputs["img"]  # (B, C, H, W)

        # Forward Pass
        encoded = self.encoder(x)
        decoder_out = self.decoder(encoded[-1], encoded[::-1][1:])

        # output_h4 = self.head4(decoder_out[-5])
        # output_h3 = self.head3(decoder_out[-4])
        # output_h2 = self.head2(decoder_out[-3])
        # output_h1 = self.head1(decoder_out[-2])
        output_h0 = self.head0(decoder_out[-1])

        # output_h0 = torch.sigmoid(output_h0)

        return output_h0


class OurBaseModel(MMBaseModel):
    def __init__(self, unet):
        super().__init__()

        self.unet = MODELS.build(unet)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, mode, **inputs):
        y_pred = self.unet(inputs)  # (B, 1, H, W)
        y_gt = inputs["mask"]  # (B, 1, H, W)

        if mode == "loss":
            lovasz_loss = lovasz_hinge(y_pred, y_gt, use_elu=True)
            return {"lovasz_loss": lovasz_loss}
        elif mode == "predict":
            return torch.sigmoid(y_pred), y_gt

# Run Training
학습과 관련된 config를 세팅하고 학습을 시작하는 코드입니다.

config의 네트워크 모델 구성과 데이터셋 설정은 위에서 이미 설명했으므로 생략하겠습니다.

batchsize는 64이고 단일 GPU에서 15~16GB의 VRAM을 요구합니다.

max_epoch은 150으로 잡혀있습니다.
단일 RTX4090에서도 이틀정도 걸리는 긴 시간의 학습이 필요합니다.
하지만, 10epoch마다 체크포인트가 "Logs/{학습시작시간}" 폴더에 저장되는데, 50epoch정도면 충분히 최고 성능으로 수렴하는 편입니다.
그러므로, 재현성을 살펴보실때 굳이 끝까지 학습하실 필요는 없을 것 같습니다.

AdamW 옵티마이저를 사용하였고, 10epoch 단위로 코사인어닐링 LR 스케쥴링을 사용합니다.

In [7]:
from datetime import datetime
from pathlib import Path

from mmengine.dataset import DefaultSampler, default_collate
from mmengine.hooks import CheckpointHook, LoggerHook
from mmengine.optim.scheduler import CosineRestartLR
from mmengine.runner import Runner
from torch.optim import AdamW

def main():
    input_chs = [0, 1, 2, 3, 4, 5, 6]

    runner = Runner(
        env_cfg=dict(
            cudnn_benchmark=True,
            mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0),
            dist_cfg=dict(backend="nccl"),
        ),
        model=dict(
            type=OurBaseModel,
            unet=dict(
                type=UNet,
                encoder=dict(
                    type=RegNetEncoder,
                    name="regnetx_002",
                    in_ch=len(input_chs),
                    empty_out_depths=[],
                ),
                decoder=dict(
                    type=UNetDecoder,
                    decoder_chs=[128, 64, 48, 32],  # , 24],
                    upsample_mode="nearest",
                    use_batchnorm=True,
                ),
            ),
        ),
        train_dataloader=dict(
            dataset=dict(
                type=WildfireDataset,
                mode="train",
                epoch_scale_factor=10.0,
                kfold_N=0,
                kfold_I=0,
                input_chs=input_chs,
            ),
            batch_size=64,
            sampler=dict(type=DefaultSampler, shuffle=True),
            collate_fn=dict(type=default_collate),
            num_workers=8,
            pin_memory=True,
            persistent_workers=True,
        ),
        train_cfg=dict(by_epoch=True, max_epochs=150),
        optim_wrapper=dict(optimizer=dict(type=AdamW, lr=1e-3)),
        param_scheduler=dict(
            type=CosineRestartLR,
            periods=[10] * 100,
            restart_weights=[1] * 100,
            eta_min=1e-6,
            by_epoch=True,
            convert_to_iter_based=True,
        ),
        default_hooks=dict(
            checkpoint=dict(type=CheckpointHook, interval=10),
            logger=dict(type=LoggerHook, interval=1000),
        ),
        work_dir=str(Path("Logs") / datetime.now().strftime("%y%m%d_%H%M%S")),
    )

    runner.train()

In [None]:
main()

03/27 13:23:12 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
    CUDA available: True
    MUSA available: False
    numpy_random_seed: 1516942946
    GPU 0,1,2,3: NVIDIA GeForce RTX 4090
    CUDA_HOME: /usr/local/cuda
    NVCC: Cuda compilation tools, release 12.4, V12.4.99
    GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
    PyTorch: 2.2.1+cu121
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 12.1
  - NVCC architecture flags: