<a href="https://colab.research.google.com/github/liangchow/zindi-amazon-secret-runway/blob/dylan/unet_starter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Read this first (10-15 minutes read)
# https://amaarora.github.io/posts/2020-09-13-unet.html

# Then, read these:
# https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/
# https://www.kaggle.com/code/quadeer15sh/how-to-perform-semantic-segmentation-using-u-net/
# https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch#Class-Distribution
# https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/

In [1]:
!pip install rasterio geopandas numpy matplotlib



In [104]:
# Import packages
import torch
import os
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd

In [105]:
# Set path
path = os.path.join("/content", "drive", "MyDrive", "training")

image_path = os.path.join(path, "images")
mask_path = os.path.join(path, "masks")

# Define test split
test_split = 0.15

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
memory = True if device == "cuda" else False

In [106]:
def load_tifs_as_arrays(image_path):
  image_arrays = []
  for filename in os.listdir(image_path):
      if filename.endswith(".tif"):
          image_filepath = os.path.join(image_path, filename)
          with rasterio.open(image_filepath) as src:
            img_array = src.read()
            image_arrays.append(img_array)
  return image_arrays

# Example usage
image_arrays = load_tifs_as_arrays(image_path)


In [None]:
def normalize_image_channels(image_arrays):
  """Normalizes all channels of a list of image arrays.

  Args:
    image_arrays: A list of image arrays, where each array has shape (C, H, W).

  Returns:
    A list of normalized image arrays.
  """
  normalized_image_arrays = []
  for image in image_arrays:
    normalized_image = []
    for channel in image:
        # Normalize the channel
        channel_min = np.nanmin(channel)
        channel_max = np.nanmax(channel)
        normalized_channel = (channel - channel_min) / (channel_max - channel_min)
        normalized_image.append(normalized_channel)

    normalized_image_arrays.append(np.stack(np.nan_to_num(normalized_image, nan=0.0)))

  return normalized_image_arrays

normalized_image_arrays = normalize_image_channels(image_arrays)

In [103]:
# prompt: Trim the arrays in normalized_image_arrays that have shape (9, 512, 513) so they all have shape (9, 512, 512)

trimmed_normalized_image_arrays = []
for image_array in normalized_image_arrays:
  if image_array.shape == (9, 512, 513):
    trimmed_image_array = image_array[:, :, :512]
    trimmed_normalized_image_arrays.append(trimmed_image_array)
  else:
    trimmed_normalized_image_arrays.append(image_array)

training_stack = np.stack(trimmed_normalized_image_arrays)
print(training_stack.shape)

(112, 9, 512, 512)


In [None]:
# Define number of channels, number of classes, and number of level in U-Net
num_channels = 1
num_classes = 1
num_levels = 3

# Define learning ratem number of epochs, and batch size
learning_rate = 0.001
num_epochs = 100
batch_size = 16

# Define input image dimensions
input_image_height = 512
input_image_width = 512

# Define threshold
threshold = 0.5

# Define path to base output directory
base_output_dir = "output"

# Define path to output serialized model, model training plot, and test image path
output_model_path = os.path.join(base_output_dir, "unet_tgs_salt.pth")
output_plot_path = os.path.join(base_output_dir, "plot.png")