In [None]:
# this script applies the trained model for image segmentation

In [None]:
import tensorflow as tf
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import SimpleITK as sitk

In [None]:
tf.__version__

In [None]:
print("Number of Available GPUs: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
# change these please
batch_size = 16
image_size = 256
start_neurons = 32

# load the trained model
model_path = '../trained_models/weed_detector_batch_size_{}_image_size_{}_start_neurons_{}_best_loss.h5'.format(batch_size, image_size, start_neurons)
model = tf.keras.models.load_model(model_path)


In [None]:
# read the input tif image, change here please
image_name = './pigweed_21_JUL_2023_ortho_rgb_Clipped.tif'
weed_img = sitk.ReadImage(image_name)

In [None]:
weed = sitk.GetArrayFromImage(weed_img)[:,:,0:3]
weed[weed==0] = 255
weed = cv2.cvtColor(weed, cv2.COLOR_BGR2RGB)

In [None]:
height, width, channel = weed.shape

In [None]:
patch_size = image_size
grid_size_x = int(height/patch_size) - 1
grid_size_y = int(width/patch_size ) - 1
mask = np.zeros((height, width))

In [None]:
# make prediction by patch
n = 0
for i in range(grid_size_x):
    for j in range(grid_size_y):
        patch_image = weed[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size, :]
        
        if np.mean(patch_image) <200:
            slice_2d = patch_image/255.0
            
            slice_2d = np.expand_dims(slice_2d, axis=0)
            seg_2d = np.squeeze(model(slice_2d).numpy())
            mask[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = seg_2d

In [None]:
# save prediction as png
cv2.imwrite(image_name.replace('.tif', '_predicted_weed_mask.png'),255*(mask>0.5))

In [None]:
# convert prediction from png to tif format, with projected coordinate system same as the input tif image for mapping in GIS
import os
import numpy as np
import rasterio
from PIL import Image

def create_tif_from_png(tif_path, png_path, output_tif_path):
    # Read the metadata from the .tif file
    with rasterio.open(tif_path) as tif_src:
        meta = tif_src.meta

    # Increase the image size limit for PIL
    Image.MAX_IMAGE_PIXELS = 1 << 31

    # Read the .png image using PIL
    png_image = Image.open(png_path)
    png_image = np.array(png_image)

    # Check if the image is read successfully
    if png_image is None:
        print(f"Failed to read '{png_path}'.")
        return

    # Update the metadata to match the number of bands in the .png image
    meta['count'] = png_image.shape[2] if len(png_image.shape) == 3 else 1

    # Save the .png image as a .tif image with the same metadata
    with rasterio.open(output_tif_path, 'w', **meta) as dst:
        if meta['count'] == 1:
            dst.write(png_image, 1)
        else:
            for band in range(meta['count']):
                dst.write(png_image[:, :, band], band + 1)

if __name__ == "__main__":
    # Set the input .tif file path
    input_tif = "./pigweed_21_JUL_2023_ortho_rgb_Clipped.tif"

    # Set the input .png file path
    input_png = "./pigweed_21_JUL_2023_ortho_rgb_Clipped_predicted_weed_mask.png"

    # Set the output .tif file path
    output_tif = "./pigweed_21_JUL_2023_ortho_rgb_Clipped_predicted_weed_mask.tif"

    # Check if the input files exist
    if not os.path.isfile(input_tif) or not os.path.isfile(input_png):
        print(f"Input file '{input_tif}' or '{input_png}' not found.")
    else:
        # Create the .tif image with the same physical coordinate system as the original .tif
        create_tif_from_png(input_tif, input_png, output_tif)
        print(f"Created '{output_tif}' with the same physical coordinate system as '{input_tif}'.")
