In [9]:
####### Inputs #######
model = "UnetPlusPlus"
backbone = "vgg19"

In [None]:

model_list = ["UnetPlusPlus", "Unet", "MAnet", "Linknet", "PSPNet", "FPN"]
backbone_list = ["vgg16", "vgg19", "resnet50", "resnet101", "resnet152", "mobilenet_v2", "efficientnet-b4"]

from __future__ import absolute_import, division, print_function
from IPython.display import clear_output
import time
from datetime import timedelta
from sklearn.model_selection import KFold
import re
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import cv2
import numpy as np
import numpy.ma as ma
from numpy import ndarray
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.transforms.functional import pad

import random
import csv

from torchinfo import summary
import segmentation_models_pytorch as smp


# Takes in a .png and returns a 2D numpy array of 0's and 1's

def image2nparray(image_file):
    image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
    image_mask = ma.make_mask(image, copy=True)
    array_out = np.array(image_mask, dtype=int)
    return array_out

# Takes a numpy array prediction from the AI and returns png

def nparray2image(nparray, filename, directory):
    name = filename.split('.')
    image_name = name[0] + "_pred.png"
    
    path = os.path.join(directory, image_name)
    
    cv2.imwrite(path, nparray*255)
    image_out = cv2.imread(path)
    return image_out

# Concatenates and saves two images:
    # 1. Takes prediction image and overlays it on the raw image
    # 2. Takes manually labeled image and overlays it on the raw image

def pred2comp(pred, raw, labeled, filename, directory):

    # print("raw shape:", raw.shape)
    # print("pred shape:", pred.shape)
    # print("labeled shape:", labeled.shape)

    pred_overlay = cv2.addWeighted(raw, 0.8, pred, 0.8, 0.0)

    orig_overlay = cv2.addWeighted(raw, 0.8, labeled, 0.8, 0.0)
    
    height = np.shape(raw)[0]
    buffer = np.ones((height, 5, 3), dtype=np.uint8)*255
    combined = np.hstack((raw, buffer, orig_overlay, buffer, pred_overlay)) 

    name = filename.split('.')
    image_name = name[0] + "_comp.png"
    path = os.path.join(directory, image_name)

    cv2.imwrite(path, combined)

    return combined


# Creates the training and testing data sets for our three different methods which are 'Control', 'Random', and 'Triple'
# Method to run is selected in the next cell 

class TrainingDataset(Dataset):
    def __init__(self, raw_folder, label_folder):
        if model_to_run == 'Triple':     
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            self.jitter = transforms.ColorJitter(brightness = 0.25, contrast = 0.4)
            self.flip = transforms.RandomHorizontalFlip(p=1.0)
            self.to_tensor = transforms.ToTensor()
 
        elif model_to_run == 'Random':
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            self.jitter = transforms.ColorJitter(brightness = 0.25, contrast = 0.4)
            self.flip = transforms.RandomHorizontalFlip(p=1.0)
            self.combined = transforms.Compose([
                transforms.ColorJitter(brightness = 0.25, contrast = 0.4),
                transforms.RandomHorizontalFlip(p=1.0)])
            self.to_tensor = transforms.ToTensor()
        
        elif model_to_run == 'Control':
            self.raw_images_list = os.listdir(raw_folder)
            self.raw_images_dir = raw_folder
            self.labeled_images = os.listdir(label_folder)
            self.labeled_images_dir = label_folder

            # Creating transform attributes
            #self.raw_normalize = transforms.Normalize(mean = [0.0839, 0.0857, 0.0868], std = [0.1734, 0.1740, 0.1746])
            # (mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) this is the original
            # Corrected means and standard deviations are: mean=[0.0839, 0.0857, 0.0868], std=[0.1734, 0.1740, 0.1746]
            self.to_tensor = transforms.ToTensor()
            
    def __len__(self):
        # size (length) of the dataset
        if model_to_run == 'Triple':
            return (len(self.raw_images_list)) *3
        else:
            return len(self.raw_images_list)
        

    def __getitem__(self, index):
        if model_to_run == 'Triple': 
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details
            true_length = len(self.raw_images_list)
            category = index // true_length
            item_data = {}

            if category == 0:
                # proceed as normal, so original image
                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                item_data['categoryname'] = orig_filename
            
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = image2nparray(label_dir)

            elif category == 1:
                # Apply color jitter
                index = index % true_length

                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_coljit.png"
                item_data['categoryname'] = new_name
        
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                raw_image_data = self.jitter(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = image2nparray(label_dir)

            else:
                # RandomHorizontalFlip
                index = index % true_length

                orig_filename = self.raw_images_list[index]
                item_data['filename'] = orig_filename
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_hflip.png"
                item_data['categoryname'] = new_name
            
                raw_image = os.path.join(self.raw_images_dir, orig_filename)
                raw_image_data = cv2.imread(raw_image)
                raw_image_data = self.to_tensor(raw_image_data)
                raw_image_data = self.flip(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)

                label_dir = os.path.join(self.labeled_images_dir, orig_filename)
                label_image_data = cv2.imread(label_dir)
                label_image_data = self.to_tensor(label_image_data)
                label_image_mask = image2nparray(label_dir)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)

                # For Debugging
                # assert item_data['labeled'].shape == item_data['raw'][:, :, 0].shape, "The label width and height does not match the raw image"
        
        elif model_to_run == 'Random': 
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details
            random_number = random.random()

            item_data = {}
            orig_filename = self.raw_images_list[index]
            item_data['filename'] = orig_filename
        
            raw_image = os.path.join(self.raw_images_dir, orig_filename)
            raw_image_data = cv2.imread(raw_image)
            raw_image_data = self.to_tensor(raw_image_data)

            label_dir = os.path.join(self.labeled_images_dir, orig_filename)
            label_image_data = cv2.imread(label_dir)
            label_image_data = self.to_tensor(label_image_data)
            label_image_mask = image2nparray(label_dir)

            # If statement to randomly apply transforms
            if random_number < 0.25:
                # Original/ no transforms
                item_data['categoryname'] = orig_filename
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = label_image_mask
            elif random_number < 0.50:
                # ColorJitter only
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_coljit.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.jitter(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = label_image_mask
            elif random_number < 0.75:
                # HorizontalFlip only
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_hflip.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.flip(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)
            else:
                # Both
                start_name = orig_filename.split('.')
                new_name = start_name[0] + "_both.png"
                item_data['categoryname'] = new_name
                raw_image_data = self.combined(raw_image_data)
                item_data['rawimg'] = raw_image_data
                item_data['raw'] = raw_image_data
                # item_data['raw'] = self.raw_normalize(raw_image_data)
                label_image_mask = self.to_tensor(label_image_mask)
                label_image_data = self.flip(label_image_data)
                label_image_mask = self.flip(label_image_mask)
                item_data['labelimg'] = label_image_data
                item_data['labeled'] = torch.squeeze(label_image_mask)
        elif model_to_run == 'Control':
            # Select the "index"th item from the dataset
            # Will return the item in the row "index" in self.details

            item_data = {}
            orig_filename = self.raw_images_list[index]
            item_data['filename'] = orig_filename
        
            raw_image = os.path.join(self.raw_images_dir, orig_filename)
            raw_image_data = cv2.imread(raw_image)
            raw_image_data = self.to_tensor(raw_image_data)     
            item_data['rawimg'] = raw_image_data
            item_data['raw'] = raw_image_data
            # item_data['raw'] = self.raw_normalize(raw_image_data)  

            label_dir = os.path.join(self.labeled_images_dir, orig_filename)
            label_image_data = cv2.imread(label_dir)
            item_data['labelimg'] = self.to_tensor(label_image_data)
            item_data['labeled'] = image2nparray(label_dir)
        
            # For Debugging
            # assert item_data['labeled'].shape == item_data['raw'][:, :, 0].shape, "The label width and height does not match the raw image"
        
        # print("Shape of labeled:", item_data['labeled'].shape)
        # print("Shape of labelimg:", item_data['labelimg'].shape)
        # print("Shape of raw:", item_data['raw'].shape)
        # print("Shape of rawimg:", item_data['rawimg'].shape)
        


            # Get current sizes
            h, w = item_data['labeled'].shape  # labeled is still 2D at this point

            # Compute required padding
            pad_h = (32 - (h % 32)) % 32
            pad_w = (32 - (w % 32)) % 32

            # Split padding between top/bottom and left/right
            pad_top = pad_h // 2
            pad_bottom = pad_h - pad_top
            pad_left = pad_w // 2
            pad_right = pad_w - pad_left

            # Save padding info for later uncropping
            item_data['pad'] = {
                'top': pad_top,
                'bottom': pad_bottom,
                'left': pad_left,
                'right': pad_right,
                'orig_hw': (h, w)  # optional but handy for sanity checks
            }
            # Apply padding to label (NumPy)
            item_data['labeled'] = np.pad(item_data['labeled'],
                                        ((pad_top, pad_bottom), (pad_left, pad_right)),
                                        mode='constant')
            item_data['labeled'] = np.expand_dims(item_data['labeled'], axis=0)

            # Apply padding to raw (Torch tensor, using F.pad)
            item_data['raw'] = pad(item_data['raw'],
                                (pad_left, pad_top, pad_right, pad_bottom),
                                padding_mode='constant')
            

        
        return item_data
    
def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)



# Function to load the model
def load_model(model_path, device):
    model = torch.load(model_path, map_location=device)
    model.eval()
    return model

# Function to preprocess the image
def preprocess_image(image_path, device):
    image = Image.open(image_path).convert("RGB")
    
    preprocess = transforms.ToTensor()
    image = preprocess(image)

    _, H, W = image.shape

    # --- compute minimal padding to make H,W divisible by 32 ---
    pad_h = (32 - (H % 32)) % 32
    pad_w = (32 - (W % 32)) % 32

    pad_top    = pad_h // 2
    pad_bottom = pad_h - pad_top
    pad_left   = pad_w // 2
    pad_right  = pad_w - pad_left

    pad_info = {
        "top": int(pad_top),
        "bottom": int(pad_bottom),
        "left": int(pad_left),
        "right": int(pad_right),
        "orig_hw": (int(H), int(W))
    }

    # --- apply padding only if needed; F.pad order: (left, right, top, bottom) ---
    if pad_h or pad_w:
        image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)

    # add batch dimension and move to device
    image = image.unsqueeze(0).to(device)  # (1,C,H',W')
    
    return image.to(device), pad_info

# Function to postprocess the output and save the segmentation result
def postprocess_and_save(output, save_path, pad_info=None):
    pred = output.sigmoid().cpu().numpy().squeeze()
    pred = (pred > 0.5).astype(np.uint8)  # Convert to binary mask
    # --- crop using dynamic padding info (if provided) ---
    if pad_info is not None:
        pt = int(pad_info.get('top', 0))
        pb = int(pad_info.get('bottom', 0))
        pl = int(pad_info.get('left', 0))
        pr = int(pad_info.get('right', 0))

        h, w = pred.shape
        h_end = h - pb if pb > 0 else h
        w_end = w - pr if pr > 0 else w
        pred = pred[pt:h_end, pl:w_end]

    pred_image = Image.fromarray(pred * 255)  # Convert to image
    pred_image.save(save_path)

# Main function to segment images
def segment_images(model_path, input_folder, output_folder, device):
    os.makedirs(output_folder, exist_ok=True)
    model = load_model(model_path, device)
    
    image_files = [f for f in os.listdir(input_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
    total_images = len(image_files)
    print(f"Total test images: {total_images}")

    for image_name in image_files:
        image_path = os.path.join(input_folder, image_name)
        save_path = os.path.join(output_folder, image_name)
        image, pad_info = preprocess_image(image_path, device)

        with torch.no_grad():
            output = model(image)
            postprocess_and_save(output, save_path, pad_info)
    
    print("Segmentation finished")


model_classes = {
    "FPN": smp.FPN,
    "Unet": smp.Unet,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "PSPNet": smp.PSPNet,
    "UnetPlusPlus": smp.UnetPlusPlus
}

NNModel = model_classes[model] # The class of model we are using
nn_name = NNModel.__name__ # gets the name of the nn class as a str
backbone = backbone # Encoder model to serve as the backbone of our ml model
print("Training", nn_name, "with backbone", backbone)
model_to_run = 'Control'
base_model_path = "Models/" + nn_name.lower() + "_" + backbone  + "/"
model_path = base_model_path + "net.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Find all the subfolders
# Define the main path
main_path = "PredictionDataset/"


# List to store paths of 'Raw' folders
raw_folders = []

# Walk through the directory tree
for root, dirs, files in os.walk(main_path):
    # Check if the current directory name is "Raw"
    if os.path.basename(root) == "Raw":
        raw_folders.append(root)


# List of addresses for prediction paths
pred_paths = [os.path.join(os.path.dirname(folder), "Segmentation/") for folder in raw_folders]

# Loop over the lists
for i in range(0, len(raw_folders)):
    input_folder = raw_folders[i]
    output_folder = pred_paths[i]

    # Ensure the prediction path exists or create it if necessary
    os.makedirs(output_folder, exist_ok=True)
    segment_images(model_path, input_folder, output_folder, device)
    print(f"Processing folder {i + 1} of {len(raw_folders)}")

