In [None]:
#라이브러리 임포트
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from tensorboardX import SummaryWriter

In [None]:
#학습에 필요한 하이퍼파라미터
train_batch_size = 32
test_batch_size = 8
learning_rate = 0.0001
epoch = 100

In [None]:
#트랜스폼, 데이터셋, 데이터로더
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

voc_train = datasets.VOCSegmentation(root='../Data/', year='2012', image_set='train', transform=transform, target_transform=transform, download=False)
voc_test = datasets.VOCSegmentation(root='../Data/', year='2012', image_set='val', transform=transform, target_transform=transform, download=False)

train_loader = DataLoader(voc_train, batch_size=train_batch_size, shuffle=True, num_workers=1, drop_last=True)
test_loader = DataLoader(voc_test, batch_size=test_batch_size, shuffle=False, num_workers=1, drop_last=True)

In [None]:
class FCN(nn.Module):
    def __init__(self, num_classes=21):
        super(FCN, self).__init__()
        
        # 1/2
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels= 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) # ceil_mode=True -> 바닥함수 대신 천장함수 사용
        )
        
        # 1/4
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        )
        
        # 1/8
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        )
        
        # 1/16
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        )
        
        # 1/32
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
        )
        
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=4096, kernel_size=1),
            nn.ReLU(),
            nn.Dropout2d()
        )
        
        self.fc2 = nn.Sequential(
            nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1),
            nn.ReLU(),
            nn.Dropout2d()
        )
        
        self.score = nn.Sequential(
            nn.Conv2d(in_channels=4096, out_channels=num_classes, kernel_size=1),
            nn.ConvTranspose2d(in_channels=num_classes, out_channels=num_classes, kernel_size=64, stride=32, padding=16)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.score(x)
        return x

In [None]:
# 사용할 장치, 모델, 손실함수, 옵티마이저 선언
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = FCN(21).to(device)

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# 학습
step_loss_arr = []
epoch_loss_arr = []

for i in range(epoch):
    train_loader = tqdm(train_loader)
    epoch_loss = 0
    for j, [image, label] in enumerate(train_loader):
        image = image.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        
        output = model.forward(image)
        loss = torch.sqrt(loss_func(output, label))
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        train_loader.set_postfix({"Loss" : loss})
        step_loss_arr.append(loss.item())
        
    epoch_loss_arr.append(epoch_loss/len(train_loader))

torch.save(model.state_dict(), 'modle/fcn.pth')

In [None]:
# loss 변화량 시각화
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].plot(step_loss_arr)
axes[1].plot(epoch_loss_arr)

plt.show()

In [None]:
# 이미지, 레이블, 레이블 예측값 시각화
fig, axes = plt.subplots(test_batch_size, 3, figsize=(3*5, test_batch_size*5))

image, label = next(iter(test_loader))
image = image.to(device)
label = label.to(device)
output = model.forward(image)
output = torch.argmax(output, dim=1)

iou_scores = []
for i in range(test_batch_size):
    image_ = image[i].permute(1, 2, 0).cpu().detach()
    label_ = label[i].permute(1, 2, 0).cpu().detach()
    label_class_predicted = output[i].cpu().detach().numpy()

    axes[i, 0].set_title("image")
    axes[i, 0].imshow(image_)
    axes[i, 1].set_title("label")
    axes[i, 1].imshow(label_)
    axes[i, 2].set_title("label predicted")
    axes[i, 2].imshow(label_class_predicted)
    
plt.show()

In [None]:
# 평가
iou_scores = []
for _, [image, label] in enumerate(tqdm(test_loader)):
    image = image.to(device)
    label = label.to(device)
    output = model.forward(image)
    output = torch.argmax(output, dim=1)
    
    for i in range(len(image)):
        image_ = image[i].permute(1, 2, 0).cpu().detach()
        label_ = label[i].permute(1, 2, 0).cpu().detach()
        label_class_predicted = output[i].cpu().detach().numpy()
    
        # IOU score
        intersection = np.logical_and(label_, label_class_predicted)
        union = np.logical_or(label_, label_class_predicted)
        iou_score = torch.sum(intersection) / torch.sum(union)
        iou_scores.append(iou_score)
    
print('IOU score: ' + str(sum(iou_scores).item() / len(iou_scores)))