In [None]:
import numpy as np
from sklearn.cluster import DBSCAN
from tifffile import imshow, imsave
from tqdm import tqdm
import matplotlib.pyplot as plt
from time import time

In [None]:
def parse_image(path, all_channels_last=False):
  im = np.load(path)
  img = im['arr_0']
  if all_channels_last:
    img = np.moveaxis(img, 0, 2)
    img = img.reshape(img.shape[0], img.shape[1], -1)
  return img

In [None]:
def check_intersection(img):
    mask = np.ones((955,955,3), dtype=np.bool)
    for i in range(len(img)):
        msk = img[i] == 0
        mask = mask * msk
    print(np.count_nonzero(mask))

In [None]:
path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/red_nir/input/L8_0.npz'

img = parse_image(path)
# img = img.reshape(*img.shape, 1)

In [None]:
from catboost import Pool, CatBoostRegressor
from skimage.morphology import dilation, disk, square

In [None]:
def create_train_data(target_image, target_mask, train_mask, j, window_size):
    train_data_expanded = []
    train_label_expanded = []
    restore_data_expanded = []

    target_mask_window = target_mask[:, j: j + window_size]
    for channel in range(target_image.shape[-1]):

        target_image_window = target_image[:, j: j + window_size, channel]
        if np.count_nonzero(target_mask_window) == 0:
            return 0, 0, 0, 0, 0
        train_mask_window = train_mask[:, j: j + window_size]
        training_images_window = training_images[:, :, j: j + window_size, channel]
        train_data = np.array([training_images_window[m][train_mask_window] for m in range(len(training_images_window))])
        train_label = target_image_window[train_mask_window]
        train_data = np.moveaxis(train_data, 0, 1)
        mask_to_fill = train_data == 0
        train_data[mask_to_fill] = None
        restore_data = np.array([training_images_window[m][target_mask_window] for m in range(len(training_images_window))])
        restore_data = np.moveaxis(restore_data, 0, 1)
        mask_to_fill = restore_data == 0
        restore_data[mask_to_fill] = None
        train_data_expanded.append(train_data)
        train_label_expanded.append(train_label)
        restore_data_expanded.append(restore_data)

    train_data_expanded = np.array(train_data_expanded).reshape(-1, len(training_images))
    # train_data_expanded = train_data_expanded.reshape(-1, train_data_expanded.shape[-1])
    train_label_expanded = np.array(train_label_expanded).reshape(-1)

    return train_data_expanded, train_label_expanded, restore_data_expanded,\
           target_mask_window, target_image[:, j: j + window_size]

quality mode = 19,
performance mode = 96

In [None]:
window_size = 19
threshold = img.shape[1] - window_size

for target_image_number in range(6, -1, -1):

    target_image = img[target_image_number]
    training_images = np.concatenate([img[:target_image_number], img[target_image_number+1:]], axis=0)
    target_mask_volumetric = target_image == 0
    target_mask = np.zeros(target_mask_volumetric.shape[:-1], dtype=np.bool)
    for i in range(target_mask_volumetric.shape[-1]):
        target_mask += target_mask_volumetric[:,:,i]
    train_mask = dilation(target_mask, disk(10))
    # train_mask = train_mask ^ target_mask
    # imshow(train_mask)
    # break
    restore_indexes = np.where(train_mask)
    indexes = np.moveaxis(np.array(restore_indexes), 0, 1)
    labels = DBSCAN(eps=1.5).fit_predict(indexes)
    train_masks = []
    target_masks = []
    for i in set(labels):
        ind = np.where(labels == i)
        train_mask_2 = np.zeros(shape=train_mask.shape, dtype=np.bool)
        rst = tuple((restore_indexes[0][ind], restore_indexes[1][ind]))
        train_mask_2[rst] = True
        target_masks.append(train_mask_2 * target_mask)
        train_masks.append(train_mask_2 ^ target_masks[-1])
    print(len(set(labels)))
    # break
    # train_mask = target_image != 0
    # noinspection PyUnreachableCode
    for j in tqdm(range(0, img.shape[2], window_size)):
        for train_mask_current, target_mask_current in zip(train_masks, target_masks):
            if j > threshold:
                j = threshold

            train_data, train_label, restore_data, target_mask_window, target_image_window = create_train_data(target_image,
                                                                                                               target_mask_current,
                                                                                                               train_mask_current,
                                                                                                               j, window_size)
            # break
            if type(train_data) == int:
                continue
            train_pool = Pool(train_data, train_label)

            # tick = time()
            model = CatBoostRegressor(
                depth=6,
                learning_rate=0.2,
                loss_function='RMSE',
                verbose=0, num_trees=40
            )
            model.fit(train_pool)
            # tack = time()
            # print((tack - tick) * 1000)
            # break
            for i, data in enumerate(restore_data):
                restore_pool = Pool(data)
                res = model.predict(restore_pool)
                target_image_window[:,:,i][target_mask_window] = res
            img[target_image_number, :, j: j + window_size] = target_image_window
        # break
    img[target_image_number] = np.clip(img[target_image_number], 0, 1)


In [None]:
path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/red_nir/output/L8_0.npz'

img2 = parse_image(path)
# img2 = img2.reshape(*img2.shape, 1)

In [None]:
for i in range(len(img)):
    imshow(img[i], vmax=1, vmin=0)
    imshow(img2[i], vmax=1, vmin=0)

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error

In [None]:
for target_image_number in range(len(img)):
    print(mean_absolute_error(img2[target_image_number].reshape(-1), img[target_image_number].reshape(-1)))
    # print(mean_squared_error(img2[target_image_number, :, :, 0].reshape(-1), img[target_image_number].reshape(-1)) ** (1 / 2))
    # imsave('res{}.tif'.format(target_image_number), img[target_image_number])

In [None]:
def NDVI(img):
    ndvi = (img[:,:,:,1] - img[:,:,:,0])/(img[:,:,:,0] + img[:,:,:,1] + 1E-6)
    return ndvi

In [None]:
VI_good = NDVI(img2)
VI_restored = NDVI(img)


In [None]:
for target_image_number in range(7):
    print(mean_absolute_error(VI_good[target_image_number], VI_restored[target_image_number]))
    imshow(VI_restored[target_image_number])
    imshow(VI_good[target_image_number])
    imsave('res{}.tif'.format(target_image_number), VI_restored[target_image_number])