In [2]:
from fastai.vision.all import *
from fastai.callback.tracker import SaveModelCallback
# from fastprogress.fastprogress import master_bar, progress_bar

from gepcore.utils import convolution
from gepcore.utils import cell_graph
from gepcore.entity import Gene, Chromosome
from gepcore.symbol import PrimitiveSet
from nas_seg.seg_model import get_net, arch_config, Network
from nas_seg.utils import code_to_rgb
from nas_seg.isprs_dataset import ISPRSDataset, img_to_mask, mask_to_img
from pygraphviz import AGraph
import glob

#from tqdm import tqdm
from skimage import io
from sklearn.metrics import confusion_matrix

torch.backends.cudnn.benchmark = True

if torch.cuda.is_available():
  print("Great! Good to go!")
else:
  print('CUDA is not up!')

Great! Good to go!


In [None]:
from gepcore.utils import convolution
from gepcore.utils import cell_graph
from gepcore.entity import Gene, Chromosome, KExpressionGraph
from gepcore.symbol import PrimitiveSet
#from gepcore.operators import *
from nas_seg.seg_model import get_net, arch_config
from ptflops import get_model_complexity_info

# import fastai deep learning tools
from fastai.vision.all import *


# imports and stuff
import numpy as np
from skimage import io
from glob import glob
from pathlib import Path

from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix
import random
import itertools
# Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable

In [None]:
# ISPRS color palette
# Let's define the standard ISPRS color palette
palette = {0 : (255, 255, 255), # Impervious surfaces (white)
           1 : (0, 0, 255),     # Buildings (blue)
           2 : (0, 255, 255),   # Low vegetation (cyan)
           3 : (0, 255, 0),     # Trees (green)
           4 : (255, 255, 0),   # Cars (yellow)
           5 : (255, 0, 0),     # Clutter (red)
           6 : (0, 0, 0)}       # Undefined (black)

invert_palette = {v: k for k, v in palette.items()}

def convert_to_color(arr_2d, palette=palette):
    """ Numeric labels to RGB-color encoding """
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)

    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d

def convert_from_color(arr_3d, palette=invert_palette):
    """ RGB-color encoding to grayscale labels """
    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)

    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i

    return arr_2d

# # We load one tile from the dataset and we display it
# img = io.imread('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/top/top_mosaic_09cm_area11.tif')
# fig = plt.figure()
# fig.add_subplot(121)
# plt.imshow(img)

# # We load the ground truth
# gt = io.imread('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/gts_for_participants/top_mosaic_09cm_area11.tif')
# fig.add_subplot(122)
# plt.imshow(gt)
# plt.show()

# # We also check that we can convert the ground truth into an array format
# array_gt = convert_from_color(gt)
# print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt)

In [None]:
# Utils

def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[:-1] #img.shape[-2:]
    x1 = random.randint(0, W - w - 1)
    x2 = x1 + w
    y1 = random.randint(0, H - h - 1)
    y2 = y1 + h
    return x1, x2, y1, y2


def accuracy(input, target):
    return 100 * float(np.count_nonzero(input == target)) / target.size

def sliding_window(top, step=10, window_size=(20,20)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]
            
def count_sliding_window(top, step=10, window_size=(20,20)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

def metrics(predictions, gts, label_values=LABELS):
    cm = confusion_matrix(
            gts,
            predictions,
            range(len(label_values)))
    
    print("Confusion matrix :")
    print(cm)
    
    print("---")
    
    # Compute global accuracy
    total = sum(sum(cm))
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    accuracy *= 100 / float(total)
    print("{} pixels processed".format(total))
    print("Total accuracy : {}%".format(accuracy))
    
    print("---")
    
    # Compute F1 score
    F1Score = np.zeros(len(label_values))
    for i in range(len(label_values)):
        try:
            F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]))
        except:
            # Ignore exception if there is no element in class i for test set
            pass
    print("F1Score :")
    for l_id, score in enumerate(F1Score):
        print("{}: {}".format(label_values[l_id], score))

    print("---")
        
    # Compute kappa coefficient
    total = np.sum(cm)
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total)
    kappa = (pa - pe) / (1 - pe);
    print("Kappa: " + str(kappa))
    return accuracy

# IoU = TP / (TP + FP + FN)

In [None]:
   
data_path = Path('/home/cliff/rs_imagery/ISPRS-DATASETS/Vaihingen/vaihingen_256')
img_path = data_path/'images/train'
msk_path = data_path/'masks/train'

mask_labels = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] 
num_classes = len(mask_labels) 
print(img_path, '\n', msk_path)

In [None]:
from pygraphviz import AGraph
#import glob

graph = [AGraph(g) for g in glob('../graphs/*.dot')]
_, comp_graphs = cell_graph.generate_comp_graph(graph)

conf = arch_config(comp_graphs=comp_graphs, channels=64, classes=N_CLASSES)
net = get_net(conf)
net.cuda()

In [None]:
import torchvision.transforms as transforms

tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize([0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185])])

In [None]:
# Load the datasets
if DATASET == 'Potsdam':
    all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
    all_ids = ["".join(f.split('')[5:7]) for f in all_files]
elif DATASET == 'Vaihingen':
    #all_ids = 
    all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
    all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]

# Test split on Vaihingen :
test_ids = ['5', '7', '23', '30']

train_set = ISPRS_dataset(data_dir=img_path, label_dir=msk_path, transforms=tfms)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=BATCH_SIZE)

In [None]:
# Parameters
window_size = 256

msk_labels = np.array(["roads", "buildings", "low veg.", "trees", "cars", "clutter"])
num_classes = len(msk_labels) 


dataset = 'Vaihingen' #'Potsdam'
dataset_dir = Path.home()/'rs_imagery/ISPRS_DATASETS/{}'.format(dataset)

if dataset == 'Potsdam':
    tiles = dataset_dir/'Ortho_IRRG/top_potsdam_{}_{}_IRRG.tif'
    masks = dataset_dir/'Labels_for_participants/top_potsdam_{}_{}_label.tif'
    e_masks = dataset_dir/'Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif'
    trainset_dir = dataset.lower() + '_{}'.format(window_size) 
    testset_ids = ['2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10'] # ['7_8', '4_10', 2 11, 5 11]
elif dataset == 'Vaihingen':
    tiles = dataset_dir/'top/top_mosaic_09cm_area{}.tif'
    masks = dataset_dir/'gts_for_participants/top_mosaic_09cm_area{}.tif'
    e_masks = dataset_dir/'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif'
    trainset_dir = dataset.lower() + '_{}'.format(window_size) 
    testset_ids = ['5', '7', '23', '30']

In [None]:
# Parameters
WINDOW_SIZE = (128, 128) # Patch size
STRIDE = 32 # Stride for testing
IN_CHANNELS = 3 # Number of input channels (e.g. RGB)
FOLDER = "/home/cliff/rs_imagery/ISPRS-DATASETS/" # Replace with your "/path/to/the/ISPRS/dataset/folder/"
BATCH_SIZE = 10 # Number of samples in a mini-batch

LABELS = ["roads", "buildings", "low veg.", "trees", "cars", "clutter"] # Label names
N_CLASSES = len(LABELS) # Number of classes
WEIGHTS = torch.ones(N_CLASSES) # Weights for class balancing
CACHE = True # Store the dataset in-memory

DATASET = 'Vaihingen'

if DATASET == 'Potsdam':
    MAIN_FOLDER = FOLDER + 'Potsdam/'
    # Uncomment the next line for IRRG data
    # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif'
    # For RGB data
    DATA_FOLDER = MAIN_FOLDER + 'Ortho_RGB/top_potsdam_{}_RGB.tif'
    LABEL_FOLDER = MAIN_FOLDER + 'Labels_for_participants/top_potsdam_{}_label.tif'
    ERODED_FOLDER = MAIN_FOLDER + 'Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif'    
elif DATASET == 'Vaihingen':
    MAIN_FOLDER = FOLDER + 'Vaihingen/'
    DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif'
    LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif'
    ERODED_FOLDER = MAIN_FOLDER + 'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif'

In [None]:
from IPython.display import clear_output

def test(net, test_ids, all=False, stride=WINDOW_SIZE[0], batch_size=BATCH_SIZE, window_size=WINDOW_SIZE):
    # Use the network on the test set
    test_images = (np.asarray(io.imread(DATA_FOLDER.format(id))) for id in test_ids)
    test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
    eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
    all_preds = []
    all_gts = []
    
    # Switch the network to inference mode
    net.eval()

    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (N_CLASSES,))

        total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
        for i, coords in enumerate(tqdm(grouper(batch_size, 
                                                sliding_window(img, step=stride, window_size=window_size)), 
                                                total=total, leave=False)):
            # Display in progress results
            if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
                    _pred = np.argmax(pred, axis=-1)
                    fig = plt.figure()
                    fig.add_subplot(1,3,1)
                    plt.imshow(np.asarray(img, dtype='uint8'))
                    fig.add_subplot(1,3,2)
                    plt.imshow(convert_to_color(_pred))
                    fig.add_subplot(1,3,3)
                    plt.imshow(gt)
                    clear_output()
                    plt.show()
                    
            # Build the tensor
            #image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords]
            #image_patches = np.asarray(image_patches)
            #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)

            image_patches = [torch.clone(tf.normalize(tf.to_tensor(img[x:x+w, y:y+h]), 
                                        [0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185]))
                             for x,y,w,h in coords]
            image_patches = torch.stack(image_patches).cuda()
            #print(image_patches)
            
            #image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)
            
            # Do the inference
            outs = net(image_patches)
            outs = outs.data.cpu().numpy()
            
            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1,2,0))
                pred[x:x+w, y:y+h] += out
            del(outs)

        pred = np.argmax(pred, axis=-1)

        # Display the result
        clear_output()
        fig = plt.figure()
        fig.add_subplot(1,3,1)
        plt.imshow(np.asarray(img, dtype='uint8'))
        fig.add_subplot(1,3,2)
        plt.imshow(convert_to_color(pred))
        fig.add_subplot(1,3,3)
        plt.imshow(gt)
        plt.show()

        all_preds.append(pred)
        all_gts.append(gt_e)

        clear_output()
        # Compute some metrics
        metrics(pred.ravel(), gt_e.ravel())
        accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), 
                           np.concatenate([p.ravel() for p in all_gts]).ravel())
    if all:
        return accuracy, all_preds, all_gts
    else:
        return accuracy

In [None]:
net.load_state_dict(torch.load('./model.pth'))

In [None]:
_, all_preds, all_gts = test(net, test_ids, all=True)

In [None]:
for p, id_ in zip(all_preds, test_ids):
    img = convert_to_color(p)
    plt.imshow(img) and plt.show()
    io.imsave('./inference_tile_{}.png'.format(id_), img)