# Filter pixel classifications results by segments


In [None]:
import os
import sys
import gdal
import shutil
import xarray as xr
import geopandas as gpd
import subprocess as sp
from datacube.utils.cog import write_cog
from rsgislib.segmentation import segutils
from scipy.ndimage.measurements import _stats

sys.path.append('../../Scripts')
from deafrica_classificationtools import HiddenPrints

%load_ext autoreload
%autoreload 2

# Analysis Parameters

In [None]:
test_shapefile = 'data/eastern_testing_sites_2.geojson'

results = 'results/classifications/'

model_type='gm_mads_two_seasons_20201123'

min_seg_size=100 #in number of pixels

### Open testing tile shapefile

In [None]:
gdf = gpd.read_file(test_shapefile)

## Image segmentation

1. Generate image segments using `rsgislib.ShepherdSeg`
2. Find majority (mode) value across each segment
5. Write object-based classification to disk

In [None]:
%%time
for g_id in gdf['GRID_ID'][0:1].values:
    print('working on grid: ' + g_id)
    
    #store temp files somewhere
    directory=results+'tmp_'+g_id
    if not os.path.exists(directory):
        os.mkdir(directory)
    
    tmp='tmp_'+g_id+'/'
    
    nc = results+'input/Eastern_tile_'+g_id+'_inputs.nc'
    ds = xr.open_dataset(nc)
    ds=ds[['NDVI_S1', 'NDVI_S2']]
    write_cog(ds.to_array(), results+'Eastern_tile_'+g_id+'_NDVI.tif',overwrite=True)
    
    #inputs to image seg
    tiff_to_segment = results+'Eastern_tile_'+g_id+'_NDVI.tif'
    kea_file = results+'Eastern_tile_'+g_id+'_NDVI.kea'
    segmented_kea_file = results+'Eastern_tile_'+g_id+'_segmented.kea'

    #convert tiff to kea
    gdal.Translate(destName=kea_file,
                   srcDS=tiff_to_segment,
                   format='KEA',
                   outputSRS='EPSG:6933')
    
    #run image seg
    print('   image segmentation...')
    with HiddenPrints():
        segutils.runShepherdSegmentation(inputImg=kea_file,
                                             outputClumps=segmented_kea_file,
                                             tmpath=results+tmp,
                                             numClusters=60,
                                             minPxls=min_seg_size)
    
    #open segments, and predictions
    segments=xr.open_rasterio(segmented_kea_file).squeeze().values
    t = results+ 'predicted/Eastern_tile_'+g_id+'_prediction_pixel_'+model_type+'.tif'
    pred = xr.open_rasterio(t).squeeze().drop_vars('band')
    
    #calculate mode
    count, _sum =_stats(pred, labels=segments, index=segments)
    mode = _sum > (count/2)
    mode = xr.DataArray(mode,  coords=pred.coords, dims=pred.dims, attrs=pred.attrs).astype(np.int16)
    
    #write to disk
    write_cog(mode, results+ 'predicted/Eastern_tile_'+g_id+'_prediction_object_'+model_type+'.tif', overwrite=True)
    
    #remove the tmp folder
    shutil.rmtree(results+tmp)
    os.remove(kea_file)
    os.remove(segmented_kea_file)
    os.remove(results+'Eastern_tile_'+g_id+'_NDVI.tif')

In [None]:
# xr.open_rasterio(results+ 'predicted/Eastern_tile_'+g_id+'_prediction_object_'+model_type+'.tif').plot(size=12);

***
## RSGISlib Shepherd Seg tiled examples

### single cpu, tiled

In [None]:
# from rsgislib.rastergis import populateRATWithMode
# from rsgislib.rastergis import ratutils

# populateRATWithMode(valsimage=results+'Eastern_tile_'+g_id+'_prediction_pixel_'+model_type+'.tif',
#                    clumps=results+'Eastern_tile_'+g_id+'_segmented.kea',
#                    outcolsname ='mode')

# ratutils.populateImageStats(
#               inputImage=results+'Eastern_tile_'+g_id+'_prediction_pixel_'+model_type+'.tif',
#               clumpsFile=results+'Eastern_tile_'+g_id+'_segmented.kea',
#               calcSum=True
# )

In [None]:
%time
# #run the segmentation
with HiddenPrints():
    tiledsegsingle.performTiledSegmentation(kea_file,
                                    segmented_kea_file,
                                    tmpDIR=temp,
                                    numClusters=60,
                                    validDataThreshold=validDataTileFraction, 
                                    tileWidth=width,
                                    tileHeight=height,
                                    minPxls=9)

In [None]:
# Attribute segments with zonal mean of input image and output as geotiff
meanImage(tiff_to_segment, segmented_kea_file, segments_zonal_mean, "GTIFF",rsgislib.TYPE_32FLOAT)

### n cpus, tiled

In [None]:
# %time
#run the segmentation
with HiddenPrints():
    tiledSegParallel.performTiledSegmentation(kea_file,
                                segmented_kea_file,
                                tmpDIR=temp,
                                numClusters=60,
                                validDataThreshold=validDataTileFraction, 
                                tileWidth=width,
                                tileHeight=height,
                                minPxls=9,
                                ncpus=ncpus)

# Attribute segments with zonal mean of input image and output as geotiff
meanImage(tiff_to_segment, segmented_kea_file, segments_zonal_mean, "GTIFF",rsgislib.TYPE_32FLOAT)