This notebook implements prediction of field extent and boundary probabilities as well as instance segmentation for all Planet RGB chunks, using FracTAL ResUNet model, pre-trained weights and fine-tuned hyperparameters.

## Load packages and modules

In [1]:
import numpy as np
# import pandas as pd
import os
from tqdm import tqdm
import mxnet as mx
from mxnet import gluon
from mxnet import image
from glob import glob
# import imageio.v2 as imageio
from osgeo import gdal, osr
import sys
from multiprocessing import cpu_count

# add existing and decode modules to system path
module_paths=['../1_Identify_months_thresholds_model_evaluation','../1_Identify_months_thresholds_model_evaluation/decode/FracTAL_ResUNet/models/semanticsegmentation',
             '../1_Identify_months_thresholds_model_evaluation/decode/FracTAL_ResUNet/nn/loss']
for module_path in module_paths:
    if module_path not in sys.path:
        sys.path.append(module_path)
# print(sys.path)

In [2]:
# import functions from modules
from FracTAL_ResUNet import FracTAL_ResUNet_cmtsk
from datasets import *

## Define parameters

In [3]:
# hyperparameters for model architecture
n_filters = 32
depth = 6
n_classes = 1
batch_size = 20
codes_to_keep = [1]
ctx_name = 'cpu'
gpu_id = 0
boundary_kernel_size = (2,2)

# other parameters
country = 'Rwanda'
CPU_COUNT = cpu_count()
srs = osr.SpatialReference()
srs.ImportFromEPSG(3857)
prj=srs.ExportToWkt()

# str_months=['04','08','12'] # best months
str_months=['04','10','12'] # best months
str_year='2021'

# folder of input RGB chunk geotiffs
input_folder='../0_Data_preparation/results/RGB_chunks'

# folder to store output model predictions
out_folder='results'

# pre-trained model weights
# trained_model='../1_Identify_months_thresholds_model_evaluation/model_weights/Planet_france.params' # trained and fine-tuned on planet data of France
trained_model='../1_Identify_months_thresholds_model_evaluation/model_weights/Planet_pretrained-france_finetuned-india.params' # trained on planet data of France and fine-tuned on India
# trained_model = '../1_Identify_months_thresholds_model_evaluation/model_weights/Airbus_pretrained-france_finetuned-india.params' # trained and fine-tuned on SPOT data of France

In [4]:
if not os.path.isdir(out_folder):
    os.makedirs(out_folder)

## Create dataset and dataloader

In [5]:
# extract ids for all chunks
file_names=glob(os.path.join(input_folder,country+'*'+str_year+'*.tif'))
chunk_ids=[]
for file_name in file_names:
    # chunk id
    chunk_id='_'.join(os.path.basename(file_name)[:-4].split('_')[-2:])
    chunk_ids.append(chunk_id)
chunk_ids=set(chunk_ids)
print('Found {} unique chunks'.format(len(chunk_ids)))

Found 18113 unique chunks


In [6]:
# list of names for all selected months
fn_months={str_month: [] for str_month in str_months}
for str_month in str_months:
    for chunk_id in chunk_ids:
        fn_prefix='_'.join([country,'planet_medres_visual',str_year,str_month,chunk_id])
        image_name=os.path.join(input_folder,fn_prefix+'.tif')
        fn_months[str_month].append(image_name)

## Load pre-trained model and run inference in batch

In [7]:
# Set MXNet ctx
if ctx_name == 'cpu':
    ctx = mx.cpu()
elif ctx_name == 'gpu':
    ctx = mx.gpu(gpu_id)

# initialise model
model = FracTAL_ResUNet_cmtsk(nfilters_init=n_filters, depth=depth, NClasses=n_classes)

# load pre-trained model parameters
model.load_parameters(trained_model, ctx=ctx)

depth:= 0, nfilters: 32, nheads::8, widths::1
depth:= 1, nfilters: 64, nheads::16, widths::1
depth:= 2, nfilters: 128, nheads::32, widths::1
depth:= 3, nfilters: 256, nheads::64, widths::1
depth:= 4, nfilters: 512, nheads::128, widths::1
depth:= 5, nfilters: 1024, nheads::256, widths::1
depth:= 6, nfilters: 512, nheads::256, widths::1
depth:= 7, nfilters: 256, nheads::128, widths::1
depth:= 8, nfilters: 128, nheads::64, widths::1
depth:= 9, nfilters: 64, nheads::32, widths::1
depth:= 10, nfilters: 32, nheads::16, widths::1


In [None]:
%%time
for month_i in range(len(str_months)):
    
    # create dataset and dataloader
    test_dataset = Planet_Dataset_No_labels(image_names=fn_months[str_months[month_i]])
#     test_dataloader = gluon.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=CPU_COUNT)
    test_dataloader = gluon.data.DataLoader(test_dataset, batch_size=batch_size,num_workers=1)
    
    # run model
    for batch_i, img_data in enumerate(tqdm(test_dataloader)):
        # extract batch data
        imgs,id_dates,geotrans=img_data

        # make a copy if the variable currently lives in the wrong context
        imgs = imgs.as_in_context(ctx)

        # predicted outputs: field extent probability, field boundary probability and distance to boundary
        logits, bound, dist = model(imgs)
        
        # average and export predictions for each batch
        bt_size=id_dates.asnumpy().shape[0]
        for i in range(bt_size):
            # extract predictions
            extent_average=logits[i,:,:].asnumpy().squeeze()
            bound_average=bound[i,:,:].asnumpy().squeeze()
            # extract date and id information
            id_date=id_dates[i,:].asnumpy().astype(int)
            chunk_id='_'.join([str(s).zfill(3) for s in id_date[2:]]) # zfill rows and cols so that output files also have uniform file name length
            # extract spatial information
            gt=geotrans[i,:].asnumpy()
            
            # output file names
            outname_extent='_'.join([country,'average_extent_prob',str_year,'_'.join(str_months),chunk_id])+'.tif'
            outname_extent=os.path.join(out_folder,outname_extent)
            outname_bound='_'.join([country,'average_bound_prob',str_year,'_'.join(str_months),chunk_id])+'.tif'
            outname_bound=os.path.join(out_folder,outname_bound)
            
            # update averaged predictions
            if month_i>0:
                temp_extent=imageio.imread(outname_extent)
                extent_average+=temp_extent
                temp_bound=imageio.imread(outname_bound)
                bound_average+=temp_bound
            if month_i==len(str_months)-1:
                extent_average/=len(str_months)*1.0
                bound_average/=len(str_months)*1.0
            # export as geotiff
            export_geotiff(outname_extent,extent_average,gt,prj,gdal.GDT_Float32)
            export_geotiff(outname_bound,bound_average,gt,prj,gdal.GDT_Float32)

  2%|▏         | 15/906 [05:33<5:29:41, 22.20s/it]