In [1]:
import os
import cv2
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pickle

from PIL import Image
from matplotlib import cm
import matplotlib.pyplot as plt
from torchsummary import summary

# Evaluate Network

### Load Network

In [None]:
# define which network to load
cur_dir = os.getcwd()
path_to_network = os.path.join(cur_dir, 'pollennet.pt')

# initialise blank network
pollen_network = FCNN()
# update with saved weights
pollen_network.load_state_dict(torch.load(path_to_network))
pollen_network.eval()
# define criterion and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(pollen_network.parameters(), lr=0.001, momentum=0.9)

# # load validloader and losses:
# TODO
# printloader should load validloader with torch.load() somehow
# losses with pickle?

### Show loaded weights

In [None]:
print("\n-----------------------------\n")
print("Weights of the first layer:")
print(pollen_network.conv0.weight)

print("\n-----------------------------\n")
print("Model's state_dict:")
for param_tensor in new_network.state_dict():
    print(param_tensor, "\t", pollen_network.state_dict()[param_tensor].size())

print("\n-----------------------------\n")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

### Show how data is transformed in size while going through the network

In [None]:
summary(network, (1,64,64))
print("-----------------------------")
summary(network, (1,4000,3000))

### Plot Losses

In [None]:
plt.plot(train_losses)
plt.plot(val_losses)
plt.plot(F1_scores)
plt.plot(accuracies)
plt.legend(["Train Loss", "Validation Loss", "F1-Score", "Accuracy"])
plt.show()

### Plot some wrong classified samples

In [None]:
# This function was used to find the weaknesses of the network.

def print_false_positive(network, testloader):
    """
    Evaluates the network and prints the false positives samples. 
    
    Parameters:
    
    Returns:
    
    """
    
    counter = 0
    # We do not need any gradiants here, since we do not train the network.
    # We are only interested in the predictions of the network on the testdata. 
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(testloader):
            outputs = network(torch.transpose(inputs[...,None],1,3)).view(-1)
            predicted = (outputs >= threshold) # Predicted is a tensor of booleans 
            predicted = predicted.view(predicted.size(0))
            labels = labels == 1
            if (predicted and not labels):
                title = 'Label: False, Predicted: True'
                fig, ax = plt.subplots()
                plt.imshow(np.array(inputs[0]), cmap='gray')
                ax.set_title(title)
                plt.show()
                counter += 1
                if counter == 20:
                    break
    return counter

def print_false_negative(network, testloader):
    # We do not need any gradiants here, since we do not train the network.
    # We are only interested in the predictions of the network on the testdata. 
    counter = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(testloader):
            outputs = network(torch.transpose(inputs[...,None],1,3)).view(-1)
            predicted = (outputs >= threshold) # Predicted is a tensor of booleans 
            predicted = predicted.view(predicted.size(0))
            labels = labels == 1
            if (not predicted and labels):
                title = 'Label: True, Predicted: False'
                fig, ax = plt.subplots()
                plt.imshow(np.array(inputs[0]), cmap='gray')
                ax.set_title(title)
                plt.show()
                counter += 1
                if counter == 20:
                    break
    return counter

In [None]:
print_false_positive(network, printloader)
print("-----------------------------")
print_false_negative(network, printloader)

# Create Heatmaps

### Load Data

In [None]:
def get_samples(dir):
    """returns samples in given directory"""

    # samples will be a list of tuples, each tuple contains a path and a list of coords of a single image
    samples = []
    for root, folders, files in os.walk(dir):
        for folder in folders:
            if folder == 'img':
                for img_root, img_folder, img_files in os.walk(os.path.join(root, folder)):
                    for img_file in img_files:
                        # go find its related annotation file:
                        found_ann = False
                        for ann_root, ann_folder, ann_files in os.walk(os.path.join(root, 'ann')):
                            if found_ann:
                                break
                            for ann_file in ann_files:
                                if img_file in ann_file:
                                    # found a pair!
                                    cur_coords = []
                                    with open(os.path.join(ann_root, ann_file)) as ann_json:
                                        ann_data = json.load(ann_json)
                                    cur_len = len(ann_data['objects'])
                                    if cur_len:
                                        for obj in ann_data['objects']:
                                            cur_coords.append(obj['points']['exterior'][0])
                                        img_path = os.path.join(img_root, img_file)
                                        np_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                                        samples.append((np.array(np_img), cur_coords))
                                    # to prevent unnecessary looping:
                                    found_ann = True
                                    break
    return samples

### Create Heatmaps

In [None]:
def heat(network, samples):
    """
    Feeds images of any size into the network and returns the belonging heatmaps. 
    
    Parameters:
        network (FCNN): A fully convolutional neural network, that is used to compute a heatmap.
        testloader (torch.utils.data.DataLoader): Dataloader with batch size 1 that contains the
            images you want to compute the heatmaps of
        
    Returns:
        heatmaps ([np.array]): A list which contains a heatmap that is a 2D array for every input image.
    """
    
    heatmaps = []
    print('Creating heatmaps...')
    with torch.no_grad():

        for fs_image, fs_coords in samples:
            t_image = torch.Tensor(fs_image)[None, ...]
            # outputs = network(torch.transpose(inputs[...,None],1,3))
            heatmap = network(torch.transpose(t_image[...,None],1,3))
            heatmaps.append((np.array(torch.transpose(heatmap[0,0,:,:],0,1)),fs_coords))
            
    print('Heatmaps created.')
    return heatmaps

### Find local maxima with max-pooling

In [None]:
def non_max_suppression(heatmaps, local_size = 8):
    heatmaps_sup = []
    pooling = nn.MaxPool2d((local_size * 2 - 1), stride = 1, padding = local_size - 1)
    for heatmap in heatmaps:
        max_filter = pooling(torch.tensor(heatmap)[None,...])
        max_filter = np.array(max_filter)
        heatmap = ((heatmap == max_filter) * (heatmap >= 0.8)).astype(int)
        heatmaps_sup.append(heatmap[0,:,:])
        
    return heatmaps_sup

def non_max_suppression_single(heatmap, local_size = 8):
    pooling = nn.MaxPool2d((local_size * 2 - 1), stride = 1, padding = local_size - 1)
    max_filter = pooling(torch.tensor(heatmap)[None,...])
    max_filter = np.array(max_filter)
    heatmap = ((heatmap == max_filter) * (heatmap >= 0.8)).astype(int)
    heatmap_sup = heatmap[0,:,:]
        
    return heatmap_sup

### Extract coordiantes of predicted pollen in heatmap

In [None]:
def pollen_coordinates(heatmap):
    coordinate_list = []
    h_coordinates = np.argwhere(heatmap == 1)
    for i in range(h_coordinates.shape[0]):
        coordinate_list.append((h_coordinates[i,0] * 2 * 2 * 2 + 28, h_coordinates[i,1] * 2 * 2 * 2 + 28))
    return coordinate_list

### Plot Image with actual Pollen and predicted locations

In [None]:
def guess_plotter(img, network_guess, actual_points):

    np_img = np.array(Image.open(img), dtype=np.uint8)
    main_fig,ax = plt.subplots(1)
    ax.imshow(np_img)

    for coord in actual_points:
        crcl = patches.Circle((coord[0],coord[1]),35,linewidth=4,edgecolor='r',facecolor='none')
        ax.add_patch(crcl)

    for coord in network_guess:
        crcl = patches.Circle((coord[0],coord[1]),20,linewidth=2,edgecolor='b',facecolor='none')
        ax.add_patch(crcl)

    plt.show()

### Run the code above:

In [None]:
# Load Images
cur_dir = os.getcwd()
full_size_path = os.path.join(curd_dir, 'Fullsize')
full_size_images = get_samples(full_size_path)

# Create heatmap of images
heatmaps = heat(network,full_size_images)

# Find local maxima on heatmaps
non_max_heatmaps = [(non_max_suppression_single(heatmap),coords) for heatmap, coords in heatmaps]

# TODO
# plot images with predicted and real pollen marked