In [5]:
import os
import datetime
import random
import time
import gc
import sys
import numpy as np

import scipy.spatial.distance as spd

from skimage import io
from skimage import util

from sklearn import metrics

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import torch.nn.functional as F

from base import *
from models import *
from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d, LovaszLoss, FocalLoss2d

import copy

import warnings
warnings.filterwarnings("ignore")

cudnn.benchmark = True

In [6]:
'''
############################################
Vaihingen/Potsdam classes:
    0 = Street
    1 = Building
    2 = Grass
    3 = Tree
    4 = Car
    5 = Surfaces
    6 = Boundaries
############################################

############################################
iSAID classes:
     0 = Background
     1 = Ship
     2 = Small Vehicle
     3 = Helicopter
     4 = Swimming Pool
     5 = Baseball Court
     6 = Storage Tank
     7 = Tennis Court
     8 = Basketball Court
     9 = Ground Track Field
    10 = Harbor
    11 = Bridge
    12 = Large Vehicle
    13 = Soccerball Field
    14 = Plane
    15 = Roundabout
############################################

############################################
GRSS classes:
     0 = Unclassified
     1 = Healthy grass
     2 = Stressed grass
     3 = Artificial turf
     4 = Evergreen trees
     5 = Deciduous trees
     6 = Bare earth
     7 = Water
     8 = Residential buildings
     9 = Non-residential buildings
    10 = Roads
    11 = Sidewalks
    12 = Crosswalks
    13 = Major thoroughfares
    14 = Highways
    15 = Railways
    16 = Paved parking lots
    17 = Unpaved parking lots
    18 = Cars
    19 = Trains
    20 = Stadium seats
############################################
'''

# Predefining directories.
ckpt_path = './ckpt'
outp_path = './outputs'

# Setting predefined arguments.
args = {
    'epoch_num': 1200,            # Number of epochs.
    'lr': 1e-3,                   # Learning rate.
    'weight_decay': 5e-6,         # L2 penalty.
    'momentum': 0.9,              # Momentum.
    'batch_size': 1,              # Batch size.
    'num_workers': 4,             # Number of workers on data loader.
    'print_freq': 1,              # Printing frequency for mini-batch loss.
    'w_size': 224,                # Width size for image resizing.
    'h_size': 224,                # Height size for image resizing.
    'test_freq': 1,               # Test each test_freq epochs.
    'save_freq': 1200,            # Save model each save_freq epochs.
    'input_channels': 4,          # Number of input channels in samples/DNN.
    'num_classes': 5,             # Number of original output classes in dataset.
}

# Reading system parameters.
conv_name = 'unet'
args['hidden_classes'] = '0'
print('hidden: ' + args['hidden_classes'])

dataset_name = 'Vaihingen'

if dataset_name == 'Potsdam':
    
    args['epoch_num'] = 600
    args['test_freq'] = args['test_freq']
    args['save_freq'] = args['save_freq']
    args['num_workers'] = 0
    
hidden = []
if '_' in args['hidden_classes']:
    hidden = [int(h) for h in args['hidden_classes'].split('_')]
else:
    hidden = [int(args['hidden_classes'])]
    
num_known_classes = args['num_classes'] - len(hidden)
num_unknown_classes = len(hidden)


weights = []
if dataset_name == 'iSAID':
    if 0 not in hidden:
        weights = [100.0 for i in range(num_known_classes)]
        weights[0] = 1.0
    else:
        weights = [1.0 for i in range(num_known_classes)]
elif dataset_name == 'GRSS':
    weights = [1.0 for i in range(num_known_classes)]
else:
    weights = [1.0 for i in range(num_known_classes)]
    if 4 not in hidden:
        weights[-1] = 2.0

global_weights = torch.FloatTensor(weights)

# Setting experiment name.
exp_name = conv_name + '_' + dataset_name + '_base_dsm_' + args['hidden_classes']

# # Setting device [0|1|2].
# args['device'] = 0



hidden: 0


In [7]:
# Training function.
def train(train_loader, net, criterion, optimizer, epoch, num_known_classes, num_unknown_classes, hidden, args):

    # Setting network for training mode.
    net.train()

    # Average Meter for batch loss.
    train_loss = list()

    prds_all = []
    labs_all = []

    # Iterating over batches.
    for i, data in enumerate(train_loader):
        
        inps, labs, true, img_name = None, None, None, None
        
        if dataset_name == 'GRSS':
            
            # Obtaining images and labels for batch.
            inps, labs, true = data
            
        else:
            
            # Obtaining images, labels and paths for batch.
            inps, labs, true, img_name = data
        
        # Casting tensors to cuda.
#         inps, labs, true = inps.cuda(), labs.cuda(), true.cuda()
#         inps, labs, true = inps.cuda(args['device']), labs.cuda(args['device']), true.cuda(args['device'])
        
        # Casting to cuda variables.
        inps = inps.cuda()#args['device'])
        labs = labs.cuda()#args['device'])
        true = true.cuda()#args['device'])
        
        if dataset_name == 'iSAID':
            
            inps = inps.view(inps.size(0) * inps.size(1), inps.size(2), inps.size(3), inps.size(4))
            labs = labs.view(labs.size(0) * labs.size(1), labs.size(2), labs.size(3))
            true = true.view(true.size(0) * true.size(1), true.size(2), true.size(3))
            
        else:
            
#             print('inps before', inps.size())
#             print('labs before', labs.size())
#             print('true before', true.size())
#             sys.stdout.flush()
            
            inps.squeeze_(0)
            labs.squeeze_(0)
            true.squeeze_(0)
            
#             print('inps after', inps.size())
#             print('labs after', labs.size())
#             print('true after', true.size())
#             sys.stdout.flush()
        
        # Clears the gradients of optimizer.
        optimizer.zero_grad()
        
        # Forwarding.
        outs = net(inps)
        soft_outs = F.softmax(outs, dim=1)
        
        # Obtaining predictions.
        prds = soft_outs.data.max(1)[1]
        
        # Computing loss.
        loss = criterion(outs, labs)
        
        # Computing backpropagation.
        loss.backward()
        optimizer.step()
        
        # Appending images for epoch loss calculation.
        prds = prds.squeeze_(1).squeeze_(0).cpu().numpy()
        
        inps_np = inps.detach().squeeze(0).cpu().numpy()
        labs_np = labs.detach().squeeze(0).cpu().numpy()
        true_np = true.detach().squeeze(0).cpu().numpy()

        prds_all.append(prds)
        labs_all.append(labs_np)
        
        # Updating loss meter.
        train_loss.append(loss.data.item())
        
        # Printing.
        #if (i + 1) % args['print_freq'] == 0:
        #    print('[epoch %d], [iter %d / %d], [train loss %.5f]' % (epoch, i + 1, len(train_loader), np.asarray(train_loss).mean()))
        #    sys.stdout.flush()
    if epoch % args['print_freq'] == 0:
        print(evaluate(prds_all, labs_all, num_known_classes))
    
    sys.stdout.flush()

def test(test_loader, net, criterion, optimizer, epoch, num_known_classes, num_unknown_classes, hidden, args, save_images, save_model):
    
    if save_model:

        torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, 'model_' + str(epoch) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + str(epoch) + '.pth'))
    
    # Setting network for evaluation mode.
    net.eval()
    labs_all = []
    prds_all = []
    
    with torch.no_grad():
        
        # Creating output directory.
        if save_images:
            check_mkdir(os.path.join(outp_path, exp_name, 'validation', 'epoch_' + str(epoch)))
        
        # Iterating over batches.
        for i, data in enumerate(test_loader):
            
            #print('Test Batch %d/%d' % (i + 1, len(test_loader)))
            sys.stdout.flush()
            
            # Obtaining images, labels and paths for batch.
            inps_batch, labs_batch, true_batch, img_name = data
            
            inps_batch = inps_batch.squeeze()
            labs_batch = labs_batch.squeeze()
            true_batch = true_batch.squeeze()
            
            # Iterating over patches inside batch.
            for j in range(inps_batch.size(0)):
                
                #print('    Test MiniBatch %d/%d' % (j + 1, inps_batch.size(0)))
                sys.stdout.flush()
                
                tic = time.time()
                
                for k in range(inps_batch.size(1)):
                    
                    inps = inps_batch[j, k].unsqueeze(0)
                    labs = labs_batch[j, k].unsqueeze(0)
                    true = true_batch[j, k].unsqueeze(0)
                    
                    # Casting tensors to cuda.
#                     inps, labs, true = inps.cuda(args['device']), labs.cuda(args['device']), true.cuda(args['device'])
                    
                    # Casting to cuda variables.
                    inps = inps.cuda()#args['device'])
                    labs = labs.cuda()#args['device'])
                    true = true.cuda()#args['device'])
                    
                    # Forwarding.
                    if conv_name == 'unet' or conv_name == 'unet2':
                        outs, dec1, dec2, dec3, dec4 = net(inps, feat=True)
                    elif conv_name == 'fcnresnet50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnresnext50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnwideresnet50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcndensenet121':
                        outs, classif1, fv2 = net(x=inps, feat=True)
                    elif conv_name == 'fcndensenet121pretrained':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnvgg19':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'fcnvgg19pretrained':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'fcninceptionv3':
                        outs, classif1, fv4 = net(inps, feat=True)
                    elif conv_name == 'fcnmobilenetv2':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'segnet':
                        outs, x_10d, x_20d = net(inps, feat=True)
                    
                    # Computing probabilities.
                    soft_outs = F.softmax(outs, dim=1)
                    
                    # Obtaining prior predictions.
                    prds = soft_outs.data.max(1)[1]
                    
                    # Obtaining posterior predictions.
                    if conv_name == 'unet' or conv_name == 'unet2':
                        feat_flat = torch.cat([outs, dec1, dec2, dec3], 1)
                    elif conv_name == 'fcnresnet50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnresnext50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnwideresnet50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcndensenet121':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcndensenet121pretrained':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnvgg19':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'fcnvgg19pretrained':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'fcninceptionv3':
                        feat_flat = torch.cat([outs, classif1, fv4], 1)
                    elif conv_name == 'fcnmobilenetv2':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'segnet':
                        feat_flat = torch.cat([outs, x_10d, x_20d], 1)
                    feat_flat = feat_flat.permute(0, 2, 3, 1).contiguous().view(feat_flat.size(0) * feat_flat.size(2) * feat_flat.size(3), feat_flat.size(1)).cpu().numpy()
                    prds_flat = prds.cpu().numpy().ravel()
                    true_flat = true.cpu().numpy().ravel()
                    
                    # Appending images for epoch loss calculation.
                    inps_np = inps.detach().squeeze(0).cpu().numpy()
                    labs_np = labs.detach().squeeze(0).cpu().numpy()
                    true_np = true.detach().squeeze(0).cpu().numpy()

                    prds = prds.cpu().squeeze().numpy()
                    prds_all.append(prds)
                    labs_all.append(labs_np)
                    
                    # Saving predictions.
                    if (save_images):
                        
                        imag_path = os.path.join(outp_path, exp_name, 'validation', 'epoch_' + str(epoch), img_name[0].replace('.tif', '_img_' + str(j) + '_' + str(k) + '.png'))
                        mask_path = os.path.join(outp_path, exp_name, 'validation', 'epoch_' + str(epoch), img_name[0].replace('.tif', '_msk_' + str(j) + '_' + str(k) + '.png'))
                        true_path = os.path.join(outp_path, exp_name, 'validation', 'epoch_' + str(epoch), img_name[0].replace('.tif', '_tru_' + str(j) + '_' + str(k) + '.png'))
                        pred_path = os.path.join(outp_path, exp_name, 'validation', 'epoch_' + str(epoch), img_name[0].replace('.tif', '_prd_' + str(j) + '_' + str(k) + '.png'))
                        
                        io.imsave(imag_path, util.img_as_ubyte(((np.transpose(inps_np, (1, 2, 0)) + 0.5) * 255).astype(int)))
                        io.imsave(mask_path, util.img_as_ubyte(labs_np))
                        io.imsave(true_path, util.img_as_ubyte(true_np))
                        io.imsave(pred_path, util.img_as_ubyte(prds))
                
                toc = time.time()
                #print('        Elapsed Time: %.2f' % (toc - tic))
        
        sys.stdout.flush()
        print('#######################')
        print('Test evaluation')
        results = evaluate(prds_all, labs_all, num_known_classes)
        print(results[:-1])
        print('#######################')
        return results[2]

In [8]:
# Setting network architecture.
if (conv_name == 'unet'):

    net = UNet(args['input_channels'], num_classes=args['num_classes'], hidden_classes=hidden).cuda()#args['device'])

elif (conv_name == 'fcnwideresnet50'):

    net = FCNWideResNet50(args['input_channels'], num_classes=args['num_classes'], pretrained=False, skip=True, hidden_classes=hidden).cuda()#args['device'])

elif (conv_name == 'fcndensenet121'):

    net = FCNDenseNet121(args['input_channels'], num_classes=args['num_classes'], pretrained=False, skip=True, hidden_classes=hidden).cuda()#args['device'])

#net = nn.DataParallel(net)
#print(net)
print(conv_name)
sys.stdout.flush()

curr_epoch = 1
args['best_record'] = {'epoch': 0, 'lr': 1e-4, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'iou': 0}

# Setting datasets.
train_set = ListDataset(dataset_name, 'Train', (args['h_size'], args['w_size']), 'statistical', hidden, overlap=False, use_dsm=True, dataset_path='../datasets/')
train_loader = DataLoader(train_set, batch_size=args['batch_size'], num_workers=args['num_workers'], shuffle=True)

test_set = ListDataset(dataset_name, 'Validate', (args['h_size'], args['w_size']), 'statistical', hidden, overlap=True, use_dsm=True, dataset_path='../datasets/')
test_loader = DataLoader(test_set, batch_size=1, num_workers=args['num_workers'], shuffle=False)

# Setting criterion.
#criterion = CrossEntropyLoss2d(weight=global_weights, size_average=False, ignore_index=args['num_classes']).cuda()#args['device'])
criterion = CrossEntropyLoss2d(weight=global_weights, size_average=False, ignore_index=args['num_classes']-len(hidden)).cuda()#args['device'])

# Setting optimizer.
#     optimizer = optim.SGD(net.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
optimizer = optim.Adam([
    {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
     'lr': 2 * args['lr']},
    {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
     'lr': args['lr'], 'weight_decay': args['weight_decay']}
], betas=(args['momentum'], 0.99))

scheduler = None
if dataset_name == 'GRSS':
    scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 3, 0.2)
elif dataset_name == 'iSAID':
    scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 5, 0.2)
else:
    scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 3, 0.2)

# Making sure checkpoint and output directories are created.
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
check_mkdir(outp_path)
check_mkdir(os.path.join(outp_path, exp_name))

# Writing training args to experiment log file.
open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')


best_model = copy.deepcopy(net)
best_iou = -1
best_epoch = 0
# Iterating over epochs.
for epoch in range(curr_epoch, args['epoch_num'] + 1):

    # Training function.
    print('Epoch ',epoch)
    train(train_loader, net, criterion, optimizer, epoch, num_known_classes, num_unknown_classes, hidden, args)

    if epoch % args['test_freq'] == 0:

        # Computing test.
        current_iou = test(test_loader, net, criterion, optimizer, epoch, num_known_classes, num_unknown_classes, hidden, args, False, False)
        if current_iou > best_iou:
            best_iou = current_iou
            best_model = copy.deepcopy(net)
            best_epoch = epoch
            print('New best IoU found: ', current_iou)
        else:
            print('Previous best IoU {0} at epoch {1}'.format(best_iou, best_epoch))

    scheduler.step()
print('Final test')
test(test_loader, best_model, criterion, optimizer, args['epoch_num'], num_known_classes, num_unknown_classes, hidden, args, False, True)


unet
self.n_classes 4
self.hidden_classes [0]
self.n_classes 4
self.hidden_classes [0]
Epoch  1
(0.570134372418266, 0.48920732061879857, 0.33335205273284535, array([0.5572341 , 0.34277792, 0.4035657 , 0.02983049]), 0.4373203948544799)
#######################
Test evaluation
(0.5223507838141513, 0.5701997676628833, 0.2796065461732598, array([0.279873  , 0.27652125, 0.54394564, 0.01808629]))
#######################
New best IoU found:  0.2796065461732598
Epoch  2
(0.7157686422838635, 0.546820093459593, 0.4262437499528323, array([0.73607507, 0.44996797, 0.4855483 , 0.03338367]), 0.5732261160377429)
#######################
Test evaluation
(0.7519596494992054, 0.657776038584521, 0.5080872273736017, array([0.80668025, 0.50643774, 0.60693295, 0.11229797]))
#######################
New best IoU found:  0.5080872273736017
Epoch  3
(0.7421290112984154, 0.6308882108551878, 0.48938725560020807, array([0.83931326, 0.46217732, 0.50139119, 0.15466725]), 0.6102264138120168)
#######################
Test

KeyboardInterrupt: 