<a href="https://colab.research.google.com/github/liangchow/zindi-amazon-secret-runway/blob/dylan/unet_starter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Read this first (10-15 minutes read)
# https://amaarora.github.io/posts/2020-09-13-unet.html

# Then, read these:
# https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/
# https://www.kaggle.com/code/quadeer15sh/how-to-perform-semantic-segmentation-using-u-net/
# https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch#Class-Distribution
# https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/

In [2]:
# Install Dependencies

!pip install rasterio geopandas numpy matplotlib



In [3]:
# Import packages

import torch
import os
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import geopandas as gpd

In [4]:
# Mount google drive

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [9]:
### CONSTANTS & CONFIGURATIONS

# Set path
path = os.path.join("/content", "drive", "MyDrive", "training")

image_path = os.path.join(path, "images")
mask_path = os.path.join(path, "masks")

# Define test split
test_split = 0.15

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
memory = True if device == "cuda" else False

# Define number of channels, number of classes, and number of level in U-Net
num_channels = 1
num_classes = 1
num_levels = 3

# Define learning ratem number of epochs, and batch size
learning_rate = 0.001
num_epochs = 100
batch_size = 16

# Define input image dimensions
input_image_height = 512
input_image_width = 512

# Define threshold
threshold = 0.5

# Define path to base output directory
base_output_dir = "output"

# Define path to output serialized model, model training plot, and test image path
output_model_path = os.path.join(base_output_dir, "unet_tgs_salt.pth")
output_plot_path = os.path.join(base_output_dir, "plot.png")

In [6]:
# Load files

BLACKLISTED_FILES = ["Sentinel_AllBands_Training_Id_127.tif","Sentinel_AllBands_Training_Id_140.tif","Sentinel_AllBands_Training_Id_142.tif","Sentinel_AllBands_Training_Id_174.tif"]

def load_tifs_as_arrays(image_path):
  image_arrays = []
  filename_map = []
  for filename in os.listdir(image_path):
      if filename.endswith(".tif") and filename not in BLACKLISTED_FILES:
          image_filepath = os.path.join(image_path, filename)
          with rasterio.open(image_filepath) as src:
            img_array = src.read()
            image_arrays.append(img_array)
            filename_map.append(filename)
  return image_arrays, filename_map

image_arrays, filename_map = load_tifs_as_arrays(image_path)
print(filename_map)


['Sentinel_AllBands_Training_Id_1.tif', 'Sentinel_AllBands_Training_Id_2.tif', 'Sentinel_AllBands_Training_Id_9.tif', 'Sentinel_AllBands_Training_Id_10.tif', 'Sentinel_AllBands_Training_Id_11.tif', 'Sentinel_AllBands_Training_Id_17.tif', 'Sentinel_AllBands_Training_Id_18.tif', 'Sentinel_AllBands_Training_Id_19.tif', 'Sentinel_AllBands_Training_Id_20.tif', 'Sentinel_AllBands_Training_Id_21.tif', 'Sentinel_AllBands_Training_Id_22.tif', 'Sentinel_AllBands_Training_Id_23.tif', 'Sentinel_AllBands_Training_Id_28.tif', 'Sentinel_AllBands_Training_Id_30.tif', 'Sentinel_AllBands_Training_Id_31.tif', 'Sentinel_AllBands_Training_Id_32.tif', 'Sentinel_AllBands_Training_Id_33.tif', 'Sentinel_AllBands_Training_Id_36.tif', 'Sentinel_AllBands_Training_Id_37.tif', 'Sentinel_AllBands_Training_Id_38.tif', 'Sentinel_AllBands_Training_Id_39.tif', 'Sentinel_AllBands_Training_Id_40.tif', 'Sentinel_AllBands_Training_Id_41.tif', 'Sentinel_AllBands_Training_Id_42.tif', 'Sentinel_AllBands_Training_Id_44.tif', 'S

In [7]:

def trim_arrays(untrimmed_image_arrays):
  trimmed_images = []
  for image in untrimmed_image_arrays:
    if image.shape == (9, 512, 512):
      trimmed_images.append(image)
      continue

    if image.shape == (9, 512, 513):
      nan_counts_front = np.sum(np.isnan(image[:, :, 0]))
      nan_counts_back = np.sum(np.isnan(image[:, :, -1]))

      if nan_counts_front > nan_counts_back:
        trimmed_array = image[:, :, 1:]
      else:
        trimmed_image = image[:, :, :-1]
      trimmed_images.append(trimmed_image)
    else:
      print("ERROR: Unexpected size detected", image.shape)
  return np.stack(trimmed_images)

trimmed_image_arrays = trim_arrays(image_arrays)
assert(np.sum(np.isnan(trimmed_image_arrays) == 0)) # Ensure there are no NaN values
print(trimmed_image_arrays.shape) # (Number of Images, Channels per Image, Width of Image, Height of Image)

(108, 9, 512, 512)


In [8]:
def normalize_image_channels(image_arrays):
  """Normalizes all channels of a list of image arrays.

  Args:
    image_arrays: A list of image arrays, where each array has shape (C, H, W).

  Returns:
    A list of normalized image arrays.
  """
  normalized_image_arrays = []
  for image in image_arrays:
    normalized_image = []
    for channel in image:
        # Normalize the channel
        channel_min = np.nanmin(channel)
        channel_max = np.nanmax(channel)
        normalized_channel = (channel - channel_min) / (channel_max - channel_min)
        normalized_image.append(normalized_channel)
    normalized_image_arrays.append(normalized_image)

  return np.stack(normalized_image_arrays)

normalized_image_arrays = normalize_image_channels(trimmed_image_arrays)
print(normalized_image_arrays.shape)

(108, 9, 512, 512)
