In [None]:
import torch
import random
import os
import json

import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm

from dataloaders.dataset import CityscapesSegmentation
from utils.custom_utils import calculate_dice_iou

In [2]:
CFG = {
    'NUM_CLASS':8,
    'IMG_SIZE':800,
    'EPOCHS':200,
    'LR':3e-4,
    'BATCH_SIZE':2,
    'SEED':41,
    'CLASS_SEGMENTATION':True,
    'CH4':True
}

if CFG['CH4'] :
    image_mean=[0.485, 0.456, 0.406, 0]
    image_std=[0.229, 0.224, 0.225, 1]

else :
    image_mean=[0.485, 0.456, 0.406]
    image_std=[0.229, 0.224, 0.225]


In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42) # Seed 고정

In [None]:
base_size = 128
crop_size = 128

cityscapes_train = CityscapesSegmentation(base_size, crop_size, split='train')
cityscapes_val = CityscapesSegmentation(base_size, crop_size, split='val')
# cityscapes_test = CityscapesSegmentation(base_size, crop_size, split='test')

train_dataloader = DataLoader(cityscapes_train, batch_size=1, shuffle=True, num_workers=4)
val_dataloader = DataLoader(cityscapes_val, batch_size=1, shuffle=True, num_workers=4)
# test_dataloader = DataLoader(cityscapes_test, batch_size=2, shuffle=True, num_workers=4)

In [6]:
# from torch.utils.data import Dataset, DataLoader, random_split

# dataset_size = len(cityscapes_train)
# train_size = int(dataset_size * 0.99)
# validation_size = int(dataset_size * 0.005)
# test_size = dataset_size - train_size - validation_size
# _, train_dataset, val_dataset = random_split(cityscapes_train, [train_size, validation_size, test_size])


# train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
# val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=4)

In [None]:
import torch.utils.model_zoo as model_zoo

num_class = cityscapes_train.NUM_CLASSES

model = smp.UnetPlusPlus(
                        # encoder=resnet50,
                        encoder_name="resnet50",
                        encoder_weights="imagenet",     
                        in_channels=3,
                        classes=num_class+1,
                        patch_attention=False,
                    )

model.to(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255).to(device) ### without focal loss

model_save_dir = '/home/vision/gyuil/lab/Segmentation/save'

In [None]:
train_losses = []
test_losses = []
train_dice_scores = []
test_dice_scores = []
train_iou_scores = []
test_iou_scores = []
sigmoid = nn.Sigmoid()

for epoch in range(100) :
    model.train()
    train_loss=0
    train_dice_score=0
    train_iou_score=0
    for data in tqdm(train_dataloader) :
        images = data['image'].cuda()
        masks = data['label'].cuda()


        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs.float(), masks.long())
        loss.requires_grad_(True)
        loss.backward()

        preds = torch.softmax(outputs, dim=1).cpu()
        preds = torch.argmax(preds, dim=1).numpy()
            
        optimizer.step()
        # loss
        train_loss += loss.item()

        # dice score
        dice, iou = calculate_dice_iou(preds, masks.squeeze(1).cpu().numpy())
        train_dice_score += dice.item()
        train_iou_score += iou.item()

        del masks, images

    print(f"Epoch {epoch+1}, Train Loss : {train_loss/len(train_dataloader)}, Train Dice Score : {train_dice_score/len(train_dataloader)}, Train IoU Score : {train_iou_score/len(train_dataloader)}")
    train_losses.append(train_loss/len(train_dataloader))
    train_dice_scores.append(train_dice_score/len(train_dataloader))
    train_iou_scores.append(train_iou_score/len(train_dataloader))
    

    # recording & model save
    if epoch % 10 == 0 :
        model.eval()
        test_loss = 0
        test_dice_score = 0
        test_iou_score = 0
        for data in tqdm(val_dataloader) :
            
            images = data['image'].cuda()
            masks = data['label'].cuda()

            outputs = model(images)

            # loss
            loss = criterion(outputs.float(), masks.long())
            preds = torch.softmax(outputs, dim=1).cpu()
            preds = torch.argmax(preds, dim=1).numpy()

            test_loss += loss.item()

            # dice score
            dice, iou = calculate_dice_iou(preds, masks.squeeze(1).cpu().numpy())
            test_dice_score += dice.item()
            test_iou_score += iou.item()

            del masks, images

        print(f"Epoch {epoch+1}, Test Loss : {test_loss/len(val_dataloader)}, Test Dice Score : {test_dice_score/len(val_dataloader)}, Test IoU Score : {test_iou_score/len(val_dataloader)}")
        test_losses.append(test_loss/len(val_dataloader))
        test_dice_scores.append(test_dice_score/len(val_dataloader))
        test_iou_scores.append(test_iou_score/len(val_dataloader))

        model_name = str(epoch) + "_unetpp"
        torch.save(model.state_dict(), os.path.join(model_save_dir, model_name))

    with open(os.path.join(model_save_dir, "unetpp_trainloss.json"), "w") as file :
        json.dump(train_losses, file)

    with open(os.path.join(model_save_dir, "unetpp_valloss.json"), "w") as file :
        json.dump(test_losses, file)

    with open(os.path.join(model_save_dir, "unetpp_traindice.json"), "w") as file :
        json.dump(train_dice_scores, file)

    with open(os.path.join(model_save_dir, "unetpp_testdice.json"), "w") as file :
        json.dump(test_dice_scores, file)

    with open(os.path.join(model_save_dir, "unetpp_trainiou.json"), "w") as file :
        json.dump(train_iou_scores, file)

    with open(os.path.join(model_save_dir, "unetpp_testiou.json"), "w") as file :
        json.dump(test_iou_scores, file)