In [1]:
import datetime
import os
import random
import time
import gc
import sys
import numpy as np

import scipy.spatial.distance as spd

from skimage import io
from skimage import util

from sklearn import metrics

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import torch.nn.functional as F

from base import *
from models import *
from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d, LovaszLoss, FocalLoss2d

import warnings
warnings.filterwarnings("ignore")

cudnn.benchmark = True

In [2]:
def get_curr_metric(msk, prd, n_known):
    
    tru_np = msk.ravel()
    prd_np = prd.ravel()
    
    tru_valid = tru_np[tru_np < (n_known + 1)]
    prd_valid = prd_np[tru_np < (n_known + 1)]
    
    print('        Computing CM...')
    cm = metrics.confusion_matrix(tru_valid, prd_valid)

    print('        Computing Accs...')
    tru_known = 0.0
    sum_known = 0.0

    for c in range(n_known):
        tru_known += float(cm[c, c])
        sum_known += float(cm[c, :].sum())

    acc_known = float(tru_known) / float(sum_known)
    
    tru_unknown = float(cm[n_known, n_known])
    sum_unknown_real = float(cm[n_known, :].sum())
    sum_unknown_pred = float(cm[:, n_known].sum())
    
    pre_unknown = 0.0
    rec_unknown = 0.0
    
    if sum_unknown_pred != 0.0:
        pre_unknown = float(tru_unknown) / float(sum_unknown_pred)
    if sum_unknown_real != 0.0:
        rec_unknown = float(tru_unknown) / float(sum_unknown_real)
        
    acc_unknown = (tru_known + tru_unknown) / (sum_known + sum_unknown_real)
    
    acc_mean = (acc_known + acc_unknown) / 2.0
    
    print('        Computing Balanced Acc...')
    bal = metrics.balanced_accuracy_score(tru_valid, prd_valid)
    
    print('        Computing Kappa...')
    kap = metrics.cohen_kappa_score(tru_valid, prd_valid)
    
    curr_metrics = [acc_known, acc_unknown, pre_unknown, rec_unknown, bal, kap]
    
    return curr_metrics

def get_metrics(msk, prd, num_classes):
    msk = np.array(msk)
    prd = np.array(prd)
    metrics = get_curr_metric(msk, prd, num_classes)
    
    return metrics


def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist


def confusion_matrix(predictions, gts, num_classes):
    hist = np.zeros((num_classes, num_classes))
    for lp, lt in zip(predictions, gts):
        hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
        
    return hist


def kappa_score(confusion):
    n_classes = confusion.shape[0]
    sum0 = np.sum(confusion, axis=0)
    sum1 = np.sum(confusion, axis=1)
    expected = np.outer(sum0, sum1) / np.sum(sum0)

    w_mat = np.ones([n_classes, n_classes], dtype=np.int)
    w_mat.flat[:: n_classes + 1] = 0
    
    k = np.sum(w_mat * confusion) / np.sum(w_mat * expected)
    return 1 - k


def evaluate(predictions, gts, num_classes):
    hist = np.zeros((num_classes, num_classes))
    for lp, lt in zip(predictions, gts):
        hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)

    print('metrics')
    metrics = get_metrics(gts, predictions, num_classes)
    print('[acc_known, acc_unknown, pre_unknown, rec_unknown, bal, kap]')
    print(metrics)
        
    #print(hist)
    # axis 0: gt, axis 1: prediction
       
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
    kappa = kappa_score(hist)
    return acc, acc_cls, mean_iu, iu, fwavacc, kappa


In [3]:

'''
############################################
Vaihingen/Potsdam classes:
    0 = Street
    1 = Building
    2 = Grass
    3 = Tree
    4 = Car
    5 = Surfaces
    6 = Boundaries
############################################

'''

# Predefining directories.
ckpt_path = './ckpt'
outp_path = './outputs'

# Setting predefined arguments.
args = {
    'epoch_num': 1200,            # Number of epochs.
    'lr': 1e-3,                   # Learning rate.
    'weight_decay': 5e-6,         # L2 penalty.
    'momentum': 0.9,              # Momentum.
    'batch_size': 1,              # Batch size.
    'num_workers': 8,             # Number of workers on data loader.
    'print_freq': 1,              # Printing frequency for mini-batch loss.
    'w_size': 224,                # Width size for image resizing.
    'h_size': 224,                # Height size for image resizing.
    'test_freq': 1200,            # Test each test_freq epochs.
    'save_freq': 1200,            # Save model each save_freq epochs.
    'input_channels': 4,          # Number of input channels in samples/DNN.
    'num_classes': 5,             # Number of original output classes in dataset.
}

# Reading system parameters.
conv_name = 'unet'
args['hidden_classes'] = '0'
print('hidden: ' + args['hidden_classes'])

dataset_name = 'Vaihingen'

if dataset_name == 'Potsdam':
    
    args['epoch_num'] = 600
    args['test_freq'] = args['epoch_num']
    args['save_freq'] = args['epoch_num']
    args['num_workers'] = 1
    
hidden = []
if '_' in args['hidden_classes']:
    hidden = [int(h) for h in args['hidden_classes'].split('_')]
else:
    hidden = [int(args['hidden_classes'])]
    
num_known_classes = args['num_classes'] - len(hidden)
num_unknown_classes = len(hidden)


weights = []
weights = [1.0 for i in range(num_known_classes)]
if 4 not in hidden:
    weights[-1] = 2.0

global_weights = torch.FloatTensor(weights)

# Setting experiment name.
exp_name = conv_name + '_' + dataset_name + '_base_dsm_' + args['hidden_classes']

pretrained_path = os.path.join(ckpt_path, exp_name, 'model_' + str(args['epoch_num']) + '.pth')
# # Setting device [0|1|2].
# args['device'] = 0

hidden: 0


In [4]:

def test(test_loader, net, criterion, optimizer, epoch, num_known_classes, num_unknown_classes, hidden, args, save_images, save_model):
    
    if save_model:

        torch.save(net.state_dict(), os.path.join(ckpt_path, exp_name, 'model_' + str(epoch) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join(ckpt_path, exp_name, 'opt_' + str(epoch) + '.pth'))
    
    # Setting network for evaluation mode.
    net.eval()
    
    with torch.no_grad():
        preds_all = []
        gts_all = []
        # Creating output directory.
        check_mkdir(os.path.join(outp_path, exp_name, 'epoch_' + str(epoch)))
        
        # Iterating over batches.
        for i, data in enumerate(test_loader):
            
            print('Test Batch %d/%d' % (i + 1, len(test_loader)))
            sys.stdout.flush()
            
            # Obtaining images, labels and paths for batch.
#             inps_batch, labs_batch, true_batch, img_name = data
            
            inps_batch, labs_batch, true_batch, img_name = None, None, None, None
            
            if dataset_name == 'GRSS':
                
                # Obtaining images and labels for batch.
                inps_batch, labs_batch, true_batch = data
                img_name = 'image.tif'
                
            else:
                
                # Obtaining images, labels and paths for batch.
                inps_batch, labs_batch, true_batch, img_name = data
            
            inps_batch = inps_batch.squeeze()
            labs_batch = labs_batch.squeeze()
            true_batch = true_batch.squeeze()
            
            # Iterating over patches inside batch.
            for j in range(inps_batch.size(0)):
                
                #print('    Test MiniBatch %d/%d' % (j + 1, inps_batch.size(0)))
                #sys.stdout.flush()
                
                tic = time.time()
                
                for k in range(inps_batch.size(1)):
                    
                    inps = inps_batch[j, k].unsqueeze(0)
                    labs = labs_batch[j, k].unsqueeze(0)
                    true = true_batch[j, k].unsqueeze(0)
                    
                    # Casting tensors to cuda.
#                     inps, labs, true = inps.cuda(args['device']), labs.cuda(args['device']), true.cuda(args['device'])
                    
                    # Casting to cuda variables.
                    inps = inps.cuda()#args['device'])
                    labs = labs.cuda()#args['device'])
                    true = true.cuda()#args['device'])
                    
                    # Forwarding.
                    if conv_name == 'unet':
                        outs, dec1, dec2, dec3, dec4 = net(inps, feat=True)
                    elif conv_name == 'fcnresnet50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnresnext50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnwideresnet50':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcndensenet121':
                        outs, classif1, fv2 = net(x=inps, feat=True)
                    elif conv_name == 'fcndensenet121pretrained':
                        outs, classif1, fv2 = net(inps, feat=True)
                    elif conv_name == 'fcnvgg19':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'fcnvgg19pretrained':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'fcninceptionv3':
                        outs, classif1, fv4 = net(inps, feat=True)
                    elif conv_name == 'fcnmobilenetv2':
                        outs, classif1, fv3 = net(inps, feat=True)
                    elif conv_name == 'segnet':
                        outs, x_10d, x_20d = net(inps, feat=True)
                    
                    # Computing probabilities.
                    soft_outs = F.softmax(outs, dim=1)
                    
                    # Obtaining prior predictions.
                    prds = soft_outs.data.max(1)[1]
                    
                    # Obtaining posterior predictions.
                    if conv_name == 'unet':
                        feat_flat = torch.cat([outs, dec1, dec2, dec3], 1)
                    elif conv_name == 'fcnresnet50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnresnext50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnwideresnet50':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcndensenet121':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcndensenet121pretrained':
                        feat_flat = torch.cat([outs, classif1, fv2], 1)
                    elif conv_name == 'fcnvgg19':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'fcnvgg19pretrained':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'fcninceptionv3':
                        feat_flat = torch.cat([outs, classif1, fv4], 1)
                    elif conv_name == 'fcnmobilenetv2':
                        feat_flat = torch.cat([outs, classif1, fv3], 1)
                    elif conv_name == 'segnet':
                        feat_flat = torch.cat([outs, x_10d, x_20d], 1)
                    
                    feat_flat = feat_flat.permute(0, 2, 3, 1).contiguous().view(feat_flat.size(0) * feat_flat.size(2) * feat_flat.size(3), feat_flat.size(1)).cpu().numpy()
                    prds_flat = prds.cpu().numpy().ravel()
                    true_flat = true.cpu().numpy().ravel()
                    
                    # Appending images for epoch loss calculation.
                    inps_np = inps.detach().squeeze(0).cpu().numpy()
                    labs_np = labs.detach().squeeze(0).cpu().numpy()
                    true_np = true.detach().squeeze(0).cpu().numpy()

                    preds_all.append(prds.detach().squeeze(0).cpu().numpy())
                    gts_all.append(true_np)
                    
                    # Saving predictions.
                    if (save_images):
                        
                        if dataset_name == 'GRSS':
                            imag_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name.replace('.tif', '_img_' + str(j) + '_' + str(k) + '.png'))
                            mask_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name.replace('.tif', '_msk_' + str(j) + '_' + str(k) + '.png'))
                            true_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name.replace('.tif', '_tru_' + str(j) + '_' + str(k) + '.png'))
                            pred_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name.replace('.tif', '_prd_' + str(j) + '_' + str(k) + '.png'))
                        else:
                            imag_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name[0].replace('.tif', '_img_' + str(j) + '_' + str(k) + '.png'))
                            mask_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name[0].replace('.tif', '_msk_' + str(j) + '_' + str(k) + '.png'))
                            true_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name[0].replace('.tif', '_tru_' + str(j) + '_' + str(k) + '.png'))
                            pred_path = os.path.join(outp_path, exp_name, 'epoch_' + str(epoch), img_name[0].replace('.tif', '_prd_' + str(j) + '_' + str(k) + '.png'))
                        
                        if inps_np.shape[0] == 4:
                            inps_np = inps_np[:3,:,:]
                        
                        inps_np = ((np.transpose(inps_np, (1, 2, 0)) + 0.5) * 255).astype(np.uint8)
                        
#                         print('inps_np', inps_np.shape, inps_np.min(), inps_np.max())
#                         print('labs_np', labs_np.shape)
#                         print('true_np', true_np.shape)
#                         print('prds', prds.cpu().squeeze().numpy().shape)
                        
                        io.imsave(imag_path, inps_np)
#                         io.imsave(imag_path, util.img_as_ubyte((np.transpose(inps_np, (1, 2, 0)) + 0.5) * 255))
                        io.imsave(mask_path, util.img_as_ubyte(labs_np))
                        io.imsave(true_path, util.img_as_ubyte(true_np))
                        io.imsave(pred_path, util.img_as_ubyte(prds.cpu().squeeze().numpy()))
                
                toc = time.time()
                #print('        Elapsed Time: %.2f' % (toc - tic))
        
        sys.stdout.flush()

        acc, acc_cls, mean_iou, iou, fwavacc, kappa = evaluate(preds_all, gts_all, num_known_classes)
        print('[acc %.4f], [acc_cls %.4f], [iou %.4f], [fwavacc %.4f], [kappa %.4f]' % (acc, acc_cls, mean_iou, fwavacc, kappa))


In [5]:
# Setting network architecture.
if (conv_name == 'unet'):

    net = UNet(args['input_channels'], num_classes=args['num_classes'], hidden_classes=hidden).cuda()#args['device'])

elif (conv_name == 'fcnwideresnet50'):

    net = FCNWideResNet50(args['input_channels'], num_classes=args['num_classes'], pretrained=False, skip=True, hidden_classes=hidden).cuda()#args['device'])

elif (conv_name == 'fcndensenet121'):

    net = FCNDenseNet121(args['input_channels'], num_classes=args['num_classes'], pretrained=False, skip=True, hidden_classes=hidden).cuda()#args['device'])

print('Loading pretrained weights from file "' + pretrained_path + '"')
net.load_state_dict(torch.load(pretrained_path))

#net = nn.DataParallel(net)
#print(net)
sys.stdout.flush()

curr_epoch = 1
args['best_record'] = {'epoch': 0, 'lr': 1e-4, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'iou': 0}

# Setting datasets.
train_set = ListDataset(dataset_name, 'Train', (args['h_size'], args['w_size']), 'statistical', hidden, overlap=False, use_dsm=True, dataset_path='../datasets/')
train_loader = DataLoader(train_set, batch_size=args['batch_size'], num_workers=args['num_workers'], shuffle=True)

test_set = ListDataset(dataset_name, 'Test', (args['h_size'], args['w_size']), 'statistical', hidden, overlap=True, use_dsm=True, dataset_path='../datasets/')
test_loader = DataLoader(test_set, batch_size=1, num_workers=args['num_workers'], shuffle=False)

# Setting criterion.
criterion = CrossEntropyLoss2d(weight=global_weights, size_average=False, ignore_index=args['num_classes']).cuda()#args['device'])

# Setting optimizer.
#     optimizer = optim.SGD(net.parameters(), lr=args['lr'], momentum=args['momentum'], weight_decay=args['weight_decay'])
optimizer = optim.Adam([
    {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
     'lr': 2 * args['lr']},
    {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
     'lr': args['lr'], 'weight_decay': args['weight_decay']}
], betas=(args['momentum'], 0.99))

#     scheduler = None
#     if dataset_name == 'GRSS':
#         scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 5, 0.2)
#     elif dataset_name == 'iSAID':
#         scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 5, 0.2)
#     else:
#         scheduler = optim.lr_scheduler.StepLR(optimizer, args['epoch_num'] // 3, 0.2)

# Making sure checkpoint and output directories are created.
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
check_mkdir(outp_path)
check_mkdir(os.path.join(outp_path, exp_name))

# Writing training args to experiment log file.
open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

# Computing test.
test(test_loader, net, criterion, optimizer, args['epoch_num'], num_known_classes, num_unknown_classes, hidden, args, True, False) 

Loading pretrained weights from file "./ckpt/unet_Vaihingen_base_dsm_0/model_1200.pth"
self.n_classes 4
self.hidden_classes [0]
self.n_classes 4
self.hidden_classes [0]
Test Batch 1/5
Test Batch 2/5
Test Batch 3/5
Test Batch 4/5
Test Batch 5/5
metrics
        Computing CM...
        Computing Accs...
        Computing Balanced Acc...
        Computing Kappa...
[acc_known, acc_unknown, pre_unknown, rec_unknown, bal, kap]
[0.8681293560935579, 0.6466960467338152, 0.0, 0.0, 0.6695770367890546, 0.5391024317265145]
[acc 0.8681], [acc_cls 0.8370], [iou 0.7509], [fwavacc 0.7725], [kappa 0.8048]
