In [None]:
import os
import random
from functools import partial
from typing import List

import torch
import torch.nn as nn

import numpy as np
import albumentations as A

In [None]:
from torch.utils.data import DataLoader
from resized_dataset import make_dataset
from setseed import set_seed

BATCH_SIZE = 8
resize = 512
RANDOM_SEED = 21
augmentation = None
set_seed(RANDOM_SEED)

train_dataset, valid_dataset = make_dataset(RANDOM_SEED = RANDOM_SEED, augmentation=augmentation)

train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=6,
    drop_last=True,
)

valid_loader = DataLoader(
    dataset=valid_dataset, 
    batch_size=2,
    shuffle=False,
    num_workers=2,
    drop_last=False
)

In [None]:
import timm
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders._base import EncoderMixin

class hrnet_encoder(nn.Module, EncoderMixin):
    def __init__(self, **kwargs):
        super().__init__()
        self.net = timm.create_model('hrnet_w18', pretrained=False, num_classes=29)

    def forward(self, x):
        out = self.net(x)

        return out

smp.encoders.encoders["hrnet_w18"] = {
    "encoder": hrnet_encoder, # encoder class here
    "pretrained_settings": {
        "imagenet": {
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "url": 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth',
            "input_space": "RGB",
            "input_range": [0, 1],
        },
    },
    "params": {
        "pretrain_img_size": 224,
        "crop_pct": 0.95
    }
}

In [None]:
import segmentation_models_pytorch as smp

set_seed(RANDOM_SEED)

model = smp.DeepLabV3Plus(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=29,                      # model output channels (number of classes in your dataset)
)
model.name = 'DeepLabV3Plus_resnet50'

In [None]:
LR = 1e-4
epoch = 50
VAL_EVERY = 1 # validation 주기

In [None]:
import torch.optim as optim

set_seed(RANDOM_SEED)

optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6)

In [None]:
from loss import FocalLoss, DiceLoss

set_seed(RANDOM_SEED)

#criterion = nn.BCEWithLogitsLoss()
#criterion = nn.SmoothL1Loss()
criterion = [(nn.BCEWithLogitsLoss(), 0.75), (DiceLoss(), 0.25)]
#criterion = [(nn.BCEWithLogitsLoss(), 0.5), (nn.SmoothL1Loss(), 0.5)]

In [None]:
#모델 저장 경로 설정
try:
    folder_name = "[{}]_[size:{}]_[loss:{}]_[LR:{}]_[seed:{}]_[epoch:{}]".format(model.name, (resize, resize), criterion, LR, RANDOM_SEED, epoch)
except:
    folder_name = "[{}]_[size:{}]_[loss:{}]_[LR:{}]_[seed:{}]_[epoch:{}]".format("model", (resize, resize), criterion, LR, RANDOM_SEED, epoch)

In [None]:
from train import train

set_seed(RANDOM_SEED)

folder_name = train(model, train_loader, valid_loader, criterion, optimizer, epoch, VAL_EVERY, folder_name)

In [None]:
from inference import inference

inference(folder_name, A.Resize(resize, resize))

In [None]:
set_seed(RANDOM_SEED)

model = smp.DeepLabV3Plus(
    encoder_name="timm-efficientnet-b4",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=29,                      # model output channels (number of classes in your dataset)
)
model.name = 'DeepLabV3Plus_efficientnet-b4'

optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6)

criterion = [(nn.BCEWithLogitsLoss(), 0.5), (nn.SmoothL1Loss(), 0.5)]

In [None]:
#모델 저장 경로 설정
try:
    folder_name = "[{}]_[size:{}]_[loss:{}]_[LR:{}]_[seed:{}]_[epoch:{}]".format(model.name, (resize, resize), criterion, LR, RANDOM_SEED, epoch)
except:
    folder_name = "[{}]_[size:{}]_[loss:{}]_[LR:{}]_[seed:{}]_[epoch:{}]".format("model", (resize, resize), criterion, LR, RANDOM_SEED, epoch)

In [None]:
set_seed(RANDOM_SEED)

folder_name = train(model, train_loader, valid_loader, criterion, optimizer, epoch, VAL_EVERY, folder_name)

In [None]:
#inference(folder_name, A.Resize(resize, resize))