In [None]:
import numpy as np
from sklearn.cluster import DBSCAN
from tifffile import imshow, imsave
from tqdm import tqdm
import matplotlib.pyplot as plt
from time import time

In [None]:
def parse_image(path, all_channels_last=False):
  im = np.load(path)
  img = im['arr_0']
  if all_channels_last:
    img = np.moveaxis(img, 0, 2)
    img = img.reshape(img.shape[0], img.shape[1], -1)
  return img

In [None]:
def check_intersection(img):
    mask = np.ones((955,955,3), dtype=np.bool)
    for i in range(len(img)):
        msk = img[i] == 0
        mask = mask * msk
    print(np.count_nonzero(mask))

In [None]:
path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/red_nir/input/L8_0.npz'

img_volume = parse_image(path)

In [None]:
from catboost import Pool, CatBoostRegressor
from skimage.morphology import dilation, disk
from catboost import CatBoostError

In [None]:
def create_train_data(target_image, target_mask, train_mask, j, window_size):
    target_image_window = target_image[:, j: j + window_size]
    target_mask_window = target_mask[:, j: j + window_size]
    if np.count_nonzero(target_mask_window) == 0:
        return 0, 0, 0, 0, 0
    train_mask_window = train_mask[:, j: j + window_size]
    training_images_window = training_images[:, :, j: j + window_size]
    train_data = np.array([training_images_window[m][train_mask_window] for m in range(len(training_images_window))])
    train_label = target_image_window[train_mask_window]
    train_data = np.moveaxis(train_data, 0, 1)
    mask_to_fill = train_data == 0
    train_data[mask_to_fill] = None
    restore_data = np.array([training_images_window[m][target_mask_window] for m in range(len(training_images_window))])
    restore_data = np.moveaxis(restore_data, 0, 1)
    mask_to_fill = restore_data == 0
    restore_data[mask_to_fill] = None
    return train_data, train_label, restore_data, target_mask_window, target_image_window


In [None]:
def create_order(img):
    number_of_skips = [np.count_nonzero(img[i] == 0) for i in range(img.shape[0])]
    order = [i for i in range(img.shape[0])]
    order = [x for x, _ in sorted(zip(order, number_of_skips), key=lambda pair: pair[1])]
    return order

def create_horizontal_iteration(img, mask, threshold=1, n=5, total_max_length=65, max_covered_length=20):
    total_max_length -= 2
    horizontal_iteration = [0]
    sliding_average = img[mask][0]
    threshold = threshold * np.std(img[mask])
    for j in range(0, img.shape[-1]):
        current_value = img[:, j][mask[:,j]]
        if current_value.size == 0:
            continue
        current_value = np.average(current_value)
        if np.abs(sliding_average - current_value) >= threshold:
            horizontal_iteration.append(j)
            sliding_average = current_value
        else:
            sliding_average = (sliding_average * (n - 1) + current_value) / n
        if j - horizontal_iteration[-1] == max_covered_length:
            horizontal_iteration.append(j)
    ratio = len(horizontal_iteration) / total_max_length
    if ratio > 1:
        horizontal_iteration_temp = [0]
        main_number, lagging_number = 0, 0
        for elem in horizontal_iteration:
            main_number += 1
            if main_number // ratio > lagging_number:
                lagging_number += 1
                horizontal_iteration_temp.append(elem)
        horizontal_iteration = horizontal_iteration_temp
    horizontal_iteration.append(img.shape[-1])
    return horizontal_iteration

In [None]:
range_object = range(2,3,1)
# for i in range_object:
#     imshow(img_volume[i,:,:,1], vmax=1, vmin=0)

In [None]:
for channel in range(0,2):
    img = img_volume[:,:,:,channel]
    for target_image_number in range_object:
        target_image = img[target_image_number]
        training_images = np.concatenate([img[:target_image_number], img[target_image_number+1:]], axis=0)
        target_mask = target_image == 0
        train_mask = dilation(target_mask, disk(1))
        # train_mask = train_mask ^ target_mask
        restore_indexes = np.where(train_mask)
        indexes = np.moveaxis(np.array(restore_indexes), 0, 1)
        labels = DBSCAN(eps=1.5).fit_predict(indexes)
        train_masks = []
        target_masks = []
        for i in set(labels):
            ind = np.where(labels == i)
            train_mask_2 = np.zeros(shape=train_mask.shape, dtype=np.bool)
            rst = tuple((restore_indexes[0][ind], restore_indexes[1][ind]))
            train_mask_2[rst] = True
            target_masks.append(train_mask_2 * target_mask)
            train_masks.append(train_mask_2 ^ target_masks[-1])
        # print(len(set(labels)))
        # break
        # train_mask = target_image != 0
        # noinspection PyUnreachableCode
        for train_mask_current, target_mask_current in tqdm(zip(train_masks, target_masks), total=len(target_masks)):
            model = CatBoostRegressor(
                    depth=6,
                    learning_rate=0.2,
                    loss_function='RMSE',
                    verbose=0, num_trees=40
                )
            threshold = 0.73
            max_length = 36
            horizontal_iteration = create_horizontal_iteration(target_image,
                                                               train_mask_current,
                                                               threshold=threshold,
                                                               n=5, total_max_length=57,
                                                               max_covered_length=max_length)
            print(len(horizontal_iteration))
            for m in range(len(horizontal_iteration) - 1):

                j = horizontal_iteration[m]
                window_size = horizontal_iteration[m+1] - j
                train_data, train_label, restore_data,\
                target_mask_window, target_image_window = create_train_data(target_image,
                                                                           target_mask_current,
                                                                           train_mask_current,
                                                                           j, window_size)
                if type(train_data) == int:
                    continue

                # imshow(target_mask_window)
                # imshow(target_image_window)
                # print(train_data, '|', train_label)
                train_pool = Pool(train_data, train_label)


                # tick = time()
                # model = CatBoostRegressor(
                #     depth=6,
                #     learning_rate=0.2,
                #     loss_function='RMSE',
                #     verbose=0, num_trees=40
                # )
                try:
                    model.fit(train_pool)
                    restore_pool = Pool(restore_data)
                    res = model.predict(restore_pool)
                except CatBoostError:
                    res = np.full((restore_data.shape[0],), train_label[0])
                # tack = time()
                # print((tack - tick) * 1000)
                # break
                target_image_window[target_mask_window] = res
                img[target_image_number, :, j: j + window_size] = target_image_window
            # break
        img[target_image_number] = np.clip(img[target_image_number], 0, 1)
    img_volume[:,:,:,channel] = img

In [None]:
path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/red_nir/output/L8_0.npz'

img2 = parse_image(path)
# img2 = img2[:,:,:,0]

In [None]:
for i in range_object:
    imshow(img_volume[i,:,:,0], vmax=1, vmin=0)
    imshow(img2[i,:,:,0], vmax=1, vmin=0)

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error

In [None]:
for target_image_number in range_object:
    print(mean_absolute_error(img2[target_image_number].reshape(-1), img_volume[target_image_number].reshape(-1)))
    # print(mean_squared_error(img2[target_image_number, :, :, 0].reshape(-1), img[target_image_number].reshape(-1)) ** (1 / 2))
    # imsave('res{}.tif'.format(target_image_number), img[target_image_number])

In [None]:
def NDVI(img):
    ndvi = (img[:,:,:,1] - img[:,:,:,0])/(img[:,:,:,0] + img[:,:,:,1] + 1E-6)
    return ndvi

In [None]:
VI_good = NDVI(img2)
VI_restored = NDVI(img_volume)


In [None]:
for target_image_number in range(img_volume.shape[0]):
    print(mean_absolute_error(VI_good[target_image_number], VI_restored[target_image_number]))
    imshow(VI_restored[target_image_number])
    imshow(VI_good[target_image_number])
    # imsave('res{}.tif'.format(target_image_number), VI_restored[target_image_number])

In [None]:
# from tifffile import imread
#
# for im_num in range(7):
#     a = imread('res{}.tif'.format(im_num))
#     imshow(a)
#     # a = NDVI(a)
#     # imsave('res{}.tif'.format(im_num), a)
#
# imshow(a)
# imshow(img_volume[0])