In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# 라이브러리
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import albumentations as A
from albumentations.pytorch import transforms

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

In [3]:
# device setting
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

print(device)

cuda


In [4]:
# directory setting
class ROOTDIR:
    train_dir = '/content/drive/MyDrive/Colab Notebooks/Github/COSE474Project/cityscapes_data/train/'
    val_dir = '/content/drive/MyDrive/Colab Notebooks/Github/COSE474Project/cityscapes_data/val/'

In [5]:
# image file list
train_filenames = sorted(glob.glob(ROOTDIR.train_dir+"*.jpg"))
val_filenames = sorted(glob.glob(ROOTDIR.val_dir+"*.jpg"))

len(train_filenames), len(val_filenames)

(790, 0)

In [None]:
val_filenames, test_filenames = train_test_split(val_filenames, test_size=0.1)

In [None]:
# visualize
def show(img_list, train, loop=1):
    if train:
        dir = ROOTDIR.train_dir
    else:
        dir = ROOTDIR.val_dir

    for i in range(loop):
        img_dir = os.path.join(dir, img_list[i])
        img = Image.open(img_dir)
        print(f"image path: {dir + img_list[i]}")
        plt.imshow(img)
        plt.title("image")
        plt.show()

In [None]:
show(img_list=train_filenames, train=True)

In [None]:
# Dataset
class Cityscape(Dataset):
    def __init__(self, img_path, transform=None):
        self.img_path = img_path
        self.transform = transform

    # img list 길이
    def __len__(self):
        return len(self.img_path)

    # get image, label
    def __getitem__(self, idx):
        data = Image.open(self.img_path[idx]).convert("RGB")
        data = np.array(data)

        img, label = data[:, :256, :], data[:, 256:, :]

        if self.transform:
            augmented = self.transform(image = img, mask = label)
            img, label = augmented['image'], augmented['mask']
            img = torch.from_numpy(img).permute(2, 0, 1)
            label = torch.from_numpy(label).permute(2, 0, 1)

        return img, label

In [None]:
# Define Transforms
train_Transforms = A.Compose([
    A.ElasticTransform(),
])

val_Transforms = None

In [None]:
# Data Augmentation
train_dataset = Cityscape(train_filenames, train_Transforms)
val_dataset = Cityscape(val_filenames, val_Transforms)

In [None]:
# Batch
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, 1, shuffle=True)

In [None]:
x, y = next(iter(train_loader))
plt.imshow(y[0].permute(1, 2, 0))
plt.show()

### Build Model ###

In [None]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super(UNet, self).__init__()

        self.num_classes = num_classes

        # down convolution
        self.down_conv_1 = self.convBlock(in_channels=3, out_channels=64)
        self.down_conv_2 = self.convBlock(in_channels=64, out_channels=128)
        self.down_conv_3 = self.convBlock(in_channels=128, out_channels=256)
        self.down_conv_4 = self.convBlock(in_channels=256, out_channels=512)
        self.down_conv_5 = self.convBlock(in_channels=512, out_channels=1024)

        # Max Pooling
        self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Dropout
        self.dropout = nn.Dropout2d(0.2)

        # up convolution transpose
        self.up_conv_trans_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.up_conv_trans_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.up_conv_trans_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.up_conv_trans_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)

        # up convolution
        self.up_conv_1 = self.convBlock(in_channels=1024, out_channels=512)
        self.up_conv_2 = self.convBlock(in_channels=512, out_channels=256)
        self.up_conv_3 = self.convBlock(in_channels=256, out_channels=128)
        self.up_conv_4 = self.convBlock(in_channels=128, out_channels=64)

        # output
        self.output = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

    def convBlock(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
            )
        return block

    def padder(self, tensorL, tensorR):
        if tensorL.shape != tensorR.shape:
            padded = torch.zeros(tensorL.shape)
            padded[:, :, :tensorR.shape[2], :tensorR.shape[3]] = tensorR
            return padded.to(device)
        return tensorR.to(device)

    def forward(self, x):
        # encoder
        x1 = self.down_conv_1(x)
        p1 = self.maxPool(x1)
        drop1 = self.dropout(p1)
        x2 = self.down_conv_2(drop1)
        p2 = self.maxPool(x2)
        drop2 = self.dropout(p2)
        x3 = self.down_conv_3(drop2)
        p3 = self.maxPool(x3)
        drop3 = self.dropout(p3)
        x4 = self.down_conv_4(drop3)
        p4 = self.maxPool(x4)
        drop5 = self.dropout(p4)
        x5 = self.down_conv_5(drop5)

        # decoder
        d1 = self.up_conv_trans_1(x5)
        #print(f"d1 shape: {d1.shape}")
        concat1 = torch.cat([d1, x4], dim=1)
        #print(f"concat1 shape: {concat1.shape}")
        u1 = self.up_conv_1(concat1)
        #print(f"u1 shape: {u1.shape}")
        d2 = self.up_conv_trans_2(u1)
        #print(f"d2 shape: {d2.shape}")
        concat2 = torch.cat([d2, x3], dim=1)
        #print(f"concat2 shape: {concat2.shape}")
        u2 = self.up_conv_2(concat2)
        #print(f"u2 shape: {u2.shape}")
        d3 = self.up_conv_trans_3(u2)
        #print(f"d3 shape: {d3.shape}")
        concat3 = torch.cat([d3, x2], dim=1)
        #print(f"concat3 shape: {concat3.shape}")
        u3 = self.up_conv_3(concat3)
        #print(f"u3 shape: {u3.shape}")
        d4 = self.up_conv_trans_4(u3)
        #print(f"d4 shape: {d4.shape}")
        concat4 = torch.cat([d4, x1], dim=1)
        #print(f"concat4 shape: {concat4.shape}")
        u4 = self.up_conv_4(concat4)
        #print(f"u4 shape: {u4.shape}")
        output = self.output(u4)
        #print(f"output shape: {output.shape}")

        return output

### Train Model ###

In [None]:
# tensorboard
writer = SummaryWriter()

In [None]:
def dice(pred, target):
    pred = (pred > 0).float()
    return 2.0 * (pred*target).sum() / (pred+target).sum()

In [None]:
def train_model(model, dataloader, criterion, optimizer, i):
    model.train()
    train_running_loss = 0.0
    train_running_dice = 0.0

    for j, data in enumerate(tqdm(dataloader)):
        optimizer.zero_grad()

        img = data[0].float().to(device)
        mask = data[1].float().to(device)

        y_pred = model(img)

        loss = criterion(y_pred, mask)

        loss.backward()

        writer.add_scalar("Loss/train", loss, j+i*len(dataloader))

        train_running_loss += loss.item()*batch_size
        train_running_dice += dice(y_pred, mask)

        optimizer.step()

    train_loss = train_running_loss / (j+1)
    train_dice = train_running_dice / (j+1)

    return train_loss, train_dice

In [None]:
def val_model(model, dataloader, criterion, i):
    model.eval()
    val_running_loss = 0.0
    val_running_dice = 0.0

    with torch.no_grad():
        for j, data in enumerate(tqdm(dataloader)):
            img = data[0].float().to(device)
            mask = data[1].float().to(device)

            y_pred = model(img)

            loss = criterion(y_pred, mask)

            writer.add_scalar("Loss/validation", loss, j+i*len(dataloader))

            val_running_loss += loss.item()*batch_size
            val_running_dice += dice(y_pred, mask)

        val_loss = val_running_loss / (j+1)
        val_dice = val_running_dice / (j+1)

    return val_loss, val_dice, model

In [None]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=20, verbose=False, delta=0, trace_func=print):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.val_loss = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.trace_func = trace_func

    def __call__(self, val_loss, model, file_name):
        if self.val_loss is None:
            self.val_loss = val_loss
            self.save_checkpoint(val_loss, model, file_name)
        elif val_loss > self.val_loss + self.delta:
            self.counter += 1
            self.trace_func(f"Early Stopping Counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.val_loss = val_loss
            self.save_checkpoint(val_loss, model, file_name)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, file_name):
        if self.verbose:
            self.trace_func(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
        torch.save(model.state_dict(), file_name)
        self.val_loss_min = val_loss

In [None]:
epochs = 30
lr = 0.001
num_classes = 3
model = UNet(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
es = EarlyStopping(patience=20, verbose=False, delta=0.01)

In [None]:
train_loss_arr = []
train_dice_arr = []
val_loss_arr = []
val_dice_arr = []

In [None]:
for i in tqdm(range(epochs)):
    train_loss, train_dice = train_model(model=model, dataloader=train_loader, criterion=criterion, optimizer=optimizer, i=i)
    val_loss, val_dice, model = val_model(model=model, dataloader=val_loader, criterion=criterion, i=i)
    train_loss_arr.append(train_loss)
    train_dice_arr.append(train_dice)
    val_loss_arr.append(val_loss)
    val_dice_arr.append(val_dice)

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Train Dice score: {train_dice:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"Validation Dice score: {val_dice:.4f}")

    TRAINED_FILE = f"/home/kmk/DL_final_project/checkpoint/unet_scratch_{i}_epoch.pth"

    es(val_loss, model, TRAINED_FILE)

    if es.early_stop:
        writer.close()
        break