In [None]:
import torch
import random
import os
import json

import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm

from dataloaders.dataset import CityscapesSegmentation
from utils.custom_utils import calculate_miou_mdice

In [2]:
CFG = {
    'NUM_CLASS':8,
    'EPOCHS':500,
    'LR':0.01,
    'BATCH_SIZE':2,
    'SEED':41,
}

model_save_dir = '/home/vision/gyuil/lab/Segmentation/save'

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42) # Seed 고정

In [None]:
base_size = (1024, 2048)
crop_size = (512, 1024)

cityscapes_train = CityscapesSegmentation(base_size, crop_size, split='train')
cityscapes_val = CityscapesSegmentation(base_size, crop_size, split='val')
# cityscapes_test = CityscapesSegmentation(base_size, crop_size, split='test')

train_dataloader = DataLoader(cityscapes_train, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=4)
val_dataloader = DataLoader(cityscapes_val, batch_size=1, shuffle=True, num_workers=4)
# test_dataloader = DataLoader(cityscapes_test, batch_size=2, shuffle=True, num_workers=4)

In [6]:
from torch.utils.data import Dataset, DataLoader, random_split

dataset_size = len(cityscapes_train)
train_size = int(dataset_size * 0.99)
validation_size = int(dataset_size * 0.005)
test_size = dataset_size - train_size - validation_size
_, train_dataset, val_dataset = random_split(cityscapes_train, [train_size, validation_size, test_size])


train_dataloader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=CFG["BATCH_SIZE"], shuffle=True, num_workers=4)

In [7]:
# import matplotlib.pyplot as plt

# data = next(iter(train_dataloader))


# origin_image = data[0]
# aug_image = data[1]
# mask = data[2]

# plt.figure(figsize=(15,15))


# plt.subplot(1,3,1)
# plt.imshow(origin_image.squeeze(0))

# plt.subplot(1,3,2)
# plt.imshow(aug_image.squeeze(0).permute(1,2,0))

# plt.subplot(1,3,3)
# plt.imshow(mask.permute(1,2,0))

In [None]:
num_class = cityscapes_train.NUM_CLASSES

model = smp.UnetPlusPlus(
                        # encoder=resnet50,
                        encoder_name="resnet101",
                        encoder_weights="imagenet",
                        in_channels=3,
                        classes=num_class+1,
                        patch_attention=True,
                        dataset_type="cityscapes" # #aitod
                    )

model.to(device)

# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=CFG['LR'], momentum=0.9, weight_decay=0.0005)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255).to(device) ### without focal loss

In [None]:
print("Dataset : ", model.encoder.dataset_type,", Patch att : ", model.encoder.patch_attention)

In [10]:
def poly_learning_rate_epoch(initial_lr, current_epoch, total_epochs, power=0.9):
    return initial_lr * (1 - current_epoch / total_epochs) ** power

In [None]:
train_log = {}
test_log = {}

train_losses = []
test_losses = []
train_dice_scores = []
test_dice_scores = []
train_iou_scores = []
test_iou_scores = []

sigmoid = nn.Sigmoid()

for epoch in range(CFG['EPOCHS']):

    current_lr = poly_learning_rate_epoch(CFG['LR'], epoch, CFG['EPOCHS'])
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    model.train()
    train_loss = 0
    train_mdice_score = 0
    train_miou_score=  0

    for data in tqdm(train_dataloader):
        images = data[0].cuda()
        masks = data[1].cuda()
        masks[masks == 255.] = num_class

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs.float(), masks.long())
        loss.backward()

        preds = torch.softmax(outputs, dim=1).cpu()
        preds = torch.argmax(preds, dim=1).numpy()

        optimizer.step()

        # loss
        train_loss += loss.item()

        # dice, iou score
        miou, mdice = calculate_miou_mdice(preds, masks.squeeze(1).cpu().numpy())
        train_mdice_score += mdice
        train_miou_score += miou
        

        del masks, images

    print(f"Epoch {epoch+1}, Train Loss : {train_loss/len(train_dataloader)}, Train Dice Score : {train_mdice_score/len(train_dataloader)}, Train IoU Score : {train_miou_score/len(train_dataloader)}")
    train_losses.append(train_loss / len(train_dataloader))
    train_dice_scores.append(train_mdice_score / len(train_dataloader))
    train_iou_scores.append(train_miou_score / len(train_dataloader))
    

    # recording & model save
    if epoch % 10 == 0 :
        model.eval()
        test_loss = 0
        test_mdice_score = 0
        test_miou_score = 0
        for data in tqdm(val_dataloader) :
            
            images = data[0].cuda()
            masks = data[1].cuda()

            outputs = model(images)

            # loss
            loss = criterion(outputs.float(), masks.long())
            preds = torch.softmax(outputs, dim=1).cpu()
            preds = torch.argmax(preds, dim=1).numpy()

            test_loss += loss.item()

            # dice score
            miou, mdice = calculate_miou_mdice(preds, masks.squeeze(1).cpu().numpy())
            test_mdice_score += mdice
            test_miou_score += miou

            del masks, images

        print(f"Epoch {epoch+1}, Test Loss : {test_loss/len(val_dataloader)}, Test Dice Score : {test_mdice_score/len(val_dataloader)}, Test IoU Score : {test_miou_score/len(val_dataloader)}")
        test_losses.append(test_loss/len(val_dataloader))
        test_dice_scores.append(test_mdice_score/len(val_dataloader))
        test_iou_scores.append(test_miou_score/len(val_dataloader))

        model_name = str(epoch) + "_unetpp"
        torch.save(model.state_dict(), os.path.join(model_save_dir, model_name))


    # Recording Train log
    train_log['train_losses'] = train_losses
    train_log['train_dice_scores'] = train_dice_scores
    train_log['train_iou_scores'] = train_iou_scores
    with open(os.path.join(model_save_dir, "unetpp_train_log.json"), "w") as file :
        json.dump(train_log, file)

    # Recording Test log
    # test_log['test_losses'] = test_losses
    # test_log['test_dice_scores'] = test_dice_scores
    # test_log['test_iou_scores'] = test_iou_scores
    # with open(os.path.join(model_save_dir, "unetpp_test_log.json"), "w") as file :
    #     json.dump(test_log, file)

    #train
    # train_losses
    # train_dice_scores
    # train_iou_scores

    #test
    # test_losses
    # test_dice_scores
    # test_iou_scores
    

    # with open(os.path.join(model_save_dir, "unetpp_trainloss.json"), "w") as file :
    #     json.dump(train_losses, file)

    # with open(os.path.join(model_save_dir, "unetpp_valloss.json"), "w") as file :
    #     json.dump(test_losses, file)

    # with open(os.path.join(model_save_dir, "unetpp_traindice.json"), "w") as file :
    #     json.dump(train_dice_scores, file)

    # with open(os.path.join(model_save_dir, "unetpp_testdice.json"), "w") as file :
    #     json.dump(test_dice_scores, file)

    # with open(os.path.join(model_save_dir, "unetpp_trainiou.json"), "w") as file :
    #     json.dump(train_iou_scores, file)

    # with open(os.path.join(model_save_dir, "unetpp_testiou.json"), "w") as file :
    #     json.dump(test_iou_scores, file)