This notebook implements prediction of field extent and boundary probabilities using Planet RGB chunks with validation data available. The prediction uses [FracTAL ResUNet model](https://www.mdpi.com/2072-4292/13/11/2197) and [pre-trained model weights](https://arxiv.org/abs/2201.04771). Results are exported as geotiffs.

## Load packages and modules

In [1]:
import numpy as np
# import pandas as pd
import os
from tqdm import tqdm
import mxnet as mx
from mxnet import gluon
from mxnet import image
from glob import glob
# import imageio.v2 as imageio
from osgeo import gdal, osr
import sys
from multiprocessing import cpu_count

# add existing and decode modules to system path
module_paths=['decode/FracTAL_ResUNet/models/semanticsegmentation',
             'decode/FracTAL_ResUNet/nn/loss']
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)
# print(sys.path)

In [2]:
# import functions from modules
from FracTAL_ResUNet import FracTAL_ResUNet_cmtsk
from datasets import *

## Define parameters

In [3]:
# hyperparameters for model architecture
n_filters = 32
depth = 6
n_classes = 1
batch_size = 5
codes_to_keep = [1]
ctx_name = 'cpu'
gpu_id = 0
boundary_kernel_size = (2,2)

# other parameters
country = 'Rwanda'
CPU_COUNT = cpu_count()
srs = osr.SpatialReference()
srs.ImportFromEPSG(3857)

# folder of input RGB chunk geotiffs
input_folder='../0_Data_preparation/results/RGB_chunks'
# folder of validation chunks
groundtruth_folder='../0_Data_preparation/results/groundtruth'
# folder to store output model predictions
out_folder='results'

# pre-trained model weights
# trained_model='model_weights/Planet_france.params' # trained and fine-tuned on planet data of France
trained_model='model_weights/Planet_pretrained-france_finetuned-india.params' # trained on planet data of France and fine-tuned on India
# trained_model = 'model_weights/Airbus_pretrained-france_finetuned-india.params' # trained and fine-tuned on SPOT data of France

In [4]:
if not os.path.isdir(out_folder):
    os.makedirs(out_folder)

## Identify RGB chunks with validation data available

In [5]:
# extract chunk ids of validation data
gt_bound_names=glob(groundtruth_folder+'/'+country+'*crop_field_bound*.tif')
print('Found {} groundtruth chunks'.format(len(gt_bound_names)))

Found 123 groundtruth chunks


In [6]:
# find Planet RGB chunks corresponding to validation chunks
image_names=[]
for gt_bound_name in gt_bound_names:
    # extract id of validation chunk
    chunk_id=os.path.basename(gt_bound_name)[:-4].split('_')[-2:]
    image_list=glob(os.path.join(input_folder,country+'*'+'_'.join(chunk_id)+'.tif'))
    if len(image_list)<1:
        print('no RGB found for chunk')
    else:
        for img in image_list:
            image_names.append(img)
print('Found {} RGB images'.format(len(image_names)))

Found 738 RGB images


## Create dataset and dataloader

In [10]:
# Define dataset
test_dataset = Planet_Dataset_No_labels(image_names=image_names)

# Loads data from a dataset and create mini batches
# test_dataloader = gluon.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=CPU_COUNT) # might encounter 'connection refused' issue
test_dataloader = gluon.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=1)

## Load pre-trained model weights and run inference in batch

In [11]:
# Set MXNet ctx
if ctx_name == 'cpu':
    ctx = mx.cpu()
elif ctx_name == 'gpu':
    ctx = mx.gpu(gpu_id)

# initialise model
model = FracTAL_ResUNet_cmtsk(nfilters_init=n_filters, depth=depth, NClasses=n_classes)

# load pre-trained model parameters
model.load_parameters(trained_model, ctx=ctx)

depth:= 0, nfilters: 32, nheads::8, widths::1
depth:= 1, nfilters: 64, nheads::16, widths::1
depth:= 2, nfilters: 128, nheads::32, widths::1
depth:= 3, nfilters: 256, nheads::64, widths::1
depth:= 4, nfilters: 512, nheads::128, widths::1
depth:= 5, nfilters: 1024, nheads::256, widths::1
depth:= 6, nfilters: 512, nheads::256, widths::1
depth:= 7, nfilters: 256, nheads::128, widths::1
depth:= 8, nfilters: 128, nheads::64, widths::1
depth:= 9, nfilters: 64, nheads::32, widths::1
depth:= 10, nfilters: 32, nheads::16, widths::1


In [12]:
%%time
# run model
for batch_i, img_data in enumerate(tqdm(test_dataloader)):
    
    # extract batch data
    imgs,id_dates,geotrans=img_data
    rows, cols= imgs.shape[2],imgs.shape[3]

    # make a copy if the variable currently lives in the wrong context
    imgs = imgs.as_in_context(ctx)

    # predicted outputs: field extent probability, field boundary probability and distance to boundary
    logits, bound, dist = model(imgs)

    # export predictions for all images in the batch
    bt_size=id_dates.asnumpy().shape[0]
    for i in range(bt_size):
        id_date=id_dates[i,:].asnumpy().astype(int)
        str_id_date=[str(id_date[0])] # year
        str_id_date.append(str(id_date[1]).zfill(2)) # month
        str_id_date.extend([str(s).zfill(3) for s in id_date[2:]]) # zfill rows and cols so that output files also have uniform file name length
        gt=geotrans[i,:].asnumpy()
        
        outname_extent=os.path.join(out_folder,country+'_extent_prob_'+'_'.join(str_id_date)+'.tif')
        prj=srs.ExportToWkt()
        export_geotiff(outname_extent,logits[1,:,:].asnumpy().squeeze(),gt,prj,gdal.GDT_Float32)

        outname_bound=os.path.join(out_folder,country+'_bound_prob_'+'_'.join(str_id_date)+'.tif')
        export_geotiff(outname_bound,bound[1,:,:].asnumpy().squeeze(),gt,prj,gdal.GDT_Float32)
    
#         outname_dist=os.path.join(out_folder,country+'_distance'+'_'.join(str_id_date)+'.tif')
#         export_geotiff(outname_dist,dist[1,:,:].asnumpy().squeeze(),gt,prj,gdal.GDT_Float32)


100%|██████████| 148/148 [16:26<00:00,  6.67s/it]

CPU times: user 1h 36min 57s, sys: 8min 55s, total: 1h 45min 52s
Wall time: 16min 26s



