In [1]:
%matplotlib inline
from IPython.display import clear_output

In [2]:
# imports
from gepcore.utils import cell_graph
from nas_seg.seg_model import *
from gepnet.utils import count_parameters
from nas_seg.isprs_dataset import img_to_mask, mask_to_img
from nas_seg.utils import *

from fastai.vision.all import *
import torchvision.transforms.functional as tf

from skimage import io
import glob
from pygraphviz import AGraph
from tqdm.notebook import tqdm as tqdm
from sklearn.metrics import confusion_matrix

In [3]:
### dataset
window_size = 128
stride = 32
batch_size = 10

labels = np.array(["imp. surf.", "buildings", "low veg.", "trees", "cars", "clutter"])
num_classes = len(labels) 

dataset = 'Vaihingen' #'Potsdam'
dataset_dir = Path.home()/'rs_imagery/ISPRS_DATASETS/{}'.format(dataset)

# if dataset == 'Potsdam':
#     tiles = dataset_dir/'Ortho_IRRG/top_potsdam_{}_IRRG.tif' #or 'top_potsdam_{}_RGB.tif' for rgb images
#     masks = dataset_dir/'Labels_for_participants/top_potsdam_{}_label.tif'
#     e_masks = dataset_dir/'Labels_for_participants_noBoundary/top_potsdam_{}_label_noBoundary.tif'
#     test_ids = ['2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10']
# elif dataset == 'Vaihingen':
#     tiles = dataset_dir/'top/top_mosaic_09cm_area{}.tif'
#     masks = dataset_dir/'gts_for_participants/top_mosaic_09cm_area{}.tif'
#     e_masks = dataset_dir/'gts_eroded_for_participants/top_mosaic_09cm_area{}_noBoundary.tif'
#     test_ids = ['5', '17', '3', '32'] #['5', '7', '23', '30']
    
    
if dataset == 'Potsdam':
    tiles = dataset_dir/'Ortho_IRRG/top_potsdam_{}_IRRG.tif' #or 'top_potsdam_{}_RGB.tif' for rgb images
    masks = dataset_dir/'Labels_all/top_potsdam_{}_label.tif'
    e_masks = dataset_dir/'Labels_all_noBoundary/top_potsdam_{}_label_noBoundary.tif'
    test_ids = ['2_13', '2_14', '3_13', '3_14', '4_13', '4_14', '4_15', 
                '5_13', '5_14', '5_15', '6_13', '6_14', '6_15', '7_13']
elif dataset == 'Vaihingen':
    tiles = dataset_dir/'top/top_mosaic_09cm_area{}.tif'
    masks = dataset_dir/'gts_complete/top_mosaic_09cm_area{}.tif'
    e_masks = dataset_dir/'gts_eroded_complete/top_mosaic_09cm_area{}_noBoundary.tif'
    test_ids = ['2', '4', '6', '8', '10', '12', '14', '16', '20', 
                '22', '24','27', '29', '31', '33', '35', '38']
    
### view dataset sample
# img = io.imread(dataset_dir/'top/top_mosaic_09cm_area11.tif')
# fig = plt.figure()
# fig.add_subplot(121)
# plt.imshow(img)

# gt = io.imread(dataset_dir/'gts_for_participants/top_mosaic_09cm_area11.tif')
# fig.add_subplot(122)
# plt.imshow(gt)
# plt.show()

# array_gt = img_to_mask(gt)
# print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt)

In [4]:
def evaluate(net, test_ids, stride=stride, batch_size=batch_size, window_size=window_size, 
         img_files=tiles, mask_files=masks, eroded_masks=e_masks):
    # Get the test set
    test_images = (np.asarray(io.imread(str(img_files).format(id))) for id in test_ids)
    test_labels = (np.asarray(io.imread(str(mask_files).format(id)), dtype='uint8') for id in test_ids)
    eroded_labels = (img_to_mask(io.imread(str(eroded_masks).format(id))) for id in test_ids)
    
    all_preds = []
    all_gts = []
    
    # Switch the network to inference mode
    net.eval()

    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (num_classes,))

        total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
        for i, coords in enumerate(tqdm(grouper(batch_size, 
                                                sliding_window(img, step=stride, window_size=window_size)), 
                                                total=total, leave=False)):
            # Display in progress results
            if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
                    _pred = np.argmax(pred, axis=-1)
                    fig = plt.figure()
                    fig.add_subplot(1,3,1)
                    plt.imshow(np.asarray(img, dtype='uint8'))
                    fig.add_subplot(1,3,2)
                    plt.imshow(mask_to_img(_pred))
                    fig.add_subplot(1,3,3)
                    plt.imshow(gt)
                    clear_output()
                    plt.show()
                    
            # Build the tensor
            image_patches = [torch.clone(tf.normalize(tf.to_tensor(img[x:x+w, y:y+h]),
                                                      [0.4776, 0.3226, 0.3189], [0.1816, 0.1224, 0.1185]))
                             for x,y,w,h in coords]
            image_patches = torch.stack(image_patches).cuda()
                        
            # Do the inference
            outs = net(image_patches)
            outs = outs.data.cpu().numpy()
            
            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1,2,0))
                pred[x:x+w, y:y+h] += out
            
            del(outs)

        pred = np.argmax(pred, axis=-1)

        # Display the result
        clear_output()
        fig = plt.figure()
        fig.add_subplot(1,3,1)
        plt.imshow(np.asarray(img, dtype='uint8'))
        fig.add_subplot(1,3,2)
        plt.imshow(mask_to_img(pred))
        fig.add_subplot(1,3,3)
        plt.imshow(gt)
        plt.show()

        all_preds.append(pred)
        all_gts.append(gt_e)

        clear_output()
        
        # Compute some metrics
        metrics(pred.ravel(), gt_e.ravel(), labels)
    accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]),
                       np.concatenate([p.ravel() for p in all_gts]).ravel(),
                       labels)

    return accuracy, all_preds, all_gts

In [5]:
graph = [AGraph(g) for g in glob.glob('nas_seg/comp_graphs/*.dot')]
_, comp_graphs = cell_graph.generate_comp_graph(graph)

conf = arch_config(comp_graphs=comp_graphs,
                   channels=32,
                   input_size=128,
                   classes=num_classes)

net = Network(conf)
net.cuda()

Network(
  (stem_): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (cell_0): Cell(
    (branch_0): Layer(
      (dilconv3x3_0): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), groups=64, bias=False)
        (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
      (maxpool3x3_1): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (sepconv5x5_2): Sequential(
        (0): Conv2d(64, 

In [None]:
# def get_mask(x):
#     dset = x.parent.name
#     path = x.parent.parent.parent/'masks'/dset
#     name = x.name
#     return path/name

# def overall_acc(preds, target):
#     target = target.squeeze(1)
#     return (preds.argmax(dim=1)==target).float().mean()

# net = load_learner('nas_seg/model.pkl')

In [6]:
net.load_state_dict(torch.load('/home/cliff/rs_imagery/ISPRS_DATASETS/Vaihingen/model.pth'))
print(count_parameters(net))

61.111696


In [7]:
_, all_preds, all_gts = evaluate(net, test_ids)

Confusion matrix :
[[2170806   86680   63684    7977    3035     943]
 [  39010 2174916   78446    2165    1087       0]
 [  34551   27982 2165547  174996       2       0]
 [   9439    1286  158402 1480353     125       0]
 [   7480    2053     722     113   52841       0]
 [      0       0       0       0       0       0]]

8744641 pixels processed
Total accuracy : 0.9199306180779748

F1Score :
imp. surf. -- 0.9449768425158307
buildings -- 0.9479771456765888
low veg. -- 0.8893637809070821
trees -- 0.8930676768794968
cars -- 0.8784944180749632
clutter -- 0.0
mean F1Score : 0.9107759728107923

IoU :
imp. surf. -- 0.8956929862745786
buildings -- 0.9010993837070796
low veg. -- 0.8007696540217695
trees -- 0.8067951926472704
cars -- 0.7833170268908062
clutter -- 0.0
mean IoU : 0.8375348487083007

Kappa:  0.8930516151957302
----------------------------------------

Confusion matrix :
[[20945896   747271   758654   167610    48216     4021]
 [  824878 20728869   341421    43122    10604      

In [None]:
for p, id_ in zip(all_preds, test_ids):
    img = convert_to_color(p)
    plt.imshow(img) and plt.show()
    io.imsave('./inference_tile_{}.png'.format(id_), img)

In [None]:
from datetime import datetime

In [None]:
x = datetime.now()

In [None]:
y = datetime.now() - x

In [None]:
print(y)