In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib notebook
%config Application.log_level="INFO"

In [3]:
%env TP_ISAE_DATA = /home/fchouteau/repositories/tp_isae/data/

env: TP_ISAE_DATA=/home/fchouteau/repositories/tp_isae/data/


In [4]:
# Import vis
%matplotlib notebook

from khumeia.utils import list_utils
from khumeia import visualisation
from matplotlib import pyplot as plt

In [5]:
# Load data
import json
import os

import khumeia
khumeia.download_eval_data()

[2018-11-08 23:56:23,840][tp-isae][get_data][INFO] Downloading evaluation data
[2018-11-08 23:56:23,841][tp-isae][get_data][INFO] Downloading data from tp_isae_eval_data.tar.gz to /home/fchouteau/repositories/tp_isae/data/
[2018-11-08 23:56:23,842][tp-isae][get_data][INFO] Extracting tar gz
[2018-11-08 23:56:24,285][tp-isae][get_data][INFO] Done. Your data is located here /home/fchouteau/repositories/tp_isae/data/raw



In [6]:
from khumeia.data.collection import SatelliteImagesCollection

RAW_DATA_DIR = os.path.join(os.environ.get("TP_ISAE_DATA"), "raw")
EVAL_DATA_DIR = os.path.join(RAW_DATA_DIR, "eval")

eval_collection = SatelliteImagesCollection.from_path(EVAL_DATA_DIR)

print(eval_collection)

--- Item collection ---
collection_id: eval
Number of items: 3
--- Item description ---
image_id: USGS_DEN
image_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_DEN.jpg
label_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_DEN.json
image_shape: (7685, 6205, 3)
number of labels: 52
--- Item description ---
image_id: USGS_LAX
image_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_LAX.jpg
label_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_LAX.json
image_shape: (7930, 6689, 3)
number of labels: 92
--- Item description ---
image_id: USGS_MSY
image_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_MSY.jpg
label_file: /home/fchouteau/repositories/tp_isae/data/raw/eval/USGS_MSY.json
image_shape: (8211, 6514, 3)
number of labels: 47



In [7]:
from khumeia.inference.engine import InferenceEngine
from khumeia.data.sliding_window import SlidingWindow
from khumeia.inference.predictor import Predictor



In [8]:
import random

class DemoPredictor(Predictor):
    """
    Dummy predictor randomly returning aircraft or background
    """

    def __init__(self, threshold=0.9, batch_size=128):
        self.threshold = threshold
        self.batch_size = batch_size
        self.model = lambda x: "aircraft" if random.random() > threshold else "background"

    def predict_on_tile(self, tile_data):
        return self.model(0)

    def predict_on_tiles(self, tiles_data):
        return [self.model(0) for tile_data in tiles_data]

In [9]:
predictor = DemoPredictor(threshold=0.75, batch_size=128)

In [10]:
sliding_window = SlidingWindow(
    tile_size=64,
    stride=64,
    discard_background=False,
    padding=0,
    label_assignment_mode="center")

In [11]:
inference_engine = InferenceEngine(items=eval_collection)

In [12]:
results = inference_engine.predict_on_item(eval_collection[0],predictor=predictor, sliding_windows=sliding_window)

[2018-11-08 23:56:25,798][tp-isae][engine][INFO] Generating tiles to predict


HBox(children=(IntProgress(value=0, description='Applying slider', max=1, style=ProgressStyle(description_widt…


[2018-11-08 23:56:26,667][tp-isae][engine][INFO] Generating predicting on item USGS_DEN with 11520 tiles


HBox(children=(IntProgress(value=0, description='Predicting on batch', max=90, style=ProgressStyle(description…




In [13]:
item = eval_collection.items[0]
image = item.image
labels = item.labels

tiles = list(filter(lambda tile: tile.item_id == item.key, results))
true_positives = list(filter(lambda tile: tile.is_true_positive, tiles))
false_positives = list(filter(lambda tile: tile.is_false_positive, tiles))
false_negatives = list(filter(lambda tile: tile.is_false_negative, tiles))

image = visualisation.draw_bboxes_on_image(image, labels, color=(255,255,255))
image = visualisation.draw_bboxes_on_image(image, true_positives, color=(0, 255, 0))
image = visualisation.draw_bboxes_on_image(image, false_positives, color=(0, 0, 255))
image = visualisation.draw_bboxes_on_image(image, false_negatives, color=(255, 0, 0))

plt.figure(figsize=(10, 10))
plt.title(item.image_id)
plt.imshow(image)
plt.show()

<IPython.core.display.Javascript object>