In [None]:
from shapely.geometry import Polygon, Point
import itertools
import geopandas as gpd
import pandas as pd
import rasterio as rio
from rasterio.features import rasterize
from rasterio import mask
from rasterio.plot import show
from rasterio.enums import Resampling
import matplotlib.pyplot as plt
from skimage.measure import label, regionprops
import torch.nn as nn
import numpy as np
import shapely.wkt as wkt
import torch
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
from tqdm import tqdm
from skimage.measure import label, regionprops
import os
import dask
import time
import gc
import re
from rioxarray.exceptions import NoDataInBounds

In [None]:
import distributed
dask.config.set({"distributed.nanny.environ.MALLOC_TRIM_THRESHOLD_": 0})
dask.config.set(scheduler='processes')

In [None]:
import ctypes

def trim_memory() -> int:
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

os.environ["MALLOC_TRIM_THRESHOLD_"] = str(dask.config.get("distributed.nanny.environ.MALLOC_TRIM_THRESHOLD_"))

print(os.environ["MALLOC_TRIM_THRESHOLD_"])

#os.environ["MALLOC_TRIM_THRESHOLD_"] = '0'

In [None]:
import xarray as xr
import rioxarray as riox
from xrspatial import convolution, focal, hillshade
from skimage.transform import resize
from dask.distributed import LocalCluster, Client

In [None]:
cluster = LocalCluster(n_workers=8, threads_per_worker=2, processes=True)
client = Client(cluster)
client.amm.start()
display(client)

In [None]:
# create function to normalize all data in range 0-1
def normalize_fn(image, image_suffix, stats_dict):
    if image_suffix in stats_dict.keys():
        min_tmp = stats_dict[image_suffix]['min']
        max_tmp = stats_dict[image_suffix]['max']
    else:
        # normalize to individual image if min/max stats not specified in dictionary
        min_tmp = np.min(image)
        max_tmp = np.max(image)
    return (image - min_tmp) / (max_tmp - min_tmp)

def calc_tpi(dtm, inner_r, outer_r, values=True):
    cellsize_x, cellsize_y = convolution.calc_cellsize(dtm)
    kernel = convolution.annulus_kernel(cellsize_x, cellsize_y, outer_r, inner_r)
    tpi = dtm - focal.apply(dtm, kernel)
    if values:
        return tpi.values
    else:
        return tpi

def calc_ndvi(ms, values=True):
    ndvi = (ms.sel(band=4).astype('float') - ms.sel(band=3).astype('float'))\
            / (ms.sel(band=4).astype('float') + ms.sel(band=3).astype('float'))
    if values:
        return ndvi.values
    else:
        return ndvi

In [None]:
from collections import namedtuple
from operator import mul

try:
    reduce = reduce
except NameError:
    from functools import reduce # py3k

Info = namedtuple('Info', 'start height')

def max_size(mat, value=0):
    """Find height, width of the largest rectangle containing all `value`'s.
    For each row solve "Largest Rectangle in a Histrogram" problem [1]:
    [1]: http://blog.csdn.net/arbuckle/archive/2006/05/06/710988.aspx
    """
    it = iter(mat)
    hist = [(el==value) for el in next(it, [])]
    max_size = max_rectangle_size(hist)
    for row in it:
        hist = [(1+h) if el == value else 0 for h, el in zip(hist, row)]
        max_size = max(max_size, max_rectangle_size(hist), key=area)
    return max_size

def max_rectangle_size(histogram):
    """Find height, width of the largest rectangle that fits entirely under
    the histogram.
    >>> f = max_rectangle_size
    >>> f([5,3,1])
    (3, 2)
    >>> f([1,3,5])
    (3, 2)
    >>> f([3,1,5])
    (5, 1)
    >>> f([4,8,3,2,0])
    (3, 3)
    >>> f([4,8,3,1,1,0])
    (3, 3)
    >>> f([1,2,1])
    (1, 3)
    Algorithm is "Linear search using a stack of incomplete subproblems" [1].
    [1]: http://blog.csdn.net/arbuckle/archive/2006/05/06/710988.aspx
    """
    stack = []
    top = lambda: stack[-1]
    max_size = (0, 0) # height, width of the largest rectangle
    pos = 0 # current position in the histogram
    for pos, height in enumerate(histogram):
        start = pos # position where rectangle starts
        while True:
            if not stack or height > top().height:
                stack.append(Info(start, height)) # push
            elif stack and height < top().height:
                max_size = max(max_size, (top().height, (pos - top().start)),
                               key=area)
                start, _ = stack.pop()
                continue
            break # height == top().height goes here

    pos += 1
    for start, height in stack:
        max_size = max(max_size, (height, (pos - start)), key=area)

    return max_size

def area(size):
    return reduce(mul, size)

In [None]:
outDIR = './cnn_pred_results/'
if not os.path.exists(outDIR):
    os.mkdir(outDIR)

In [None]:
ENCODER = 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['burrow']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda' #'cuda'# 'cpu'# 
model_fnl = 'deeplabplus'
res_fnl = 5
inputs_fnl = ['rgb', 'tpi', 'ndvi'] 
preprocess = True
prob_thresh = 0.5

In [None]:
#past_subset = None
past_subset = ['22W', '22E', 'CN']

img_f_dict = {
    '5W': {
        'group_1': {
            'rgb': '/mnt/d/202109/outputs/202109_5W_RGB/CPER_202109_5W_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_5W_MS/CPER_202109_5W_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_5W_RGB/CPER_202109_5W_RGB_dsm.tif'
        }
    },
    '29-30': {
        'group_1': {
            'rgb': '/mnt/d/202109/outputs/202109_29_30_RGB/CPER_202109_29_30_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_29_30_MS/CPER_202109_29_30_North_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_29_30_RGB/CPER_202109_29_30_RGB_DSM.tif'
        },
        'group_2': {
            'rgb': '/mnt/d/202109/outputs/202109_29_30_RGB/CPER_202109_29_30_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_29_30_MS/CPER_202109_29_30_South_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_29_30_RGB/CPER_202109_29_30_RGB_DSM.tif'
        }
    },
    '22W': {
        'group_1': {
            'rgb': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_RGB_DSM.tif'
        },
        'group_2': {
            'rgb': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_RGB_DSM.tif'
        }
    },
    '22E': {
        'group_1': {
            'rgb': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight1_RGB_DSM.tif'
        },
        'group_2': {
            'rgb': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_RGB_DSM.tif'
        },
        'group_3': {
            'rgb': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight3_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight2_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_22EW/CPER_202109_22EW_Flight3_RGB_DSM.tif'
        }
    },
    'CN': {
        'group_1': {
            'rgb': '/mnt/d/202109/outputs/202109_CN_RGB/Orthos/CPER_CN_Flight2_202109_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_CN_MS/CPER_202109_CN_Flight2_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_CN_RGB/DSMs/CPER_CN_Flight2_202109_RGB_DSM.tif'
        },
        'group_2': {
            'rgb': '/mnt/d/202109/outputs/202109_CN_RGB/Orthos/CPER_CN_Flight3_202109_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_CN_MS/CPER_202109_CN_Flight2_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_CN_RGB/DSMs/CPER_CN_Flight3_202109_RGB_DSM.tif'
        },
        'group_3': {
            'rgb': '/mnt/d/202109/outputs/202109_CN_RGB/Orthos/CPER_CN_Flight4_202109_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_CN_MS/CPER_202109_CN_Flight3_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_CN_RGB/DSMs/CPER_CN_Flight4_202109_RGB_DSM.tif'
        },
        'group_4': {
            'rgb': '/mnt/d/202109/outputs/202109_CN_RGB/Orthos/CPER_CN_Flight5_202109_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_CN_MS/CPER_202109_CN_Flight3_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_CN_RGB/DSMs/CPER_CN_Flight5_202109_RGB_DSM.tif'
        },
        'group_5': {
            'rgb': '/mnt/d/202109/outputs/202109_CN_RGB/Orthos/CPER_CN_Flight5_202109_RGB_ortho.tif',
            'ms': '/mnt/d/202109/outputs/202109_CN_MS/CPER_202109_CN_Flight4_MS_ortho.tif',
            'dsm': '/mnt/d/202109/outputs/202109_CN_RGB/DSMs/CPER_CN_Flight5_202109_RGB_DSM.tif'
        }
    }
}

if past_subset is not None:
    img_f_dict_tmp = img_f_dict.copy()
    img_f_dict = {}
    for k in img_f_dict_tmp:
         if k in past_subset:
                img_f_dict[k] = img_f_dict_tmp[k]

cper_f = '/mnt/c/Users/TBGPEA-Sean/Desktop/Pdogs_UAS/cper_pdog_pastures_2017_clip.shp'

In [None]:
full_buff_size = 25
full_tile_size = 100
tile_size = 256
buff_size = 64
chunk_size = 300

In [None]:
# load best saved checkpoint
if res_fnl == 2:
    best_model = torch.load('./cnn_results_' + model_fnl + '_' + str(res_fnl) + 'cm/best_model_' + '_'.join(inputs_fnl) + '.pth')
else:
    best_model = torch.load('./cnn_results_' + model_fnl + '_' + str(res_fnl) + 'cm/best_model_' + '_'.join(inputs_fnl) + '_' + str(res_fnl) + 'cm.pth')

if DEVICE == 'cpu':
    best_model = best_model.cpu()
best_model.eval()

# load the image stats from the training data
df_image_stats = pd.read_csv('./_utils/image_stats_2cm.csv').set_index('stat')

# convert image stats dictionary to dataframe
image_stats = {i: {'min': df_image_stats.loc['min', i],
                   'max': df_image_stats.loc['max', i]} for i in df_image_stats.columns}

In [None]:
#pasture = '22E'
#group = 'group_3'
fig, ax = plt.subplots()
cper_gdf = gpd.read_file(cper_f)
hfig = display(cper_gdf.plot(ax=ax, color='none', edgecolor='black'), display_id=True, clear=True)

for pasture in tqdm(img_f_dict):
    print('\n\n----------\nPasture: ' + pasture)
    for group in tqdm(img_f_dict[pasture]):
        print('---\nGroup: ' + group)

        # load in approapriate image data
        rgb_f = img_f_dict[pasture][group]['rgb']
        ms_f = img_f_dict[pasture][group]['ms']
        dsm_f = img_f_dict[pasture][group]['dsm']

        # get the bounding box of the pasture
        past_bbox = cper_gdf[cper_gdf['Past_Name_'] == pasture].buffer(
            np.ceil(buff_size * res_fnl * 0.01)).bounds.apply(lambda x: int(x))

        # open image data and mask and rename where appropriate
        rgb_xr = riox.open_rasterio(rgb_f, chunks={'y': chunk_size, 
                                                   'x': chunk_size,
                                                   'band': 1}).sel(band=slice(0, 3))
        rgb_xr = rgb_xr.where(rgb_xr != 255)
        ms_xr = riox.open_rasterio(ms_f, chunks={'y': chunk_size,
                                                 'x': chunk_size,
                                                 'band': 1}).sel(band=[4, 3])
        ms_xr = ms_xr.where(ms_xr != 65535)
        dsm_xr = riox.open_rasterio(dsm_f, chunks={'y': chunk_size,
                                                   'x': chunk_size}).squeeze().drop('band')
        dsm_xr.name = 'DSM'
        dsm_xr = dsm_xr.where(dsm_xr > 0)

        # subset image data to pasture boundaries
        rgb_xr = rgb_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))

        ms_xr = ms_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))

        dsm_xr = dsm_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))

        # get count of any null data remaining in imagery within pasture boundaries
        ms_ct_null = ms_xr.isel(band=0).isnull().sum().compute()
        dsm_ct_null = dsm_xr.isnull().sum().compute()

        # if more than 1% of the multispectral data is null
        # get the largest rectangle (to nearest 10 m) of non-null multispectral data
        if (ms_ct_null/(ms_xr.shape[1]*ms_xr.shape[2])) > 0.01:
            # coarsen imagery to approximately 10 m
            ms_1m_coarse_val = int(10.0/ms_xr.rio.resolution()[0])
            ms_1m_res = ms_1m_coarse_val * ms_xr.rio.resolution()[0]
            ms_1m = ms_xr.isel(band=1).notnull().astype('int').coarsen(x=ms_1m_coarse_val,
                                                                      y=ms_1m_coarse_val, boundary='trim').max().compute()

            # get the size of the largest rectangle with no null values
            cln_rect = max_size(ms_1m.values, value=1)
            cln_rect

            # get the number of rows and columns to iterate through to find lower-left coords of non-null rectangle
            x_chk_n = (ms_1m.x.max() - (ms_1m.x.min() + ((cln_rect[1] - 1) * ms_1m_res))) / ms_1m_res + 1
            y_chk_n = (ms_1m.y.max() - (ms_1m.y.min() + ((cln_rect[0] - 1) * ms_1m_res))) / ms_1m_res + 1

            # iterate through the rows and columns and save all starting coordinates with non-null rectangles
            x_cln_list = []
            y_cln_list = []
            for x in tqdm(np.arange(ms_1m.x.min(), ms_1m.x.min() + x_chk_n * ms_1m_res, ms_1m_res)):
                for y in np.arange(ms_1m.y.min(), ms_1m.y.min() + y_chk_n * ms_1m_res, ms_1m_res):
                    chk_null = (ms_1m.sel(x=slice(x, x + cln_rect[1] * ms_1m_res-1),
                                         y=slice(y + cln_rect[0] * ms_1m_res-1, y)) == 1).all()
                    if chk_null:
                        x_cln_list.append(x)
                        y_cln_list.append(y)


            # save the minimum starting coordinates
            coords_cln = pd.Series({'x': x_cln_list[np.argmin(y_cln_list)], 'y': np.min(y_cln_list)})
            coords_cln

            # update the extent of the multispectral image and rechunk for NDVI calc later
            ms_xr = ms_xr.sel(y=slice(coords_cln['y'] + cln_rect[0]*ms_1m_res, coords_cln['y']), 
                              x=slice(coords_cln['x'], coords_cln['x'] + cln_rect[1]*ms_1m_res)).chunk({'y': 500, 'x': 500, 'band': -1})

        # if more than 1% of the RGB DSM data is null
        # get the largest rectangle (to nearest 10 m) of non-null RGB DSM data
        if (dsm_ct_null/(dsm_xr.shape[0]*dsm_xr.shape[1])) > 0.01:
            # coarsen imagery to approximately 10 m
            dsm_1m_coarse_val = int(10.0/dsm_xr.rio.resolution()[0])
            dsm_1m_res = dsm_1m_coarse_val * dsm_xr.rio.resolution()[0]
            dsm_1m = dsm_xr.notnull().astype('int').coarsen(x=dsm_1m_coarse_val,
                                                            y=dsm_1m_coarse_val, boundary='trim').max().compute()

            # get the size of the largest rectangle with no null values
            cln_rect = max_size(dsm_1m.values, value=1)
            cln_rect

            # get the number of rows and columns to iterate through to find lower-left coords of non-null rectangle
            x_chk_n = (dsm_1m.x.max() - (dsm_1m.x.min() + ((cln_rect[1] - 1) * dsm_1m_res))) / dsm_1m_res + 1
            y_chk_n = (dsm_1m.y.max() - (dsm_1m.y.min() + ((cln_rect[0] - 1) * dsm_1m_res))) / dsm_1m_res + 1

            # iterate through the rows and columns and save all starting coordinates with non-null rectangles
            x_cln_list = []
            y_cln_list = []
            for x in tqdm(np.arange(dsm_1m.x.min(), dsm_1m.x.min() + x_chk_n * dsm_1m_res, dsm_1m_res)):
                for y in np.arange(dsm_1m.y.min(), dsm_1m.y.min() + y_chk_n * dsm_1m_res, dsm_1m_res):
                    chk_null = (dsm_1m.sel(x=slice(x, x + cln_rect[1] * dsm_1m_res-1),
                                         y=slice(y + cln_rect[0] * dsm_1m_res-1, y)) == 1).all()
                    if chk_null:
                        x_cln_list.append(x)
                        y_cln_list.append(y)


            # save the minimum starting coordinates
            coords_cln = pd.Series({'x': x_cln_list[np.argmin(y_cln_list)], 'y': np.min(y_cln_list)})
            coords_cln

            # update the extent of the RGB and DSM images
            rgb_xr = rgb_xr.sel(y=slice(coords_cln['y'] + cln_rect[0]*dsm_1m_res, coords_cln['y']), 
                              x=slice(coords_cln['x'], coords_cln['x'] + cln_rect[1]*dsm_1m_res))
            dsm_xr = dsm_xr.sel(y=slice(coords_cln['y'] + cln_rect[0]*dsm_1m_res, coords_cln['y']), 
                              x=slice(coords_cln['x'], coords_cln['x'] + cln_rect[1]*dsm_1m_res))

        del ms_ct_null, dsm_ct_null    
        client.run(gc.collect)
        client.run(trim_memory)
        # get the minimum bounding box of all non-null data
        past_bbox['minx'] = max(rgb_xr.x.min().values, ms_xr.x.min().values, dsm_xr.x.min().values, past_bbox['minx'])
        past_bbox['miny'] = max(rgb_xr.y.min().values, ms_xr.y.min().values, dsm_xr.y.min().values, past_bbox['miny'])
        past_bbox['maxx'] = min(rgb_xr.x.max().values, ms_xr.x.max().values, dsm_xr.x.max().values, past_bbox['maxx'])
        past_bbox['maxy'] = min(rgb_xr.y.max().values, ms_xr.y.max().values, dsm_xr.y.max().values, past_bbox['maxy'])
        
        total_bounds = {'xmin': past_bbox['minx'],
                        'xmax': past_bbox['maxx'],
                        'ymin': past_bbox['miny'],
                        'ymax': past_bbox['maxy']}
        
        n_row_tiles = int(np.ceil((total_bounds['ymax'] - total_bounds['ymin'])/full_tile_size))
        n_col_tiles = int(np.ceil((total_bounds['xmax'] - total_bounds['xmin'])/full_tile_size))

        
        # subset image data to the updated pasture subset boundaries
        rgb_xr = rgb_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))

        ms_xr = ms_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))

        dsm_xr = dsm_xr.sel(y=slice(past_bbox['maxy'], past_bbox['miny']), 
                            x=slice(past_bbox['minx'], past_bbox['maxx']))
        
        # plot the current pasture and bounding box of the analsis area in the output preview
        cper_gdf[cper_gdf['Past_Name_'] == pasture].plot(ax=ax)
        gpd.GeoSeries(Polygon([(past_bbox['minx'], past_bbox['miny']),
                 (past_bbox['minx'], past_bbox['maxy']),
                 (past_bbox['maxx'], past_bbox['maxy']),
                 (past_bbox['maxx'], past_bbox['miny'])])).plot(ax=ax, edgecolor='red', color='none')
        fig.canvas.draw()
        hfig.update(fig)

        outSHP = os.path.join(outDIR, 'burrow_pts_pred_' + '_'.join([pasture, group]) + '_' + '_'.join(inputs_fnl + [str(res_fnl)]) + 'cm.shp')
        if os.path.exists(outSHP):
            gdf_out = gpd.read_file(outSHP)
            r_ct_pred = len(gdf_out)
            rc_completed = gdf_out.apply(lambda x: '_'.join([str(x.tile_row), str(x.tile_col)]), axis=1).unique()
        elif os.path.exists(re.sub('.shp', '.csv', outSHP)):
            gdf_out = pd.read_csv(re.sub('.shp', '.csv', outSHP))
            r_ct_pred = len(gdf_out)
            rc_completed = gdf_out.apply(lambda x: '_'.join([str(x.tile_row), str(x.tile_col)]), axis=1).unique()
        else:
            r_ct_pred = 0
            gdf_out = gpd.GeoDataFrame()
            rc_completed = []

        try:
            client.restart(timeout=9)
            time.sleep(10)
        except TimeoutError:
            client.shutdown()
            client.close()
            cluster = LocalCluster(n_workers=8, threads_per_worker=2, processes=True)
            client = Client(cluster)
            client.amm.start()
            
        for full_r in range(n_row_tiles):
            print('running row: ' + str(full_r) + ' of ' + str(n_row_tiles))
            for full_c in tqdm(range(n_col_tiles)):
                if len(client.cluster.workers) < 8:
                    client.shutdown()
                    client.close()
                    cluster = LocalCluster(n_workers=8, threads_per_worker=2, processes=True)
                    client = Client(cluster)
                    client.amm.start()
                if '_'.join([str(full_r), str(full_c)]) in rc_completed:
                    #print('skipping row/column combination, already in shapefile!')
                    continue
                else:
                    try:
                        t0=time.time()
                        ll = [full_c * full_tile_size + total_bounds['xmin'],
                              full_r * full_tile_size + total_bounds['ymin']]
                        ul = [ll[0], ll[1] + full_tile_size]
                        ur = [x + full_tile_size for x in ll]
                        lr = [ll[0] + full_tile_size, ll[1]]

                        image_dict = {}
                        newsize_r = int(round((ul[1] - ll[1]) / (res_fnl * 0.01), 0))
                        newsize_c = int(round((lr[0] - ll[0]) / (res_fnl * 0.01), 0))
                        if 'rgb' in inputs_fnl:
                            #print('getting RGB')
                            t1=time.time()
                            image_dict['rgb'] = rgb_xr.sel(band=slice(1, 3),
                                                           x=slice(ll[0], lr[0]),
                                                           y=slice(ul[1], ll[1])).rio.reproject(
                                rgb_xr.rio.crs,
                                shape=(newsize_r, newsize_c),
                                resampling=Resampling.bilinear).values
                            #rgb_xr.close()
                            t2=time.time()
                            #print('... completed in', round(t2 - t1, 1), 'secs')
                        if 'dsm' in inputs_fnl:
                            t1 = time.time()
                            #print('getting DSM')
                            image_dict['dsm'] = dsm_xr.sel(x=slice(ll[0], lr[0]),
                                                                y=slice(ul[1], ll[1])).squeeze().rio.reproject(
                                dsm_xr.rio.crs,
                                shape=(newsize_r, newsize_c),
                                resampling=Resampling.bilinear).values
                            #dsm_xr.close()
                            t2=time.time()
                            #print('... completed in', round(t2 - t1, 1), 'secs')
                        if 'tpi' in inputs_fnl: 
                            t1 = time.time()
                            #print('computing TPI')
                            # prepare an annulus kernel with a ring at a distance from 5-10 cells away from focal point
                            outer_radius = "0.75m"
                            inner_radius = "0.25m"
                            image_dict['tpi'] = calc_tpi(dsm_xr.sel(x=slice(ll[0], lr[0]),
                                                                    y=slice(ul[1], ll[1])).squeeze().rio.reproject(
                                dsm_xr.rio.crs,
                                shape=(newsize_r, newsize_c),
                                resampling=Resampling.bilinear).chunk({'x': chunk_size,
                                                                       'y': chunk_size}), inner_r=inner_radius, outer_r=outer_radius, values=True)
                            #dsm_xr.close()
                            t2=time.time()
                            #print('... completed in', round(t2 - t1, 1), 'secs')
                        if 'ndvi' in inputs_fnl:
                            t1 = time.time()
                            #print('computing NDVI')
                            image_dict['ndvi'] = calc_ndvi(ms_xr.sel(x=slice(ll[0], lr[0]),
                                                                     y=slice(ul[1], ll[1])).rio.reproject(
                                ms_xr.rio.crs,
                                shape=(newsize_r, newsize_c),
                                resampling=Resampling.bilinear), values=True)
                            #ms_xr.close()
                            t2=time.time()
                            #print('... completed in', round(t2 - t1, 1), 'secs')

                        if 'rgb' in image_dict:
                            tshape = image_dict['rgb'].shape[1:]
                        else:
                            tshape = image_dict[inputs_fnl[0]].shape

                        n_row_chunks = int(np.ceil(tshape[0]/tile_size))
                        n_col_chunks = int(np.ceil(tshape[1]/tile_size))

                        pr_mask = np.empty(tshape)
                        t1 = time.time()
                        #print('predicting binary burrow image')
                        for r in range(n_row_chunks):
                            if (r + 1) * tile_size > tshape[0]:
                                r_min = tshape[0] - tile_size
                                r_max = tshape[0]
                                r_max_comp = tshape[0]
                            elif (r + 1) * tile_size + buff_size > tshape[0]:
                                r_min = r * tile_size
                                r_max = (r + 1) * tile_size
                                r_max_comp = r_max
                            else:
                                r_min = r * tile_size
                                r_max = (r + 1) * tile_size
                                r_max_comp = r_max + buff_size
                            for c in range(n_col_chunks):
                                image_sub_dict = {}
                                if (c + 1) * tile_size > tshape[1]:
                                    c_min = tshape[1] - tile_size
                                    c_max = tshape[1]
                                    c_max_comp = tshape[1]
                                elif (c + 1) * tile_size + buff_size > tshape[1]:
                                    c_min = c * tile_size
                                    c_max = (c + 1) * tile_size
                                    c_max_comp = c_max
                                else:
                                    c_min = c * tile_size
                                    c_max = (c + 1) * tile_size
                                    c_max_comp = c_max + buff_size
                                for k in image_dict:
                                    if k == 'rgb':
                                        image_sub_dict[k] = image_dict[k][:,
                                                                          slice(max(0, r_min-buff_size), r_max_comp),
                                                                          slice(max(0, c_min-buff_size), c_max_comp)].astype('float32')
                                    else:
                                        image_sub_dict[k] = image_dict[k][slice(max(0, r_min-buff_size), r_max_comp),
                                                                          slice(max(0, c_min-buff_size), c_max_comp)].astype('float32')
                                    if len(image_sub_dict[k].shape) == 2:
                                        image_sub_dict[k] = np.expand_dims(image_sub_dict[k], 0)
                                if preprocess:
                                    for i in image_sub_dict:
                                        image_sub_dict[i] = normalize_fn(image_sub_dict[i], i, image_stats)
                                image_list = [image_sub_dict[i] for i in inputs_fnl]
                                image_out = np.concatenate(image_list, axis=0)
                                x_tensor = torch.from_numpy(image_out).to(DEVICE).unsqueeze(0)
                                if type(best_model) == nn.DataParallel:
                                    pred_tmp = best_model.module.predict(x_tensor).cpu().detach().numpy().squeeze() >= prob_thresh
                                    buff_r_min = buff_size * int(r_min-buff_size > 0)
                                    buff_r_max = buff_size * int(r_max+buff_size < tshape[0])
                                    buff_c_min = buff_size * int(c_min-buff_size > 0)
                                    buff_c_max = buff_size * int(c_max+buff_size < tshape[1])
                                    if pred_tmp.shape[1] > tile_size:
                                        pr_mask[r_min:r_max, c_min:c_max] = pred_tmp[buff_r_min:pred_tmp.shape[0]-buff_r_max,
                                                                                     buff_c_min:pred_tmp.shape[1]-buff_c_max]
                                    else:
                                        pr_mask[r_min:r_max, c_min:c_max] = pred_tmp
                                else:
                                    pred_tmp = best_model.predict(x_tensor).cpu().detach().numpy().squeeze() >= prob_thresh
                                    buff_r_min = buff_size * int(r_min-buff_size > 0)
                                    buff_r_max = buff_size * int(r_max+buff_size < tshape[0])
                                    buff_c_min = buff_size * int(c_min-buff_size > 0)
                                    buff_c_max = buff_size * int(c_max+buff_size < tshape[1])
                                    if pred_tmp.shape[1] > tile_size:
                                        pr_mask[r_min:r_max, c_min:c_max] = pred_tmp[buff_r_min:pred_tmp.shape[0]-buff_r_max,
                                                                                     buff_c_min:pred_tmp.shape[1]-buff_c_max]
                                    else:
                                        pr_mask[r_min:r_max, c_min:c_max] = pred_tmp
                        t2=time.time()
                        #print('... completed in', round(t2 - t1, 1), 'secs')
                        t1 = time.time()
                        #print('getting burrow locations')
                        pr_labels = label(pr_mask)
                        pr_regions = regionprops(pr_labels)
                        if len(pr_regions) == 0:
                            #print('no burrow locations found!')
                            gdf_tmp = gpd.GeoDataFrame(data=pd.DataFrame({'area': ''}, index=[r_ct_pred]))
                            gdf_tmp['tile_row'] = full_r
                            gdf_tmp['tile_col'] =  full_c
                            gdf_tmp['tile_size'] = full_tile_size
                            gdf_out = pd.concat([gdf_out, gdf_tmp])
                            del gdf_tmp
                            r_ct_pred += 1
                        else:
                            for r in pr_regions:
                                if r.area*(res_fnl/100)**2 > 0.05:
                                    gdf_tmp = gpd.GeoDataFrame(data=pd.DataFrame({'area': r.area}, index=[r_ct_pred]), geometry=[Point([ll[0] + r.centroid[1]*(res_fnl*0.01),
                                                                                                                                        ul[1] - r.centroid[0]*(res_fnl*0.01)])], crs='EPSG:32613')
                                    gdf_tmp['tile_row'] = full_r
                                    gdf_tmp['tile_col'] =  full_c
                                    gdf_tmp['tile_size'] = full_tile_size
                                    gdf_out = pd.concat([gdf_out, gdf_tmp])
                                    if type(gdf_out) is pd.core.frame.DataFrame:
                                        gdf_out = gpd.GeoDataFrame(gdf_out, geometry = gdf_out['geometry'])
                                    del gdf_tmp
                                    r_ct_pred += 1
                        t2=time.time()
                        #print('... completed in', round(t2 - t1, 1), 'secs')
                        if type(gdf_out) is pd.core.frame.DataFrame:
                            gdf_out.to_csv(re.sub('.shp', '.csv', outSHP), index=False)
                        else:
                            gdf_out.to_file(outSHP)
                        del pr_mask, pred_tmp, pr_labels, pr_regions, image_dict, image_sub_dict, image_list, image_out
                        client.run(gc.collect)
                        client.run(trim_memory)
                        if full_c % 2 == 0:
                            try:
                                client.restart(timeout=9)
                                time.sleep(10)
                            except TimeoutError:
                                client.shutdown()
                                client.close()
                                cluster = LocalCluster(n_workers=8, threads_per_worker=2, processes=True)
                                client = Client(cluster)
                                client.amm.start()
                            
                        #client.restart()
                    except NoDataInBounds:
                        continue
                        #print('No data in bounds. Skipping row/column.')
            if not '_'.join([str(full_r), str(full_c)]) in rc_completed:
                try:
                    client.restart(timeout=9)
                    time.sleep(10)
                except TimeoutError:
                    client.shutdown()
                    client.close()
                    cluster = LocalCluster(n_workers=8, threads_per_worker=2, processes=True)
                    client = Client(cluster)
                    client.amm.start()
        print('Pasture-group finished!')
        