# Segmentation

This notebook walks through the segmentation process for images.

*Note that this was developed using Google Colab and does not neatly fit within the rest of the project structure as is; this notebook is meant to show the process, not necessarily a universal implementation.*

# Setup

The functionaliuty of the Colab notebook is used to access a drive folder with the cleaned dat to be segmented. Video files are parsed from there and then compared with what already exists so as not to repeat segmentation.

In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')
folder = "/content/drive/MyDrive/data/CleanedData"
target_folder = "/content/drive/MyDrive/data/SegmentedData2"

if not os.path.exists(target_folder):
        os.makedirs(target_folder)      

all_files = [f for f in os.listdir(folder) if '.' in f]
is_video = lambda f : f.endswith('avi')
video_files = [f for f in all_files if is_video(f)]
print(len(video_files))
completed_videos = [f for f in os.listdir(target_folder) if '.' in f]
video_files = [f for f in video_files if f not in completed_videos or True]
print(f"Total Videos to Process: {len(video_files)}")

**Function definitions**

In [None]:
import cv2 as cv
from google.colab.patches import cv2_imshow
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import slic


def mask_segmented_image(binary_image):
    gray = cv.cvtColor(binary_image, cv.COLOR_BGR2GRAY)
    unique_colors, counts = np.unique(gray, return_counts=True)
    color1, color2 = unique_colors
    count1, count2 = counts
    more_common_color = color1 if count1 > count2 else color2
    result_image = np.where(gray == more_common_color, 0, 1).astype(np.uint8) * 255
    result_image = cv.cvtColor(result_image, cv.COLOR_GRAY2BGR)
    color_ratio = max(count1/count2, count1/count2)
    return result_image, color_ratio


def kmeans_segmentation(image):
    # initial k-means segmentation
    twoDimage = image.reshape((-1,3))
    twoDimage = np.float32(twoDimage)
    criteria = (cv.TERM_CRITERIA_EPS + cv.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    ret,label,center=cv.kmeans(twoDimage,2,None,criteria,10,cv.KMEANS_PP_CENTERS)
    center = np.uint8(center)
    res = center[label.flatten()]
    segmented_image = res.reshape((image.shape))
    masked, color_ratio = mask_segmented_image(segmented_image)

    # needs to be grayscale for contour and morphology
    masked = cv.cvtColor(masked, cv.COLOR_BGR2GRAY)

    # Avoid excessive filtering when necessary
    if color_ratio > 2.0:
        return masked

    # Remove small contours
    contours = cv.findContours(masked, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
    contours = contours[0] if len(contours) == 2 else contours[1] # unpack if tuple

    largest_area = max(map(cv.contourArea, contours), default=0)
    for c in contours:
        area = cv.contourArea(c)
        if area < largest_area:
            cv.drawContours(masked, [c], -1, (0,0,0), -1)
    # Morph close
    kernel = cv.getStructuringElement(cv.MORPH_RECT, (5,5))
    close = cv.morphologyEx(masked, cv.MORPH_CLOSE, kernel, iterations=2)
    return close


def binary_slic_segmentation(image):
    segments = slic(image, n_segments=300, compactness=10, sigma=1)
    return segments


def classical_segmentation(image):
    kmeans_mask = kmeans_segmentation(image)
    superpixels = binary_slic_segmentation(image)
    binary_mask = np.zeros_like(kmeans_mask)
    for superpixel in np.unique(superpixels):
        mask = np.zeros_like(kmeans_mask)
        mask[superpixels == superpixel] = 255
        overlap = np.logical_and(kmeans_mask, mask)*255
        if np.sum(overlap) / np.sum(mask) > 0.5:
            binary_mask[superpixels == superpixel] = 255
    return binary_mask


def segment_image(image, mode=0, model=None):
    result_image = None
    if mode == 0:
        mask = classical_segmentation(image)
        result_image = cv.cvtColor(mask, cv.COLOR_GRAY2BGR)
    elif mode == 1:
        dim = image.shape
        image = cv.resize(image, (128, 128))
        output = model.forward(image)
        predicted_mask = (output != 0).float()
        predicted_mask_numpy = predicted_mask.squeeze().cpu().numpy()*255.0
        result_image = cv.resize(predicted_mask_numpy, (dim[1], dim[0]))
    else:
        print("Invalid mode")
    return result_image


def segment_video(video_path, output_path):
    cap = cv.VideoCapture(video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return

    fourcc = cv.VideoWriter_fourcc(*'XVID')
    out = cv.VideoWriter(output_path, fourcc, 2.0, (int(cap.get(3)), int(cap.get(4))))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        segmented_image = segment_image(frame, mode=0)
        out.write(segmented_image)

    cap.release()
    out.release()


**Segment the video files**

In [None]:
for f in video_files:
    video_path = folder + '/' + f
    output_path = target_folder + '/' +  f
    print(f"Processing {f}")
    print(f"Video Path: {video_path}")
    print(f"Output Path: {output_path}")
    segment_video(video_path, output_path)
    break

# Data Collection
A Kaggle dataset (URL listed) was used to train the model. Follow the instructions at the listed link to find your kaggle token (kaggle.json), and upload it to the follwing cell.

In [None]:
# https://www.kaggle.com/discussions/general/74235
! pip install -q kaggle
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

Once the file is uploaded, run the following cell to download the dataset. If using a different dataset, the image and mask directories will likely need to be changed. The dataset used includes two directories, each containing entirely JPG files. One contains images of rivers and the other contains masks in which the river is white and everything else is black.

In [None]:
import kaggle
import os

# Authenticate with Kaggle API
kaggle.api.authenticate()

# Download the dataset
kaggle.api.dataset_download_files('franciscoescobar/satellite-images-of-water-bodies', path='./data', unzip=True)

# Define the paths
data_dir = '/content/data/Water Bodies Dataset/'
image_dir = os.path.join(data_dir, 'Images')
mask_dir = os.path.join(data_dir, 'Masks')

# Data Preprocessing
A Dataset class is used to make the training easier to adapt to different datasets because it creates a uniform interface for the training process to access while abstracting away the details of the files storage and any simple transforms applied to the data.

In [None]:
from torch.utils.data import Dataset
import os
from PIL import Image
import torchvision.transforms as transforms

# Define the paths
data_dir = '/content/data/Water Bodies Dataset/'
image_dir = os.path.join(data_dir, 'Images')
mask_dir = os.path.join(data_dir, 'Masks')

class RiverDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform, target_transform):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.images = sorted([f for f in os.listdir(image_dir) if f.endswith('.jpg')])
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_base_name = os.path.splitext(img_name)[0]

        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_base_name + '.jpg')

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform(image)
        mask = self.target_transform(mask)

        return image, mask

The only transforms applied at this step are resizing all images to a uniform size (128 for computational limitations) and conversion to tensors. This transform is applied to both the images and the mask images.

In [None]:
import torch
# Define the transformations
SIZE = 128
target_transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor()
])

# Create the dataset and data loader
dataset = RiverDataset(image_dir, mask_dir, target_transform, target_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

The data was split into 80% training data and 20% testing data, each being given their own dataloader.

In [None]:
from sklearn.model_selection import train_test_split

# Split the dataset into train and test sets
train_indices, test_indices = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)

# Create subset datasets
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

# Create data loaders for train and test sets
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Deep Learning Training
The following model (A SwinV2 image transformer used as an encoder hooked up to a convolutional decoder) was found to be ineffective, as even with testing of various learning rates the extrema achieved by simply predicting a monochromatic mask (either 'all river' or all 'not river') which provided an approximately 70% accuracy from which the model did not improve at all over 10 epochs. Code is included to allow training over multiple runs of the cell by checking for uploaded model weights and using them. This also means that if one wants to run the training locally on their device in a way that supports more cores or GPU usage, they can then upload the weights to the notebook and include in the pipeline outlined here.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoImageProcessor, Swinv2Model
from torch.nn import functional as F
import torchvision.transforms as transforms

# Set the device (GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.conv1 = nn.ConvTranspose2d(768, 512, kernel_size=4, stride=4)
        self.act1 = nn.ReLU()
        self.conv2 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=4)
        self.act2 = nn.ReLU()
        self.conv3 = nn.ConvTranspose2d(128, 1, kernel_size=2, stride=2)

    def forward(self, x):
        x = x.view(-1, 768, 8, 8)
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.act2(x)
        x = self.conv3(x)
        return x

class SemanticSegmentationModel(nn.Module):
    def __init__(self):
        super(SemanticSegmentationModel, self).__init__()
        self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        self.encoder = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        self.decoder = Decoder()

    def forward(self, x):
        x = self.image_processor(images=x, return_tensors="pt")
        encoder_output = self.encoder(**x)
        last_hidden_state = encoder_output.last_hidden_state
        x = self.decoder(last_hidden_state)
        return x

import matplotlib.pyplot as plt

# Initialize the model, criterion, and optimizer
model = SemanticSegmentationModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(list(model.encoder.parameters()) + list(model.decoder.parameters()), lr=0.01)

for epoch in range(10):
    total_loss = 0
    for batch in train_dataloader:
        image, mask = batch
        image = image.to(device)
        mask = mask.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(image)
        loss = criterion(output, mask)

        # Backward pass
        loss.backward()

        # Update the model parameters
        optimizer.step()

        # Accumulate the loss
        total_loss += loss.item()

    # Calculate the average loss for the epoch
    avg_loss = total_loss / len(train_dataloader)
    print(f'Epoch {epoch+1}, Loss: {avg_loss}')
    torch.save(model.encoder.state_dict(), f"encoder_epoch_{epoch+1}.pth")
    torch.save(model.decoder.state_dict(), f"decoder_epoch_{epoch+1}.pth")
    print(f"Model saved after epoch {epoch+1}")