In [2]:
from tensorflow.keras import losses
import tensorflow as tf
import os
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import cv2
import rasterio as rs
import fiona
from osgeo import osr
from fiona.crs import from_epsg
from shapely.geometry import Point,mapping
from skimage import filters
from skimage.feature import blob_log
from math import sqrt
import earthpy.plot as ep

from utility.postprocess import Binary, Find_threshold_otsu, Gaussian_filter

In [None]:
# Load the pre trained model
model = load_model('./models/UNet.h5', compile=False)
model.compile(optimizer=Adam(learning_rate=1e-5), loss = losses.binary_crossentropy, metrics =['binary_accuracy', 'Precision', 'Recall'])
print("Model is loaded..")

In [11]:
# load test data.
image = rs.open('./images/plot_12.tif')
crs = image.crs
transform = image.transform
image_rgba = image.read([1,2,3,4])
image = np.moveaxis(image_rgba, 0, -1)
image = np.expand_dims(image, axis=0)

In [None]:
# make predictions
pred = model.predict(image)
pred = np.squeeze(pred, axis=0)
# apply gaussian filter
blur = Gaussian_filter(pred, sigma=1)
# find otsu threshold
t = Find_threshold_otsu(pred)
# generate binary image
binary = Binary(blur, t)
# find blobs
blobs_log = blob_log(binary, min_sigma=1, max_sigma=3, num_sigma=10, threshold=0.25, overlap=0.25, exclude_border = True)
blobs_log[:, 2] = blobs_log[:, 2] * sqrt(2)

In [None]:
# empty list to store the values
lon = []
lat = []
val = []
# plot figure
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(25, 25))
band_indices = [3, 2, 1]
ep.plot_rgb(
    image_rgba.transpose([2,0,1]),
    rgb=band_indices,
    title="Satellite Image",
    stretch=True,
    ax=ax[0]
)
ax[1].imshow(binary.squeeze(), cmap='binary')
ax[1].set_title("Prediction")
# loop throw each blob and apply spatial transformation and add to the list
for blob in blobs_log:
    y, x, r = blob
    c = plt.Circle((x, y), r+1, color='red', linewidth=2, fill=False)
    y1, x1 = rs.transform.xy(transform=transform, rows=y, cols=x)
    lon.append(x1)
    lat.append(y1)
    val.append(r)
    ax[2].add_patch(c)
ax[2].imshow(binary.squeeze(), cmap='binary')
ax[2].set_title("No. of trees: "+str(len(blobs_log)))
plt.show()

In [125]:
# save the result as tiff file
with rs.open("results/res_2.tif","w",driver='GTiff', count=1, dtype=rs.uint16,width=image_rgba.shape[1],
            height=image_rgba.shape[1], transform=transform, crs=crs) as raschip:
    raschip.write(binary, 1)

In [126]:
# create a dataframe
point_df = np.column_stack([np.array(lat), np.array(lon), np.array(val)])
schema = {
    'geometry':'Point',
    'properties':{'id': 'int','Value':'float:15.2'}
}

srs = osr.SpatialReference()
srs.SetFromUserInput("EPSG:3857")
wgs84 = srs.ExportToProj4()

with fiona.open("results/plt_2.geojson", 'w', crs=wgs84, driver="GeoJSON", schema=schema) as sink:
    for i in range(len(point_df)):
        point = Point(point_df[i,:2])
        sink.write({
            'geometry': mapping(point),
            'properties': {'id': i+1, 'Value': point_df[i,2]},
        })