In [None]:
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.cluster import DBSCAN
from tifffile import imshow, imsave
from tqdm import tqdm
import matplotlib.pyplot as plt
from time import time
from utils import create_mask
from catboost import Pool, CatBoostRegressor, CatBoostError
from skimage.morphology import dilation, disk

In [None]:
def NDVI(img, last_channel=False):
    ndvi = (img[:,:,:,1] - img[:,:,:,0])/(img[:,:,:,1] + img[:,:,:,0] + 1E-6)
    if last_channel:
        return np.expand_dims(ndvi, -1)
    return ndvi


def parse_image(path, all_channels_last=False):
  im = np.load(path)
  img = im['arr_0']
  if all_channels_last:
    img = np.moveaxis(img, 0, 2)
    img = img.reshape(img.shape[0], img.shape[1], -1)
  return img


def check_intersection(img):
    mask = np.ones((955,955), dtype=bool)
    for i in range(len(img)):
        msk = img[i] == 0
        mask = mask * msk
    return np.count_nonzero(mask)

In [None]:
def create_train_data(target_image, training_images, target_mask, train_mask, j, window_size, geospatial=True):
    target_image_window = target_image[:, j: j + window_size]
    target_mask_window = target_mask[:, j: j + window_size]
    if np.count_nonzero(target_mask_window) == 0:
        return 0, 0, 0, 0, 0
    train_mask_window = train_mask[:, j: j + window_size]
    training_images_window = training_images[:, :, j: j + window_size]
    train_data = np.array([training_images_window[m][train_mask_window] for m in range(len(training_images_window))])
    train_label = target_image_window[train_mask_window]
    train_data = np.moveaxis(train_data, 0, 1)
    if geospatial:
        coordinates = np.array(np.where(train_mask_window))
        coordinates = np.moveaxis(coordinates, 0, 1)
        train_data = np.concatenate([train_data, coordinates], axis=1)
    restore_data = np.array([training_images_window[m][target_mask_window] for m in range(len(training_images_window))])
    restore_data = np.moveaxis(restore_data, 0, 1)
    if geospatial:
        coordinates = np.array(np.where(target_mask_window))
        coordinates = np.moveaxis(coordinates, 0, 1)
        restore_data = np.concatenate([restore_data, coordinates], axis=1)
    return train_data, train_label, restore_data, target_mask_window, target_image_window

def create_train_data_quick(target_image, training_images, target_mask, train_mask, i, j, window_size, geospatial=True):
    target_image_window = target_image[i: i + window_size, j: j + window_size]
    target_mask_window = target_mask[i: i + window_size, j: j + window_size]
    if np.count_nonzero(target_mask_window) == 0:
        return 0, 0, 0, 0, 0
    train_mask_window = train_mask[i: i + window_size, j: j + window_size]
    training_images_window = training_images[:, i: i + window_size, j: j + window_size]
    train_data = np.array([training_images_window[m][train_mask_window] for m in range(len(training_images_window))])
    train_label = target_image_window[train_mask_window]
    train_data = np.moveaxis(train_data, 0, 1)
    if geospatial:
        coordinates = np.array(np.where(train_mask_window))
        coordinates = np.moveaxis(coordinates, 0, 1)
        train_data = np.concatenate([train_data, coordinates], axis=1)
    restore_data = np.array([training_images_window[m][target_mask_window] for m in range(len(training_images_window))])
    restore_data = np.moveaxis(restore_data, 0, 1)
    if geospatial:
        coordinates = np.array(np.where(target_mask_window))
        coordinates = np.moveaxis(coordinates, 0, 1)
        restore_data = np.concatenate([restore_data, coordinates], axis=1)
    return train_data, train_label, restore_data, target_mask_window, target_image_window


def create_order(img):
    number_of_skips = [np.count_nonzero(img[i] == 0) for i in range(img.shape[0])]
    order = [i for i in range(img.shape[0])]
    order = [x for x, _ in sorted(zip(order, number_of_skips), key=lambda pair: pair[1])]
    return order

def create_horizontal_iteration(img, mask, epsilon=1, n=5, total_max_segments=65, max_covered_length=35):
    total_max_segments -= 2
    horizontal_iteration = [0]
    sliding_average = img[mask][0]
    threshold = epsilon * np.std(img[mask])
    for j in range(0, img.shape[-1]):
        current_value = img[:, j][mask[:,j]]
        if current_value.size == 0:
            continue
        current_value = np.average(current_value)
        if np.abs(sliding_average - current_value) >= threshold:
            horizontal_iteration.append(j)
            sliding_average = current_value
        else:
            sliding_average = (sliding_average * (n - 1) + current_value) / n
        if j - horizontal_iteration[-1] == max_covered_length:
            horizontal_iteration.append(j)
    # print(total_max_segments, len(horizontal_iteration))
    try:
        ratio = len(horizontal_iteration) / total_max_segments
    except ZeroDivisionError:
        return [0, 955]
    if ratio > 1:
        horizontal_iteration_temp = [0]
        main_number, lagging_number = 0, 0
        for elem in horizontal_iteration:
            main_number += 1
            if main_number // ratio > lagging_number:
                lagging_number += 1
                horizontal_iteration_temp.append(elem)
        horizontal_iteration = horizontal_iteration_temp
    horizontal_iteration.append(img.shape[-1])
    return horizontal_iteration

In [None]:
def dummy_restore(img_volume, mask_volume):
    def roll_down(mask, i):
        for m in range(i, i + 30):
            if not mask[m]:
                return m
            return m

    for channel in range(img_volume.shape[3]):
        mask = img_volume[:,:,:,channel] == 0
        mask = mask * mask_volume
        if np.count_nonzero(mask) != 0:
            indexes = np.where(mask)
            (numbers, i_s, j_s) = indexes
            for number, i, j in zip(numbers.tolist(), i_s.tolist(), j_s.tolist()):
                img_volume[number, i, j, channel] = img_volume[number, i - 1, j, channel] if i - 1 > 0 else img_volume[number, roll_down(mask_volume[number, :, j], i), j, channel]
    return img_volume

# def dummy_restore_2(img_volume, mask_volume):
#     for channel in range(img_volume.shape[3]):
#         if np.count_nonzero(mask_volume) != 0:
#             indexes = np.where(np.rot90(mask_volume, axes=(1,2)))
#             (numbers, j_s, i_s) = indexes
#             for number in range(len(mask_volume)):
#                 i_s_sup = i_s[numbers == number]
#                 j_s_sup = j_s[numbers == number]
#                 for j in range(mask_volume.shape[2]):
#                     i_s_sup_sup = i_s_sup[j_s_sup == mask_volume.shape[2] - j - 1]
#                     to_restore = []
#                     i1 = None
#                     for i in range(len(i_s_sup_sup) - 1):
#                         if i1 is None:
#                             i1 = i_s_sup_sup[i] - 1
#                             continue
#                         if i_s_sup_sup[i + 1] - i_s_sup_sup[i] > 1:
#                             to_restore.append([i1, i_s_sup_sup[i] + 1])
#                             i1 = None
#                     if not i1 is None:
#                         to_restore.append([i1, i_s_sup_sup[-1] + 1])
#
#                     for edges in to_restore:
#                         if edges[1] - edges[0] > 30:
#                             continue
#                         img_volume[number, edges[0] + 1: edges[1], j, channel] = (img_volume[number, edges[0], j, channel] if edges[0] != -1 else img_volume[number, edges[1], j, channel] + img_volume[number, edges[1], j, channel] if edges[1] <= img_volume.shape[1] else img_volume[number, edges[0], j, channel] ) / 2
#             return img_volume


In [None]:
def quick_restore(
        img_volume,
        mask_volume,
        restoration_order,
        window_size=479,
        min_value=0.,
        max_value=1.,
        verbose=True,
        geospatial=True
):
    try:
        assert img_volume.shape[0] == mask_volume.shape[0]
    except AssertionError:
        print('First dimensions of img_volume and mask_volume must be equal!')

    model = CatBoostRegressor(
        learning_rate=0.2,
        depth=5,
        loss_function='RMSE',
        verbose=0,
        num_trees=300

    )

    for target_image_number in restoration_order:
        target_image = img_volume[target_image_number]
        target_mask = mask_volume[target_image_number]
        if np.count_nonzero(target_mask) == 0:
            continue
        train_mask = dilation(target_mask, disk(1))
        train_mask = train_mask ^ target_mask
        training_images = np.concatenate([img_volume[:target_image_number], img_volume[target_image_number+1:]], axis=0)
        threshold = (target_image.shape[0] - window_size, target_image.shape[1] - window_size)
        for i in tqdm(range(0, target_image.shape[0], window_size), disable=not verbose):
            for j in range(0, target_image.shape[1], window_size):
                if i > threshold[0]:
                    i = threshold[0]
                if j > threshold[1]:
                    j = threshold[1]
                for channel in range(img_volume.shape[-1]):
                    train_data, train_label, restore_data,\
                    target_mask_window, target_image_window = create_train_data_quick(
                        target_image[:,:,channel],
                        training_images[:,:,:,channel],
                        target_mask,
                        train_mask, i, j,
                        window_size, geospatial=geospatial
                    )
                    if type(train_data) == int:
                        continue
                    train_pool = Pool(train_data, train_label)
                    try:
                        model.fit(train_pool)
                        restore_pool = Pool(restore_data)
                        res = model.predict(restore_pool)
                    except CatBoostError:
                        res = np.full((restore_data.shape[0],), np.average(train_label[0]))
                    target_image_window[target_mask_window] = res
                    img_volume[target_image_number, i:i + window_size, j: j + window_size, channel] = target_image_window
    img_volume = np.nan_to_num(img_volume, nan=0.0)
    img_volume = np.clip(img_volume, min_value, max_value)
    return img_volume



In [None]:
def restore_images(
        img_volume,
        mask_volume,
        restoration_order,
        epsilon=0.8,
        max_length=150,
        masks_to_sum=1,
        average_segment_length = 318,
        min_value=0.,
        max_value=1.,
        verbose=True,
        geospatial=True
):

    try:
        assert img_volume.shape[0] == mask_volume.shape[0]
    except AssertionError:
        print('First dimensions of img_volume and mask_volume must be equal!')
    # for channel in range(img_volume.shape[-1]):
    #     img_volume[:,:,:,channel][~mask_volume] = None
    # a = img_volume == None
    # for i in range(len(a)):
    #     imshow(a[i])
    # return img_volume
    total_max_segments = int(img_volume.shape[1] // average_segment_length)
    model = CatBoostRegressor(
        learning_rate=0.2,
        depth=4,
        loss_function='RMSE',
        verbose=0,
        num_trees=210

    )
    for target_image_number in restoration_order:
        target_image = img_volume[target_image_number]
        # complementary_mask = mask_volume_2[target_image_number]
        # target_mask = create_mask(complementary_mask)
        target_mask = mask_volume[target_image_number]
        if np.count_nonzero(target_mask) == 0:
            continue
        train_mask = dilation(target_mask, disk(1))
        restore_indexes = np.where(train_mask)
        indexes = np.moveaxis(np.array(restore_indexes), 0, 1)
        labels = DBSCAN(eps=1.5).fit_predict(indexes)
        train_masks = []
        target_masks = []
        for i in set(labels):
            ind = np.where(labels == i)
            train_mask_2 = np.zeros(shape=train_mask.shape, dtype=bool)
            rst = tuple((restore_indexes[0][ind], restore_indexes[1][ind]))
            train_mask_2[rst] = True
            target_masks.append(train_mask_2 * target_mask)
            train_masks.append((train_mask_2 ^ target_masks[-1]) * ~target_mask)
        check_mask = []
        # max_area = np.max(np.count_nonzero(target_masks, axis=(1,2)))
        lengths = []

        max_area = target_masks[0].shape[-1]
        train_masks_aux = [np.zeros(train_masks[0].shape, dtype=bool)]
        target_masks_aux = [np.zeros(train_masks[0].shape, dtype=bool)]
        i = 0
        for train_mask, target_mask in zip(train_masks, target_masks):
            if i < masks_to_sum:
                train_masks_aux[-1] = train_masks_aux[-1] + train_mask
                target_masks_aux[-1] = target_masks_aux[-1] + target_mask
                i += 1
            else:
                i = 1
                train_masks_aux.append(train_mask)
                target_masks_aux.append(target_mask)
        train_masks = train_masks_aux
        target_masks = target_masks_aux
        for _ in range(len(target_masks)):
            auxiliary = np.where(target_masks[_])[1]
            lengths.append(auxiliary[-1] - auxiliary[0])
        # train_masks
        # print(lengths)
        training_images = np.concatenate([img_volume[:target_image_number], img_volume[target_image_number+1:]], axis=0)
        for train_mask_current, target_mask_current, current_area in tqdm(zip(train_masks, target_masks, lengths), total=len(target_masks), disable=not verbose):
            # current_area = np.count_nonzero(target_mask_current)
            # print(current_area, max_area)

            horizontal_iteration = create_horizontal_iteration(
                target_image[:,:,0],
                train_mask_current,
                epsilon=epsilon,
                n=1,
                total_max_segments=int(np.round(total_max_segments * current_area / max_area + 1)),
                max_covered_length=max_length
            )
            # horizontal_iteration = [i for i in range(955)]
            # aux = np.array(target_mask_current)
            # try:
            #     aux[:, horizontal_iteration[:-1]] = False
            # except:
            #     aux[:, horizontal_iteration[:-2]] = False
            # check_mask.append(target_mask_current ^ aux)
        # return check_mask
        #     print(len(horizontal_iteration))

            for m in range(len(horizontal_iteration) - 1):
                for channel in range(img_volume.shape[-1]):
                    j = horizontal_iteration[m]
                    window_size = horizontal_iteration[m+1] - j
                    train_data, train_label, restore_data,\
                    target_mask_window, target_image_window = create_train_data(
                        target_image[:,:,channel],
                        training_images[:,:,:,channel],
                        target_mask_current,
                        train_mask_current,
                        j, window_size, geospatial=geospatial
                    )
                    if type(train_data) == int:
                        # print('err0')
                        continue
                    train_pool = Pool(train_data, train_label)
                    # tick = time()
                    try:
                        model.fit(train_pool)
                        restore_pool = Pool(restore_data)
                        res = model.predict(restore_pool)
                    except CatBoostError:
                        # print('Err1')
                        res = np.full((restore_data.shape[0],), np.average(train_label[0]))
                    # tack = time()
                    # print((tack - tick) * 1000)
                    # break
                    target_image_window[target_mask_window] = res
                    img_volume[target_image_number, :, j: j + window_size, channel] = target_image_window
    img_volume = np.nan_to_num(img_volume, nan=0.0)
    img_volume = np.clip(img_volume, min_value, max_value)
    return img_volume

In [None]:
not_full_restoration = False
for img_num in tqdm(range(0, 28), disable=False):
    path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/extended/input/L8_{}.npz'.format(img_num)
    img_volume = parse_image(path)
    tick1 = time()
    mask_volume = img_volume[:,:,:,0] != 0
    range_object = create_order(img_volume)
    if check_intersection(img_volume[:,:,:,0]) != 0:
        continue
    # for target_image_number in range_object:
    #     print(np.mean(NDVI(img_volume)[target_image_number][img_volume[target_image_number, :, :, 0] != 0]))
    # mask_aux = img_volume[:,:,:,0] == 0
    img_volume = NDVI(img_volume, last_channel=True)
    for i in range(len(mask_volume)):
        mask_volume[i] = create_mask(mask_volume[i])
    # imshow(mask_aux[0, :50, :10])
    # img_volume = dummy_restore_2(img_volume, mask_aux)
    # print(np.count_nonzero(a == 12))
    img_volume = dummy_restore(img_volume, mask_volume)
    tick2 = time()
    img_volume = quick_restore(
        img_volume,
        mask_volume,
        range_object,
        min_value=-1.,
        verbose=not_full_restoration
    )
    np.savez('D:/Docs/Visillect/agrofields/adc/restoration/data/extended/results_ndvi_2_quick/L8_{}.npz'.format(img_num), img_volume)
    tick3 = time()
    img_volume = restore_images(
        img_volume,
        mask_volume,
        restoration_order=range_object,
        min_value=-1.,
        verbose=not_full_restoration
    )
    tick4 = time()
    print('{:.1f} seconds total, {:.1f} seconds full restoration, {:.1f} seconds quick restoration'.format(tick4 - tick1, tick4 - tick2, tick3 - tick2))
    np.savez('D:/Docs/Visillect/agrofields/adc/restoration/data/extended/results_ndvi_2/L8_{}.npz'.format(img_num), img_volume)

In [None]:
# megamask = np.zeros((955, 955), bool)
# for msk in mask:
#     megamask = megamask ^ msk
# imshow(megamask)

In [None]:
# img_num = 0
path = 'D:/Docs/Visillect/agrofields/adc/restoration/data/extended/output/L8_{}.npz'.format(img_num)

img2 = parse_image(path)
# img2 = img2[:,:,:,0]


In [None]:
# img3 = np.array(img2)
# img3[2, :, :, 0][megamask] = 0
# # imshow(img3[1, :, :, 0])

In [None]:
# img_volume = np.clip(img_volume, 0, 1)
# chnl = 0
# for i in range(10):
#     imshow(img_volume[i,:,:,chnl], vmax=1, vmin=0)
#     # imshow(img3[i,:,:,chnl], vmax=1, vmin=0)
#     imshow(img2[i,:,:,chnl], vmax=1, vmin=0)

In [None]:
# for target_image_number in range(img_volume.shape[0]):
#     print(mean_absolute_error(img2[target_image_number].reshape(-1), img_volume[target_image_number].reshape(-1)))
#     print(mean_squared_error(img2[target_image_number, :, :, 0].reshape(-1), img_volume[target_image_number].reshape(-1)) ** (1 / 2))
#     # imsave('res{}.tif'.format(target_image_number), img_volume[target_image_number])

In [None]:
# VI_restored = NDVI(img_volume, last_channel=True)
VI_restored = img_volume
VI_good = NDVI(img2)
print(img_volume.shape, range_object)

In [None]:

for target_image_number in range(0, 10):
    print(mean_absolute_error(VI_good[target_image_number].reshape(-1), VI_restored[target_image_number, :, :].reshape(-1)))
    # print(r2_score(VI_good[target_image_number].reshape(-1), VI_restored[target_image_number, :, :].reshape(-1)))
    imshow(VI_restored[target_image_number, :, :, 0], vmin=-0.5, vmax=0.5)
    imshow(VI_good[target_image_number], vmin=-0.5, vmax=0.5)
    # imsave('res{}.tif'.format(target_image_number), VI_restored[target_image_number])

In [None]:
# mask1 = VI_restored[0, 739:769, 750:780, 0] == 0
# mask2 = dilation(mask1, disk(3))
# mask2 = mask1 ^ mask2
# visualization = np.zeros(mask1.shape)
# visualization[mask1] = 1
# visualization[mask2] = 0.5
# imshow(visualization)

In [None]:
# mask1 = VI_restored[0, :200, 700:900, 0] == 0
# aux = VI_restored[0, :200, 700:900, 0]
# mask2 = ~mask1
# mask2 = create_mask(mask2)
# visualization = np.zeros((*(aux.shape), 3))
# visualization[:, :, 0] = aux
# visualization[:, :, 1][mask2] = 1
# visualization[:, :, 0][mask2] = 1
# imshow(visualization)