In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import argparse
import matplotlib.pyplot as plt
import albumentations
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils as smp_utils
import torch
import torch.nn.functional as F
from tqdm import tqdm as tqdm
import sys
from torch.utils.data import DataLoader
import math
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import glob
import rasterio
import json
import re
import tifffile
from skimage import measure
from rasterio.features import shapes
torch.__version__

In [2]:
#////////////////////////Pre Processing\\\\\\\\\\\\\\\\\\\\\\\\
def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform

    Args:
        preprocessing_fn (callbale): data normalization function
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose

    """

    _transform = [
        albumentations.Lambda(image=preprocessing_fn),
    ]
    return albumentations.Compose(_transform)



In [2]:
#//////////////////////// Evaluate on Large Tiff File(s) \\\\\\\\\\\\\\\\\\\\\\\\
#input_imgs = os.path.join(data_dir, 'tiff_images') # Input the path to tiff images
input_imgs = r"C:\Users\AhmadWaseem\Desktop\bhalli\PD-Seg Work\datasets for training - 20Z\tiff files\2017\2017 reproject.tif"
model_path = r"C:\Users\AhmadWaseem\Desktop\bhalli\PD-Seg Work\model v6 - DL 20Z - TS1024\model h5"
output_dir = r"C:\Users\AhmadWaseem\Desktop\bhalli\PD-Seg Work\model v6 - DL 20Z - TS1024\output"

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'

target_size = (1024, 1024)
padding_pixels = (100, 100)
padding_value = 0
downsampling_factor = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PIX_VALUE_MAX = 255    # The max data value we have
PIX_VALUE_MAX_REQ = 255 # The max data value we need

model = torch.load(os.path.join(model_path, 'best_model.h5'), map_location=DEVICE)
model.eval()
imgs = [file for file in glob.glob(input_imgs) if file.endswith('.tif')]
print(imgs)
assert len(imgs) > 0, "The number of images equal to zero"

#num_processes = 1
print("Running on {} images".format(len(imgs)))# using {} parallel processes".format(len(imgs), num_processes))

# pos = len(input_imgs.split('*')[0]) #Related to naming of output file

#args = [[img, DEVICE, model, args.output_dir] for img in imgs]

'''
if num_processes > 1:
    p = mlt.Pool(num_processes)
    (p.map(process, args))
    p.close()
else:
    for arg in args:
        process(arg)'''

#img_path, device, model, output_dir, pos = args[0]

for img_path in imgs:
    #img_path = imgs[0]
    file_name = os.path.split(img_path)[-1].split('.')[0]
    print("Running for {}".format(file_name))

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    save_path = os.path.join(output_dir, file_name + "_preds.npy")
    img = np.transpose(rasterio.open(img_path).read(), (1, 2, 0))
#     max_value = np.iinfo(img.dtype).max
    print("Actual Image Size: {}".format(img.shape))

    #Define k_x, k_y to define 'useful' portion, since we are taking patches with overlapping area.
    k_y = target_size[0] - 2 * padding_pixels[0]
    k_x = target_size[1] - 2 * padding_pixels[1]

    # First padding: To make divisible by k
    cols = (math.ceil(img.shape[0]/k_y))
    rows = (math.ceil(img.shape[1]/k_x))

    pad_bottom = cols*k_y - img.shape[0]   #pixels to add in y direction
    pad_right = rows*k_x - img.shape[1]    #pixels to add in x direction
    if pad_bottom > 0 or pad_right > 0:
        print("Running cv2 padding..")
        img = cv2.copyMakeBorder(img, 0, pad_bottom, 0, pad_right, cv2.BORDER_CONSTANT, value=padding_value)
    print("Image Size after making divisible by ({}, {}): {}".format(k_x, k_y, img.shape))

    output_image = np.zeros((int(img.shape[0]*downsampling_factor), int(img.shape[1]*downsampling_factor)), dtype=np.uint8) * 255
    print("Size of output image after downsampling factor of {}: {}".format(downsampling_factor, output_image.shape))

    # Second Padding: To add boundary padding pixels
    img = cv2.copyMakeBorder(img, padding_pixels[0], padding_pixels[0], padding_pixels[1], 
                             padding_pixels[1], cv2.BORDER_CONSTANT, value=padding_value)
    print("Image Size after adding ({}, {}) boundary pixels: {}".format(padding_pixels[0], padding_pixels[0], img.shape))

    # Load pre-processing function
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
    preprocessing = get_preprocessing(preprocessing_fn)

    total_patches = rows*cols
    print("Total {} patches for the given image {}".format(rows*cols, file_name))
    for y_idx in range(cols):
        y1 = y_idx*k_y + padding_pixels[0]
        y2 = y1 + k_y
#         if y_idx <= 0:
#             continue
        for x_idx in range(rows):
            x1 = x_idx*k_x + padding_pixels[1]
            x2 = x1 + k_x
            patch_number = y_idx*rows + x_idx + 1

            img_crop = img[y1-padding_pixels[0]: y2 + padding_pixels[0], x1 - padding_pixels[1]: x2 + padding_pixels[1]]
            print("Patch {} of {}: [{}:{}, {}:{}]".format(patch_number, total_patches, y1-padding_pixels[0],
                                                          y2 + padding_pixels[0], x1 - padding_pixels[1],
                                                          x2 + padding_pixels[1]), end =" ")
            
            img_crop = ((img_crop/PIX_VALUE_MAX)*(PIX_VALUE_MAX_REQ)).astype(np.uint8)
#             plt.subplot(1, 2, 1)
#             plt.imshow(img_crop)
#             print(img_crop.max())
            sample = preprocessing(image=img_crop)
            image = cv2.resize(sample['image'], 
                               (int(downsampling_factor * target_size[0]),
                                int(downsampling_factor * target_size[1])),
                               interpolation = cv2.INTER_AREA)
#             plt.subplot(1, 2, 2)
#             plt.imshow(denormalize(image))
#             plt.show()
#             print(image.max())
            x_tensor = torch.Tensor(image).permute(2, 0, 1).to(DEVICE).unsqueeze(0)
#             image = np.transpose(image, (2, 0, 1))#.astype('float32')
#             x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
            with torch.no_grad():
                pred_mask = model(x_tensor)
            pr_mask = pred_mask.squeeze()
            pr_mask = pr_mask.detach().squeeze().cpu().numpy().round()

            patch = pr_mask[int(downsampling_factor * padding_pixels[0]) : int(downsampling_factor * (target_size[0] - padding_pixels[0])),
                            int(downsampling_factor * padding_pixels[1]) : int(downsampling_factor * (target_size[1] - padding_pixels[1]))]
#             print(",Output: [{}:{}, {}:{}]".format(int(downsampling_factor*(y_idx*k_y)), int(downsampling_factor*(y_idx*k_y + k_y)),
#                                                    int(downsampling_factor*(x_idx*k_x)), int(downsampling_factor*(x_idx*k_x + k_x)),
#                                                    end =" "))
            output_image[int(downsampling_factor*(y_idx*k_y)): int(downsampling_factor*(y_idx*k_y + k_y)),
                           int(downsampling_factor*(x_idx*k_x)): int(downsampling_factor*(x_idx*k_x + k_x))] = patch
            print("..... Done!")
#         break


output_image = output_image[:output_image.shape[0] - int(downsampling_factor * pad_bottom),
                            :output_image.shape[1] - int(downsampling_factor * pad_right)]
print("Final shape of downsampled output image: {}".format(output_image.shape))

np.save(save_path, output_image)

img_array = np.load(r'C:\Users\AhmadWaseem\Desktop\bhalli\PD-Seg Work\model v6 - DL 20Z - TS1024\output\2017 reproject_preds.npy')
tifffile.imwrite(r'C:\Users\AhmadWaseem\Desktop\bhalli\PD-Seg Work\model v6 - DL 20Z - TS1024\output\2017 reproject_preds.tiff',img_array)
print("Completed!")