In [1]:
# ticks = ['background', 'aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','dining table','dog','horse','motorbike','person','potted plant','sheep','sofa','train','tv/monitor']
# print(len(ticks))
# plt.figure(figsize=(18, 18))
# cmap=plt.cm.gist_ncar
# norm = matplotlib.colors.BoundaryNorm(np.arange(-0.5,21,1), cmap.N)
# plt.imshow(image, norm=norm, cmap=cmap)
# cbar = plt.colorbar(ticks=np.arange(0, 21, 1))
# cbar.set_ticklabels(ticks)

In [2]:
%run unet_dataset.ipynb
%run unet_modules.ipynb

import os
import matplotlib.pyplot as plt
import cv2
import numpy as np
import matplotlib
import matplotlib.patches as mpatches
import time
import datetime
import copy
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torchvision import transforms
from torchsummary import summary

# Network Architecture        
class Net(nn.Module):
    def __init__(self, n_class=25, n_channels=3):
        super(Net, self).__init__()
        
        self.n_class = n_class
        self.n_channels = n_channels
        
        self.inconv = double_conv(in_channels=self.n_channels, out_channels=64)
        self.encode1 = encode(in_channels=64, out_channels=128)
        self.encode2 = encode(in_channels=128, out_channels=256)
        self.encode3 = encode(in_channels=256, out_channels=512)
        self.encode4 = encode(in_channels=512, out_channels=1024)
        
        self.decode1 = decode(in_channels=1024, out_channels=512, bilinear=False)
        self.decode2 = decode(in_channels=512, out_channels=256, bilinear=False)
        self.decode3 = decode(in_channels=256, out_channels=128, bilinear=False)
        self.decode4 = decode(in_channels=128, out_channels=64, bilinear=False)
        self.outconv = nn.Conv2d(in_channels=64, out_channels=self.n_class, kernel_size=1, padding=0)
        
    def forward(self, x):
        x1 = self.inconv(x)
        x2 = self.encode1(x1)
        x3 = self.encode2(x2)
        x4 = self.encode3(x3)
        x5 = self.encode4(x4)
        
        x = self.decode1(x5, x4)
        x = self.decode2(x, x3)
        x = self.decode3(x, x2)
        x = self.decode4(x, x1)
        x = self.outconv(x)
                
        return x

##### Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##### Data augmentation
trans = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)
                               ])

##### Datasets & DataLoaders
train_dataset = train_dataset(trans)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
test_dataset = test_dataset()
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)


net = Net()
net = net.to(device)
summary(net, input_size=(3, 608, 416))

##### Loss function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

def train(train_loader, test_loader):
    print("Training U-Net has been started!\n")
    best_acc = 0
    for epoch in range(1000):
        ##### Learning rate scheduler
        scheduler.step()
        if epoch % 100 == 99:
            print("Learning rate has been decayed!")
        running_loss = 0.0
        
        for i, data in enumerate(train_loader):
            ##### Images (N, C, H, W), Annotations (N, H, W)
            images, annotations = data
            images, annotations = images.to(device=device), annotations.to(device=device, dtype=torch.int64)

            ##### zero the parameter gradients
            optimizer.zero_grad()
            
            ##### forward propagation + backward propagation + optimization
            outputs = net(images)
            loss = criterion(outputs, annotations)
            loss.backward()
            optimizer.step()
            
            ##### print statistics every 38 mini-batches
            running_loss += loss.item()
            if i % 38 == 37:
                print('[{:3d}, {:4d}] loss: {:.4f}'.format(epoch + 1, i + 1, running_loss / 38))
                running_loss = 0.0
        
        ##### Save the best model based on test accuracy
        test_acc = 0
        if epoch % 5 == 0:
            test_acc = test(save=True, test_loader=test_loader)
        else:
            test_acc = test(save=False, test_loader=test_loader)
        if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(net)
            torch.save(best_model.state_dict(), 'unet_best_model.pt')
            mysize = os.path.getsize('unet_best_model.pt')
            print("Best model's size: %.4f MB\n" %(mysize/1024/1024))
        

def test(save, test_loader):
    correct = 0
    total = 0
    accuracy = 0
    
    pic = None
    cmap = plt.cm.nipy_spectral
    norm = matplotlib.colors.BoundaryNorm(np.arange(-0.5, 25, 1), cmap.N)
    ticks = ['background', 'skin', 'hair', 'bag', 'belt', 'boots','coat', 'dress', 
             'glasses', 'gloves', 'hat','jacket', 'necklace', 'pants', 'scarf', 'shirt', 'shoes', 'shorts',
             'skirt', 'socks', 'sweater', 'tights', 'top', 'vest', 'watch']        
    
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            ##### Images (N, C, H, W), Annotations (N, H, W)
            images, annotations = data
            images, annotations = images.to(device=device), annotations.to(device=device, dtype=torch.int64)
            outputs = net(images)
            
            predicted = torch.argmax(outputs, dim=1)
            
            N = annotations.size(0)

            total += N
            correct += (predicted.view(-1) == (annotations.squeeze(1).view(-1))).sum().item()
            accuracy = 100 * correct / total / (416*608)
            
            pic = predicted
            
            if save == True:
                for j, item in enumerate(pic):
                    if (N*i+j)% 25 == 0:
                        plt.figure(figsize=(18, 15))
                        plt.imshow(item.cpu().numpy(), norm=norm, cmap=cmap)
                        cbar = plt.colorbar(ticks = np.arange(0, 25, 1))
                        cbar.set_ticklabels(ticks)
                        plt.savefig(os.path.join('./unet_segmented/{:03d}.png'.format(N*i+j)))
                        plt.clf()
                        plt.close('all')

    print('Accuracy on test images: {:.6f}%\n'.format(accuracy))
    return accuracy

if __name__ == '__main__':
    train(train_loader, test_loader)

True
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 608, 416]           1,792
       BatchNorm2d-2         [-1, 64, 608, 416]             128
              ReLU-3         [-1, 64, 608, 416]               0
            Conv2d-4         [-1, 64, 608, 416]          36,928
       BatchNorm2d-5         [-1, 64, 608, 416]             128
              ReLU-6         [-1, 64, 608, 416]               0
         Dropout2d-7         [-1, 64, 608, 416]               0
       double_conv-8         [-1, 64, 608, 416]               0
         MaxPool2d-9         [-1, 64, 304, 208]               0
           Conv2d-10        [-1, 128, 304, 208]          73,856
      BatchNorm2d-11        [-1, 128, 304, 208]             256
             ReLU-12        [-1, 128, 304, 208]               0
           Conv2d-13        [-1, 128, 304, 208]         147,584
      BatchNorm2d-14        [-1, 1


[ 15,   38] loss: 0.9334
[ 15,   76] loss: 0.9335
[ 15,  114] loss: 0.9170
Accuracy on test images: 79.809793%

[ 16,   38] loss: 0.9066
[ 16,   76] loss: 0.9379
[ 16,  114] loss: 0.9124
Accuracy on test images: 80.102769%

Best model's size: 118.4980 MB

[ 17,   38] loss: 0.8901
[ 17,   76] loss: 0.9418
[ 17,  114] loss: 0.9131
Accuracy on test images: 80.009512%

[ 18,   38] loss: 0.9046
[ 18,   76] loss: 0.9243
[ 18,  114] loss: 0.9062
Accuracy on test images: 80.029082%

[ 19,   38] loss: 0.8901
[ 19,   76] loss: 0.9013
[ 19,  114] loss: 0.8852
Accuracy on test images: 80.360450%

Best model's size: 118.4980 MB

[ 20,   38] loss: 0.8919
[ 20,   76] loss: 0.9133
[ 20,  114] loss: 0.8878
Accuracy on test images: 80.281690%

[ 21,   38] loss: 0.8623
[ 21,   76] loss: 0.9178
[ 21,  114] loss: 0.9024
Accuracy on test images: 80.104197%

[ 22,   38] loss: 0.8755
[ 22,   76] loss: 0.8931
[ 22,  114] loss: 0.9074
Accuracy on test images: 80.324639%

[ 23,   38] loss: 0.8660
[ 23,   76] lo

[ 82,   76] loss: 0.7364
[ 82,  114] loss: 0.7474
Accuracy on test images: 82.776306%

[ 83,   38] loss: 0.7270
[ 83,   76] loss: 0.7415
[ 83,  114] loss: 0.7326
Accuracy on test images: 83.246267%

Best model's size: 118.4980 MB

[ 84,   38] loss: 0.7528
[ 84,   76] loss: 0.7473
[ 84,  114] loss: 0.7134
Accuracy on test images: 83.280084%

Best model's size: 118.4980 MB

[ 85,   38] loss: 0.7547
[ 85,   76] loss: 0.6858
[ 85,  114] loss: 0.7647
Accuracy on test images: 82.828078%

[ 86,   38] loss: 0.7385
[ 86,   76] loss: 0.7338
[ 86,  114] loss: 0.7405
Accuracy on test images: 82.908578%

[ 87,   38] loss: 0.7090
[ 87,   76] loss: 0.7489
[ 87,  114] loss: 0.7373
Accuracy on test images: 82.839656%

[ 88,   38] loss: 0.7310
[ 88,   76] loss: 0.7155
[ 88,  114] loss: 0.7210
Accuracy on test images: 83.259775%

[ 89,   38] loss: 0.7212
[ 89,   76] loss: 0.7275
[ 89,  114] loss: 0.7314
Accuracy on test images: 83.327937%

Best model's size: 118.4980 MB

[ 90,   38] loss: 0.7190
[ 90,   

[153,   38] loss: 0.6871
[153,   76] loss: 0.7158
[153,  114] loss: 0.6595
Accuracy on test images: 83.875269%

[154,   38] loss: 0.6709
[154,   76] loss: 0.6761
[154,  114] loss: 0.6594
Accuracy on test images: 84.142873%

[155,   38] loss: 0.7246
[155,   76] loss: 0.6592
[155,  114] loss: 0.6615
Accuracy on test images: 84.175173%

Best model's size: 118.4980 MB

[156,   38] loss: 0.6743
[156,   76] loss: 0.6995
[156,  114] loss: 0.7038
Accuracy on test images: 83.990388%

[157,   38] loss: 0.6819
[157,   76] loss: 0.6878
[157,  114] loss: 0.6726
Accuracy on test images: 84.023078%

[158,   38] loss: 0.6856
[158,   76] loss: 0.6513
[158,  114] loss: 0.6942
Accuracy on test images: 83.971328%

[159,   38] loss: 0.6684
[159,   76] loss: 0.6757
[159,  114] loss: 0.6842
Accuracy on test images: 84.153381%

[160,   38] loss: 0.6611
[160,   76] loss: 0.6841
[160,  114] loss: 0.6971
Accuracy on test images: 84.046807%

[161,   38] loss: 0.6729
[161,   76] loss: 0.6728
[161,  114] loss: 0.70

[224,   38] loss: 0.6651
[224,   76] loss: 0.6333
[224,  114] loss: 0.6993
Accuracy on test images: 84.044813%

[225,   38] loss: 0.6522
[225,   76] loss: 0.6432
[225,  114] loss: 0.6496
Accuracy on test images: 84.181998%

[226,   38] loss: 0.6552
[226,   76] loss: 0.6727
[226,  114] loss: 0.7123
Accuracy on test images: 84.150713%

[227,   38] loss: 0.6701
[227,   76] loss: 0.6855
[227,  114] loss: 0.6570
Accuracy on test images: 84.066054%

[228,   38] loss: 0.6619
[228,   76] loss: 0.7068
[228,  114] loss: 0.6528
Accuracy on test images: 84.308600%

[229,   38] loss: 0.6575
[229,   76] loss: 0.6482
[229,  114] loss: 0.6771
Accuracy on test images: 84.142181%

[230,   38] loss: 0.6654
[230,   76] loss: 0.6616
[230,  114] loss: 0.6732
Accuracy on test images: 84.062392%

[231,   38] loss: 0.6792
[231,   76] loss: 0.6666
[231,  114] loss: 0.6593
Accuracy on test images: 84.264721%

[232,   38] loss: 0.6409
[232,   76] loss: 0.6746
[232,  114] loss: 0.6584
Accuracy on test images: 84.2

[296,  114] loss: 0.6871
Accuracy on test images: 84.352994%

[297,   38] loss: 0.6443
[297,   76] loss: 0.6567
[297,  114] loss: 0.6708
Accuracy on test images: 84.048214%

[298,   38] loss: 0.6724
[298,   76] loss: 0.6600
[298,  114] loss: 0.6830
Accuracy on test images: 84.214328%

[299,   38] loss: 0.6503
[299,   76] loss: 0.6552
[299,  114] loss: 0.6366
Accuracy on test images: 84.135335%

Learning rate has been decayed!
[300,   38] loss: 0.6649
[300,   76] loss: 0.6731
[300,  114] loss: 0.6567
Accuracy on test images: 84.152725%

[301,   38] loss: 0.6583
[301,   76] loss: 0.6653
[301,  114] loss: 0.6591
Accuracy on test images: 84.188582%

[302,   38] loss: 0.6580
[302,   76] loss: 0.6560
[302,  114] loss: 0.6626
Accuracy on test images: 84.272434%

[303,   38] loss: 0.6879
[303,   76] loss: 0.6700
[303,  114] loss: 0.6335
Accuracy on test images: 84.201447%

[304,   38] loss: 0.6566
[304,   76] loss: 0.6697
[304,  114] loss: 0.6580
Accuracy on test images: 84.169549%

[305,   38

[369,   38] loss: 0.6559
[369,   76] loss: 0.6606
[369,  114] loss: 0.6982
Accuracy on test images: 84.161932%

[370,   38] loss: 0.6660
[370,   76] loss: 0.6556
[370,  114] loss: 0.6632
Accuracy on test images: 84.213979%

[371,   38] loss: 0.6712
[371,   76] loss: 0.6674
[371,  114] loss: 0.6557
Accuracy on test images: 84.167265%

[372,   38] loss: 0.6890
[372,   76] loss: 0.6763
[372,  114] loss: 0.6384
Accuracy on test images: 84.139134%

[373,   38] loss: 0.6783
[373,   76] loss: 0.6616
[373,  114] loss: 0.6543
Accuracy on test images: 84.155351%

[374,   38] loss: 0.6374
[374,   76] loss: 0.6824
[374,  114] loss: 0.6789
Accuracy on test images: 84.229268%

[375,   38] loss: 0.6800
[375,   76] loss: 0.6524
[375,  114] loss: 0.6903
Accuracy on test images: 84.297230%

[376,   38] loss: 0.6602
[376,   76] loss: 0.6427
[376,  114] loss: 0.6984
Accuracy on test images: 84.121199%

[377,   38] loss: 0.6900
[377,   76] loss: 0.6783
[377,  114] loss: 0.6821
Accuracy on test images: 84.1


[442,   38] loss: 0.6519
[442,   76] loss: 0.6728
[442,  114] loss: 0.6595
Accuracy on test images: 84.244129%

[443,   38] loss: 0.6904
[443,   76] loss: 0.6614
[443,  114] loss: 0.6360
Accuracy on test images: 84.127770%

[444,   38] loss: 0.6452
[444,   76] loss: 0.6926
[444,  114] loss: 0.6366
Accuracy on test images: 84.221894%

[445,   38] loss: 0.6611
[445,   76] loss: 0.6610
[445,  114] loss: 0.6905
Accuracy on test images: 84.125960%

[446,   38] loss: 0.6748
[446,   76] loss: 0.6664
[446,  114] loss: 0.6538
Accuracy on test images: 84.123699%

[447,   38] loss: 0.6682
[447,   76] loss: 0.6345
[447,  114] loss: 0.6322
Accuracy on test images: 84.247379%

[448,   38] loss: 0.6756
[448,   76] loss: 0.6738
[448,  114] loss: 0.6596
Accuracy on test images: 84.226742%

[449,   38] loss: 0.6916
[449,   76] loss: 0.6723
[449,  114] loss: 0.6475
Accuracy on test images: 84.163139%

[450,   38] loss: 0.6822
[450,   76] loss: 0.6614
[450,  114] loss: 0.6663
Accuracy on test images: 84.


[515,   38] loss: 0.6443
[515,   76] loss: 0.6518
[515,  114] loss: 0.6665
Accuracy on test images: 84.239801%

[516,   38] loss: 0.6579
[516,   76] loss: 0.6941
[516,  114] loss: 0.6505
Accuracy on test images: 84.233912%

[517,   38] loss: 0.6615
[517,   76] loss: 0.6594
[517,  114] loss: 0.6508
Accuracy on test images: 84.239744%

[518,   38] loss: 0.6576
[518,   76] loss: 0.6909
[518,  114] loss: 0.6424
Accuracy on test images: 84.262981%

[519,   38] loss: 0.6749
[519,   76] loss: 0.6647
[519,  114] loss: 0.6527
Accuracy on test images: 84.259766%

[520,   38] loss: 0.6441
[520,   76] loss: 0.6670
[520,  114] loss: 0.6655
Accuracy on test images: 84.232641%

[521,   38] loss: 0.6908
[521,   76] loss: 0.6609
[521,  114] loss: 0.6171
Accuracy on test images: 84.207462%

[522,   38] loss: 0.6605
[522,   76] loss: 0.6734
[522,  114] loss: 0.6834
Accuracy on test images: 84.204635%

[523,   38] loss: 0.6656
[523,   76] loss: 0.6448
[523,  114] loss: 0.6618
Accuracy on test images: 84.

[588,   38] loss: 0.6721
[588,   76] loss: 0.6565
[588,  114] loss: 0.6735
Accuracy on test images: 84.306779%

[589,   38] loss: 0.6760
[589,   76] loss: 0.6591
[589,  114] loss: 0.6524
Accuracy on test images: 84.123778%

[590,   38] loss: 0.6641
[590,   76] loss: 0.6835
[590,  114] loss: 0.6695
Accuracy on test images: 84.159807%

[591,   38] loss: 0.6537
[591,   76] loss: 0.6707
[591,  114] loss: 0.6538
Accuracy on test images: 84.201100%

[592,   38] loss: 0.6873
[592,   76] loss: 0.6588
[592,  114] loss: 0.6994
Accuracy on test images: 84.276705%

[593,   38] loss: 0.6452
[593,   76] loss: 0.6865
[593,  114] loss: 0.6755
Accuracy on test images: 84.212077%

[594,   38] loss: 0.6888
[594,   76] loss: 0.6736
[594,  114] loss: 0.6646
Accuracy on test images: 84.182661%

[595,   38] loss: 0.6530
[595,   76] loss: 0.6729
[595,  114] loss: 0.6656
Accuracy on test images: 84.270429%

[596,   38] loss: 0.6705
[596,   76] loss: 0.6714
[596,  114] loss: 0.6641
Accuracy on test images: 84.1

PermissionError: [Errno 13] Permission denied: './unet_segmented/175.png'