# Unsupervised Image Segmentation

In [66]:
import numpy as np
from PIL import Image
import os
from glob import glob
from tqdm import tqdm
import datetime

import matplotlib.pyplot as plt

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score

from skimage.segmentation import slic
from skimage.future import graph

import torch
import torchvision

import wnet_src.wnet as wnet
import wnet_src.wnet as wnet
import wnet_src.network as network
import wnet_src.loss as loss
from wnet_src.crf import crf_fit_predict, crf_batch_fit_predict

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

## Utility functions

In [98]:
# Creates an RGB representation of a segmentation mask using the mean color of each cluster
def mask_to_rgb(mask, image):
    mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3))
    for i in range(0, mask.max().astype(int)+1):
        color = image[mask==i].mean(0)
        mask_rgb[mask==i] = color
    return np.round(mask_rgb).astype(np.uint8)

# Saves segmentation mask as RGB image and CSV
def save_segmentation(mask, image, method_name, id):
    png_path = os.path.join("output", method_name, "png", id + ".png")
    csv_path = os.path.join("output", method_name, "csv", id + ".csv")
    os.makedirs(os.path.dirname(png_path), exist_ok=True)
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    mask_rgb = mask_to_rgb(mask, image)
    mask_rgb = Image.fromarray(mask_rgb)
    mask_rgb.save(png_path)
    np.savetxt(csv_path, mask, delimiter=',')

## Locate and Load Images

In [4]:
BSDS500_dir = 'BSDS500'
gt_dir = os.path.join(BSDS500_dir, 'gt')
test_image_dir = os.path.join(BSDS500_dir, 'BSDS500', 'data', 'images', 'test')
test_image_paths = glob(os.path.join(test_image_dir, '*.jpg'))
os.makedirs("output", exist_ok=True)

class BSDS500_dataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, subset):
        self.root_dir = root_dir
        self.image_paths = glob(os.path.join(root_dir, 'BSDS500', 'data', 'images', subset, '*.jpg'))
        self.image_sizes = [np.take(torchvision.io.read_image(img_path).shape, [1,2]) for img_path in self.image_paths]
        self.image_size = (224, 224)
        
        self.max_image = None
        self.min_image = None
        for i in range(self.__len__()):
            image = torchvision.io.read_image(self.image_paths[i]).float()
            image = torchvision.transforms.functional.resize(image, self.image_size, interpolation=torchvision.transforms.functional.InterpolationMode.BILINEAR)
            self.max_image = torch.max(self.max_image, image) if self.max_image is not None else image
            self.min_image = torch.min(self.min_image, image) if self.min_image is not None else image

    def __len__(self):
        return len(self.image_paths)

    def size(self, idx=None):
        s = [self.__len__(), 3, self.image_size[0], self.image_size[1]]
        if idx is not None:
            s = s[idx]
        else:
            s = torch.Size(s)
        return s

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = torchvision.io.read_image(img_path)
        image = image.float()
        image = torchvision.transforms.functional.resize(image, self.image_size, interpolation=torchvision.transforms.functional.InterpolationMode.BILINEAR)
        image = (image - self.min_image) / (self.max_image - self.min_image)
        return image.to(device)

class BSDS500():
    def __init__(self, root_dir, batch_size):
        self.trainset = BSDS500_dataset(root_dir, 'train')
        self.trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=batch_size, shuffle=True)
        self.valset = BSDS500_dataset(root_dir, 'val')
        self.valloader = torch.utils.data.DataLoader(self.valset, batch_size=batch_size, shuffle=False)
        self.testset = BSDS500_dataset(root_dir, 'test')
        self.testloader = torch.utils.data.DataLoader(self.testset, batch_size=batch_size, shuffle=False)

    def get_trainloader(self):
        return self.trainloader
    
    def get_valloader(self):
        return self.valloader

    def get_testloader(self):
        return self.testloader

if device == "mps":
    batch_size = 16
else:
    batch_size = 8

bsds = BSDS500('../BSDS500', batch_size)
X_train = bsds.get_trainloader()
y_train = bsds.get_trainloader()
X_val = bsds.get_valloader()
y_val = bsds.get_valloader()
X_test = bsds.get_testloader()
y_test = bsds.get_testloader()

## Test Segmentation Models

### K-Means

In [12]:
# Run K-Means on HSV encoded images
for path in tqdm(test_image_paths):
    # Load image
    id = os.path.splitext(os.path.basename(path))[0]
    imageRGB = Image.open(path)
    image = np.array(imageRGB.convert("HSV"))
    pixels = image.reshape((-1, 3)) # Reshape to 2D array of pixels
    segmentors = []
    sils = []
    # Try k in range 3 to 17 with step size 2
    for k in range(3,17,2):
        segmentor = KMeans(n_clusters=k, random_state=0, init="random")
        segmentor.fit(pixels)
        segmentors.append(segmentor)
        sil = silhouette_score(pixels, segmentor.labels_, sample_size=5000) # Calculate silhouette score of clustering
        sils.append(sil)
        if len(sils) >= 3 and sils[-1] < sils[-2] and sils[-2] < sils[-3]: # Terminate search if silhouette score is decreasing
            break
    segmentor = segmentors[np.argmax(sils)] # Choose segmentor with highest silhouette score
    mask = segmentor.labels_.reshape(image.shape[:2]) # Match shape of image

    save_segmentation(mask, np.array(imageRGB), "KMeans", id)


100%|██████████| 2/2 [00:04<00:00,  2.48s/it]


### EM

In [16]:
# Run EM on HSV encoded images
for path in tqdm(test_image_paths):
    # Load image
    id = os.path.splitext(os.path.basename(path))[0]
    imageRGB = Image.open(path)
    image = np.array(imageRGB.convert("HSV"))
    pixels = image.reshape((-1, 3)) # Reshape to 2D array of pixels
    segmentors = []
    sils = []
    # Try k in range 3 to 17 with step size 2
    for k in range(3,17,2):
        segmentor = GaussianMixture(n_components=k, random_state=0)
        segmentor.fit(pixels)
        segmentors.append(segmentor)
        sil = silhouette_score(pixels, segmentor.predict(pixels), sample_size=5000) # Calculate silhouette score of clustering
        sils.append(sil)
        if len(sils) >= 3 and sils[-1] < sils[-2] and sils[-2] < sils[-3]: # Terminate search if silhouette score is decreasing
            break
    segmentor = segmentors[np.argmax(sils)] # Choose segmentor with highest silhouette score
    mask = segmentor.predict(pixels).reshape(image.shape[:2]) # Match shape of image

    save_segmentation(mask, np.array(imageRGB), "EM", id)


100%|██████████| 2/2 [00:18<00:00,  9.19s/it]


### Normalized Cuts

In [19]:
# Run normalized cuts on RGB images
for path in tqdm(test_image_paths[:2]):
    # Load image
    id = os.path.splitext(os.path.basename(path))[0]
    image = np.array(Image.open(path))

    # Run 400-means on image in RGBXY space
    labels = slic(image, compactness=30, n_segments=400)
    # Assemble region adjacency graph
    g = graph.rag_mean_color(image, labels, mode='similarity')
    # Run normalized cuts
    labels = graph.cut_normalized(labels, g)
    
    save_segmentation(labels, image, "NormalizedCuts", id)

  color = image[mask==i].mean(0)
  ret = um.true_divide(
100%|██████████| 2/2 [00:16<00:00,  8.09s/it]


### W-Net

In [86]:
# Confugre
train = False
epochs = 500
use_checkpoint = True
checkpoint_path = 'wnet-2022-12-02-02-50.pt'

net = wnet.WNet(device_type=device)
net.to(device)
if use_checkpoint:
    # Load checkpoint from previous training for 500 epochs
    checkpoint = torch.load(checkpoint_path, map_location=device)
    net.load_state_dict(checkpoint['model_state_dict'])
    net.to(device)
    print('Loaded checkpoint from {}'.format(checkpoint_path))

if train:
    net.fit(
        X_train, y_train,
        X_val, y_val,
        epochs=epochs,
        learn_rate=1e-3,
        weight_decay=1e-5
    )
    date = datetime.now().__str__()
    date = date[:16].replace(':', '-').replace(' ', '-')
    torch.save({'epoch': epochs, 'model_state_dict': net.state_dict()}, f'models/wnet-{date}.pt')

# Run WNet on test images
all_inputs = []
all_masks = []
for i,batch in tqdm(enumerate(X_test), total=len(X_test)):
    inputs = batch
    mask, outputs = net.forward(inputs)
    inputs = inputs.detach().cpu().numpy()
    outputs = outputs.detach().cpu().numpy()
    mask = mask.detach().cpu().numpy()
    # Post process encoded images with conditional random field
    crf_mask = crf_batch_fit_predict(mask, inputs)
    for j in range(inputs.shape[0]):
        all_masks.append(crf_mask[j])

# Save results
for i,mask in tqdm(enumerate(all_masks), total=len(all_masks)):
    image_path = bsds.testset.image_paths[i]
    id = os.path.splitext(os.path.basename(image_path))[0]
    image = np.array(Image.open(image_path))
    mask = torchvision.transforms.functional.resize(torch.tensor(mask), list(bsds.testset.image_sizes[i]), interpolation=torchvision.transforms.functional.InterpolationMode.NEAREST)
    mask = np.array(mask.argmax(0))
    save_segmentation(mask, image, "WNet", id)

Loaded checkpoint from wnet-2022-12-02-02-50.pt


100%|██████████| 13/13 [03:43<00:00, 17.17s/it]
