In [None]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
import numpy as np
import segmentation_models as sm
import cv2
from tensorflow import keras

# Define the models with their weights paths and corresponding crop sizes
#    'densenet201_1e-05_Unet_1024_noaug': {'weights_path': './tosave/densenet201_1e-05_Unet_1024_noaug/Unet_densenet201_1e-05.h5','crop_size': 1024},

'''
    'resnet50_1e-05_Unet_1024_aug': {
        'weights_path': './tosave/resnet50_1e-05_Unet_1024_aug/Unet_resnet50_1e-05.h5',
        'crop_size': 1024
    },
    
    'resnet101_1e-05_Unet_128_noaug': {
        'weights_path': './tosave/resnet101_1e-05_Unet_128_noaug/Unet_resnet101_1e-05.h5',
        'crop_size': 128
    },
'vgg19_1e-05_Unet_128_aug': {
        'weights_path': './tosave/vgg19_1e-05_Unet_128_aug/Unet_vgg19_1e-05.h5',
        'crop_size': 128
    },
    'vgg19_1e-05_Unet_256_aug': {
        'weights_path': './tosave/vgg19_1e-05_Unet_256_aug/Unet_vgg19_1e-05.h5',
        'crop_size': 256
    },
    'vgg19_1e-05_Unet_512_aug': {
        'weights_path': './tosave/vgg19_1e-05_Unet_512_aug/Unet_vgg19_1e-05.h5',
        'crop_size': 512
'''    
models_info = {


    'vgg16_1e-05_Unet_256_noaug': {
        'weights_path': './tosave/vgg16_1e-05_Unet_256_noaug/Unet_vgg16_1e-05.h5',
        'crop_size': 256
    },


    },
    'vgg19_1e-05_Unet_512_noaug': {
        'weights_path': './tosave/vgg19_1e-05_Unet_512_noaug/Unet_vgg19_1e-05.h5',
        'crop_size': 512
    }
}
# Define input/output directories
input_images_dir = './v2_val/images'
output_dir = './mv_predictions'
os.makedirs(output_dir, exist_ok=True)

# Function to crop images with overlap if not multiple of crop size
def crop_image_with_overlap(image, crop_size):
    height, width = image.shape[:2]
    crops = []
    for y in range(0, height, crop_size):
        for x in range(0, width, crop_size):
            x_end = min(x + crop_size, width)
            y_end = min(y + crop_size, height)
            crop = image[y:y_end, x:x_end]
            crops.append((crop, x, y))  # Store the crop and its position
    return crops

# Function to stitch predictions together into one image
def stitch_predictions(predictions, full_image_size):
    full_height, full_width = full_image_size
    stitched_image = np.zeros((full_height, full_width), dtype=np.uint8)
    
    for pred, x, y in predictions:
        h, w = pred.shape

        # Ensure pred fits into the image by adjusting the size if it exceeds boundaries
        h = min(h, full_height - y)
        w = min(w, full_width - x)

        stitched_image[y:y+h, x:x+w] = np.maximum(stitched_image[y:y+h, x:x+w], pred[:h, :w])
    
    return stitched_image

# Loop through each model
for model_name, model_info in models_info.items():
    print(f"Processing with model: {model_name}")
    
    # Load the model
    BACKBONE = model_name.split('_')[0]
    preprocess_input = sm.get_preprocessing(BACKBONE)
    model = sm.Unet(BACKBONE, classes=1, activation='sigmoid')
    model.load_weights(model_info['weights_path'])

    crop_size = model_info['crop_size']

    # Loop through images in the input folder
    for image_file in os.listdir(input_images_dir):
        image_path = os.path.join(input_images_dir, image_file)
        image = cv2.imread(image_path)
        original_size = image.shape[:2]
        print(f"Processing image: {image_file}, size: {original_size}, crop size: {crop_size}")
        
        predictions = []

        # Preprocess image
        preprocessed_image = preprocess_input(image)

        # Crop image into sections
        crops = crop_image_with_overlap(preprocessed_image, crop_size)

        # Loop through each crop and predict
        for crop, x, y in crops:
            resized_crop = cv2.resize(crop, (crop_size, crop_size))
            resized_crop = np.expand_dims(resized_crop, axis=0)
            prediction = model.predict(resized_crop)
            prediction = (prediction > 0.5).astype(np.uint8) * 255
            predictions.append((prediction[0, :, :, 0], x, y))

        # Stitch predictions together
        stitched_image = stitch_predictions(predictions, original_size)

        # Save the stitched image
        output_path = os.path.join(output_dir, f"{model_name}_{image_file}_pred_{crop_size}.png")
        cv2.imwrite(output_path, stitched_image)

        print(f"Saved prediction for {model_name} at {output_path}")

print("All predictions saved.")


Segmentation Models: using `tf.keras` framework.
Processing with model: resnet101_1e-05_Unet_128_noaug
Processing image: image16.tif, size: (5000, 5000), crop size: 128
Saved prediction for resnet101_1e-05_Unet_128_noaug at ./mv_predictions\resnet101_1e-05_Unet_128_noaug_image16.tif_pred_128.png
Processing image: image17.tif, size: (5000, 5000), crop size: 128
Saved prediction for resnet101_1e-05_Unet_128_noaug at ./mv_predictions\resnet101_1e-05_Unet_128_noaug_image17.tif_pred_128.png
Processing image: image7.tif, size: (5000, 5000), crop size: 128
Saved prediction for resnet101_1e-05_Unet_128_noaug at ./mv_predictions\resnet101_1e-05_Unet_128_noaug_image7.tif_pred_128.png
Processing image: image9.tif, size: (5000, 5000), crop size: 128
Saved prediction for resnet101_1e-05_Unet_128_noaug at ./mv_predictions\resnet101_1e-05_Unet_128_noaug_image9.tif_pred_128.png
Processing with model: resnet50_1e-05_Unet_1024_aug
Processing image: image16.tif, size: (5000, 5000), crop size: 1024
Saved 