In [None]:
%matplotlib inline
import localizer
from localizer import models, keras_helpers, util, visualization
from localizer.localizer import Localizer
import matplotlib.pyplot as plt
from scipy.misc import imread
from os.path import join
import os
from pylab import rcParams
import h5py
import keras
import numpy as np
rcParams['figure.figsize'] = 15, 15

In [None]:
data_dir = '/home/ben/deeplocalizer_data/data'

In [None]:
loc = Localizer(data_dir, load_filter_network = False)

In [None]:
X_train, y_train, X_test, y_test, X_val, y_val = util.load_or_restore_data(data_dir)

In [None]:
filtersize = (16, 16)
Xs_train = util.resize_data(X_train, filtersize)
Xs_val   = util.resize_data(X_val, filtersize)
Xs_test  = util.resize_data(X_test, filtersize)

print(Xs_train.shape)
print(Xs_test.shape)
print(Xs_val.shape)

In [None]:
saliency_datagen = keras_helpers.get_datagen(Xs_train)

In [None]:
ys_out = keras_helpers.predict_model(loc.saliency_network, Xs_test, saliency_datagen)

In [None]:
precision, recall, average_precision, thresholds, fpr, tpr, roc_auc = keras_helpers.evaluate_model(
    y_test > 0.75, ys_out, visualize=True)

In [None]:
saliency_threshold = keras_helpers.select_threshold(precision, recall, thresholds, min_value=0.90, optimize='precision')

In [None]:
beesbook_dir = "/home/beesbook/beesbook-data"
beesbook_2015 =  join(beesbook_dir, "season_2015_preprocces")
with open(join(beesbook_2015, "images.txt")) as f:
    beesbook_images = [l.rstrip('\n') for l in f.readlines()]
    
beesbook_tag_dir = join(beesbook_dir, "season_2015_tags")
os.makedirs(beesbook_tag_dir, exist_ok=True)

In [None]:
h5_fname = join(beesbook_tag_dir, "tags.hdf5")

In [None]:
init_samples = 1024
nb_chunks=1024
h5file = h5py.File(h5_fname)
tags = h5file.create_dataset("tags", shape=(init_samples, 1, 64, 64), 
                             maxshape=(None, 1, 64, 64), 
                             chunks=(nb_chunks, 1, 64, 64), 
                             dtype='float32')
saliency_dset = h5file.create_dataset("saliency", shape=(init_samples,), 
                                      maxshape=(None,), chunks=(2048,), 
                                      dtype='float32')

In [None]:
progbar = keras.utils.generic_utils.Progbar(len(beesbook_images))
nb_tags = 0
grow_samples = nb_chunks*10
nb_batch_samples = 32
batch = []
beesbook_images = beesbook_images
threshold = 0.99500811100006104
for i, imfname in enumerate(beesbook_images):
    if len(batch) == nb_batch_samples or i == len(beesbook_images):  
        for b in batch:
            assert len(b[0]) == len(b[1])
        batch_sali = np.concatenate([b[0] for b in batch], axis=0)
        batch_tags = np.concatenate([b[1] for b in batch], axis=0)
        assert len(batch_tags) == len(batch_sali), "{}, {}".format(len(batch_tags), len(batch_sali))
        
        end = nb_tags + len(batch_tags)
        while end >= len(tags):
            tags.resize(len(tags) + grow_samples, axis=0)
            saliency_dset.resize(len(saliency_dset) + grow_samples, axis=0)
        
        
        indicies = np.random.shuffle(np.arange(len(batch_tags)))
        tags[nb_tags:end] = batch_tags[indicies]
        saliency_dset[nb_tags:end] = batch_sali[indicies]
        nb_tags += len(batch_tags)
        batch = []
    
    saliencies, candidates, rois = loc.detect_tags(imfname, threshold)
    assert len(saliencies.reshape(-1)) == len(rois)
    assert len(tags) == len(saliency_dset)
    batch.append((saliencies.reshape((-1,)), rois))
    os.remove(imfname)
    progbar.add(1)
    
h5file.flush()

In [None]:
print(tags.shape)

In [None]:
def show(imfname, threshold):
    saliencies, candidates, rois = loc.detect_tags(imfname, threshold)
    print(saliencies.shape)
    plt.imshow(visualization.get_roi_overlay(candidates, imread(imfname) / 255.))
    plt.show()
    
for threshold in [0.75, 0.80, 0.85, 0.90, 0.95, 0.995]:
    print(threshold)
    cam1 = "/home/ben/deeplocalizer_data/images/season_2015/cam1/Cam_1_20150911120849_847258_wb.jpeg"
    cam3 = "/home/ben/deeplocalizer_data/images/season_2015/cam3/Cam_3_20150915235539_739596_wb.jpeg"
    show(cam1, threshold)
    show(cam3, threshold)

In [None]:
_ = plt.imshow(visualization.get_roi_overlay(candidates, imread(imfile) / 255.))