In [1]:
import os
import numpy as np
from osgeo import gdal
from scipy.interpolate import interp1d
from multiprocessing import Pool
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import LSTM
import time
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
os.environ['PROJ_LIB'] = '/home/ubuntu/anaconda3/envs/icm/share/proj/'

In [2]:
# Function to stack all bands
def stack_allbands(list_of_images):
    image_list = ()
    for i in list_of_images:
        print(i)
        image = gdal.Open(i)
        img_arr = image.ReadAsArray()
        print(img_arr.shape)
        image_list = image_list + (img_arr,)
    stack_arr = np.vstack(image_list)
    stack_arr = np.transpose(stack_arr, (1, 2, 0))
    image = None
    img_arr = None
    return stack_arr

In [3]:
# Set the band order
bands = ['B3','B4','B5','B6','B7','B8','B11','B12','B8A','BSI','GCVI','LSWI','NDRE','NDVI','SAVI']
li = []
in_directory = "/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/"
for b in bands: 
    in_file = os.path.join(in_directory, f'interpolated_{b}.tif')
    li.append(in_file)

In [4]:
# stack them all
stack_arr = stack_allbands(li)

/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B3.tif




(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B4.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B5.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B6.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B7.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B8.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B11.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B12.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_B8A.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_BSI.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bungoma/interpolation/interpolated_GCVI.tif
(13, 8115, 7855)
/home/ubuntu/crop_monitoring/Kenya/Bu

In [5]:
#get the shape of the stacked array
shp = stack_arr.shape
shp

(8115, 7855, 195)

In [6]:
def predict_with_patches(imgarray, model_path, patchsize=1024):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LSTM.LSTM(input_dim=20, num_classes=7, hidden_dims=256, num_layers=2, dropout=0.3, bidirectional=True, use_layernorm=True)
    model.load_state_dict(torch.load(model_path, map_location=device))

    print(f"Performing {patchsize} X {patchsize} patch inferencing ...")
    
    nrows, ncols, nbands = imgarray.shape
    empty_array = np.zeros((nrows, ncols))
    
    # Splitting image into patches using patch size (default = 200) # write four conditions
    last_patch_x = math.ceil(nrows / patchsize)
    last_patch_y = math.ceil(ncols / patchsize)
    
    for i in range(last_patch_x):
        x_end = (i + 1) * patchsize
        if i == last_patch_x-1:
            x_end = -1
        for j in range(last_patch_y):
            start_time = time.time()
            y_end = (j + 1) * patchsize
            if j == last_patch_y-1:
                y_end = -1
            patch_ = imgarray[i * patchsize:x_end, j * patchsize:y_end, :]
            shp = patch_.shape
            print(shp)
            input_array = patch_.reshape(shp[0]*shp[1], shp[2])
            input_array = input_array.reshape(input_array.shape[0], 13, 15)
            with torch.no_grad():
                input_tensor = torch.tensor(input_array, dtype=torch.float32).to(device)
                logits = model(input_tensor)
                print(torch.argmax(logits, dim=1))
                predicted_array = torch.argmax(logits, dim=1)
    
            predicted_array = predicted_array.reshape(shp[0], shp[1])
            empty_array[i * patchsize:x_end, j * patchsize:y_end] = predicted_array
            end_time = time.time()
            est_time = (end_time - start_time)/60
            print(f"Prediction time for a patch: {est_time:.2f} min")
            
    return empty_array

In [7]:

# def predict_with_patches_parallel(imgarray, model_path, patchsize=256, num_workers=20):

#     # Set up a logger
#     logger = logging.getLogger("PatchPrediction")
#     logger.setLevel(logging.INFO)

#     formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")

#     # Log to console
#     ch = logging.StreamHandler()
#     ch.setFormatter(formatter)
#     logger.addHandler(ch)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = LSTM.LSTM(input_dim=15, num_classes=6, hidden_dims=128, num_layers=5, dropout=0.5713020228087161, bidirectional=True, use_layernorm=True)
#     model.load_state_dict(torch.load(model_path, map_location=device))

#     logger.info(f"Performing {patchsize} X {patchsize} patch inferencing in parallel ...")
    
#     nrows, ncols, nbands = imgarray.shape
#     empty_array = np.zeros((nrows, ncols))
    
#     # Splitting image into patches using patch size (default = 200)
#     last_patch_x = math.ceil(nrows / patchsize)
#     last_patch_y = math.ceil(ncols / patchsize)

#     # Function to process a single patch
#     def process_patch(i, j):
#         x_end = (i + 1) * patchsize if i != last_patch_x - 1 else nrows
#         y_end = (j + 1) * patchsize if j != last_patch_y - 1 else ncols
#         patch_ = imgarray[i * patchsize:x_end, j * patchsize:y_end, :]
#         shp = patch_.shape
#         input_array = patch_.reshape(shp[0] * shp[1], shp[2])
#         input_array = input_array.reshape(input_array.shape[0], 13, 15)
#         with torch.no_grad():
#             input_tensor = torch.tensor(input_array, dtype=torch.float32).to(device)
#             logits = model(input_tensor)
#             pred_probs = F.softmax(logits, dim=1)
#             _, predicted_array = torch.max(pred_probs, dim=1)
#         predicted_array = predicted_array.reshape(shp[0], shp[1])
#         return i, j, predicted_array

#     # Use ThreadPoolExecutor for parallel processing
#     with ThreadPoolExecutor(max_workers=num_workers) as executor:
#         future_to_patch = {
#             executor.submit(process_patch, i, j): (i, j) for i in range(last_patch_x) for j in range(last_patch_y)
#         }
#         for future in as_completed(future_to_patch):
#             i, j, predicted_array = future.result()
#             x_end = (i + 1) * patchsize if i != last_patch_x - 1 else nrows
#             y_end = (j + 1) * patchsize if j != last_patch_y - 1 else ncols
#             empty_array[i * patchsize:x_end, j * patchsize:y_end] = predicted_array

#             # Log the progress
#             logger.info(f"Processed patch ({i}, {j})")

#     logger.info("Patch inferencing completed.")
#     return empty_array


In [8]:
model_path = 'LSTM_model_10ep.pt'
predicted_arr = predict_with_patches(stack_arr, model_path, patchsize=512)

Performing 512 X 512 patch inferencing ...
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min


tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 174, 195)
tensor([0, 0, 0,  ..., 0, 0, 0])
Prediction time for a patch: 0.07 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 4])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 0, 0,  ..., 1, 1, 1])
Prediction

tensor([1, 1, 1,  ..., 4, 4, 4])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([1, 1, 1,  ..., 4, 4, 4])
Prediction time for a patch: 0.23 min
(512, 174, 195)
tensor([1, 1, 1,  ..., 4, 4, 4])
Prediction time for a patch: 0.08 min
(512, 512, 195)
tensor([4, 4, 4,  ..., 1, 1, 1])
Prediction time for a patch: 0.24 min
(512, 512, 195)
tensor([1, 1, 1,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([5, 5, 5,  ..., 0, 0, 0])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([5, 1, 5,  ..., 1, 0, 4])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([1, 1, 1,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([0, 5, 5,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([4, 4, 5,  ..., 0, 0, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([4, 5, 5,  ..., 1, 1, 1])
Prediction time for a patch: 0.23 min
(512, 512, 195)
tensor([4, 5, 5,  ..., 1, 1, 1])
Prediction

In [9]:
def save_image(array, ref_img, outfile, data_type=gdal.GDT_Byte):
    
    # metadata
    meta_image = gdal.Open(ref_img)
    trans = meta_image.GetGeoTransform()
    proj = meta_image.GetProjection()

        
    # save classified image as tif image
    outdriver = gdal.GetDriverByName("GTiff")
    shp = array.shape
    outdata = outdriver.Create(outfile, shp[1], shp[0], 1, data_type)
    outdata.GetRasterBand(1).WriteArray(array)

    
    outdata.SetGeoTransform(trans)
    outdata.SetProjection(proj)
    del outdata
    

In [10]:
output_image = 'Kenya/Bungoma/classified.tif'
save_image(predicted_arr, li[0], output_image)

In [None]:

def predict_with_patches(imgarray, model, patchsize=1024):


    print(f"Performing parallel prediction CPUs...")
    
    nrows, ncols, nbands = imgarray.shape
    empty_array = np.zeros((nrows, ncols))
    
    # Splitting image into patches using patch size (default = 200) # write four conditions
    last_patch_x = math.ceil(nrows / patchsize)
    last_patch_y = math.ceil(ncols / patchsize)
    
    for i in range(last_patch_x):
        x_end = (i + 1) * patchsize
        if i == last_patch_x-1:
            x_end = -1
        for j in range(last_patch_y):
            start_time = time.time()
            y_end = (j + 1) * patchsize
            if j == last_patch_y-1:
                y_end = -1
            patch_ = imgarray[i * patchsize:x_end, j * patchsize:y_end, :]
            shp = patch_.shape
            print(shp)
            input_array = patch_.reshape(shp[0]*shp[1], shp[2])
            predicted_array = model.predict(input_array)
            predicted_array = predicted_array.reshape(shp[0], shp[1])
            empty_array[i * patchsize:x_end, j * patchsize:y_end] = predicted_array
            end_time = time.time()
            est_time = (end_time - start_time)/60
            print(f"Prediction time for a chunk {est_time:.2f} min.")
            
    return empty_array


In [None]:
import lightgbm as lgb
import joblib
model = joblib.load('Kenya_model_lgbm.pkl')
predicted_arr = predict_with_patches(stack_arr, model, patchsize=1024)

In [None]:
output_image = 'Uasin_Gishu/results/classified_0000008192-0000008192.tif'
save_image(predicted_arr, li[0], output_image)

In [None]:
import math
math.ceil(8192 / 1024)

In [None]:
li[0]

In [None]:
import os
os.cpu_count()