In [None]:
from PIL import Image, ImageEnhance
import requests
from tempfile import NamedTemporaryFile
from sklearn.cluster import MiniBatchKMeans
from scipy import ndimage
import numpy
import pathlib
import matplotlib.pyplot as plt
import PIL
import rioxarray as xr
import rasterio
from tqdm import tqdm 



PATH_RASTER = '../../data/1.tiff'
PATH_OUTPUT = '../../data/1_outu.tiff'
n_k = 6


ds = xr.open_rasterio(PATH_RASTER)


# # Alter image 
# Try to make the classes as different as possible. color works the best in a lot of cases

# SPECIFY FILTERS (1 = UNCHANGED)
constrast = 1.5
color = 1.5
sharpness = 1.5
brightness = 1.0


# alter input only if different from default (1.0)
if constrast != 1.0 and color != 1.0 and sharpness != 1.0 and brightness != 1.0:
    # load data
    im = ds.data
    im = numpy.moveaxis(im, 0, 2)
    im = PIL.Image.fromarray(im).convert('RGB')

    # apply modifications
    enhancer = ImageEnhance.Contrast(im)
    im = enhancer.enhance(constrast)
    enhancer = ImageEnhance.Color(im)
    im = enhancer.enhance(color)
    enhancer = ImageEnhance.Sharpness(im)
    im = enhancer.enhance(sharpness)
    enhancer = ImageEnhance.Brightness(im)
    im = enhancer.enhance(brightness)

    # save preview for debugging
    im.save('temp.png')

    new = numpy.array(im)

    # save back togheter with othe bands
    ds[dict(band=0)] = new[:,:,0]
    ds[dict(band=1)] = new[:,:,1]
    ds[dict(band=2)] = new[:,:,2]





if ds.size < 15000000:
    im = ds.to_numpy()

else: 
    coarsened = ds.coarsen(x=20, y=20, boundary='trim').mean()
    coarsened = coarsened.compute()
    smaller_array = coarsened.to_numpy()
    im = smaller_array


im = numpy.moveaxis(im, 0, 2)
im = im[:,:,:]
im = numpy.asarray(im)
shape_x, shape_y, shape_z = im.shape
image_2d = im.reshape(shape_x*shape_y, shape_z)


# COMPUTE CLUSTERS
kmeans_cluster = MiniBatchKMeans(n_clusters=n_k,
                          random_state=0,
                          batch_size=150, verbose=1, compute_labels=False).fit(image_2d)


ls = list(ds.coords)
for x in ['lon', 'longitude', 'x']:
    for y in ['lat', 'latitude', 'y']:
        if (x in ls) & (y in ls):
            xp = ls.index(x)
            yp = ls.index(y)
            bp = ls.index('band')
            xc = x
            yc = y

ds = ds.rename({xc:'x', yc:'y'})

TILE_RES = 512

shape_x = ds[x].shape[0]
shape_y = ds[y].shape[0]



for i in tqdm(range(shape_x // TILE_RES +1)) :
    for j in range(shape_y // TILE_RES +1) :

        #i +=1
        i0 = i * TILE_RES
        j0 = j * TILE_RES

        # EXTRACT PART OF TIF
        A_data = ds.isel(x=slice(i0,i0+TILE_RES), y=slice(j0, j0+TILE_RES))

        A = A_data.to_numpy()
        original_w, original_h = A.shape[1], A.shape[2]

        A = numpy.moveaxis(A, 0, -1)
        A_reshaped = A.reshape(A.shape[0] * A.shape[1], A.shape[2])
        y = kmeans_cluster.predict(A_reshaped)
        y = y.reshape(A.shape[0], A.shape[1])
        y = y[0:0 + A.shape[0], 0:0 + A.shape[1]]
        
        # change color in DataArray
        ds[dict(x=slice(i0,i0+TILE_RES), y=slice(j0,j0+TILE_RES), band=0)] = y


ds.rio.to_raster(PATH_OUTPUT)





