# Segmentierung der Marsoberfläche mit Hilfe von unüberwachtem tiefem Clustering

## Eingabeparameter

In [None]:
input_path = "images/p03.png"
op_mode = "tile" # {batch,tile,single}
tile_dims = (640, 625)
max_tiles = 5
grayscale = True
clustering = "mr" # {tsugf,slic,lm,s,mr}
network = "mnet" # {mnet}
tensorboard = True

## Imports und Optionen

In [None]:
import os
from time import time

import cv2
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import torch
from IPython.display import display
from joblib import Parallel, delayed
from scipy.stats import mode
from skimage.color import label2rgb
from skimage.segmentation import felzenszwalb, slic
from sklearn.cluster import KMeans
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

starttime = time()
%matplotlib inline
plt.rcParams['figure.figsize'] = [5, 5]

## Parameter prüfen

In [None]:
if os.path.isdir(input_path) and op_mode != "batch":
    raise ValueError()
if os.path.isfile(input_path) and op_mode not in ["tile", "single"]:
    raise ValueError()
if max_tiles is None:
    max_tiles = np.inf


filelist = []
if op_mode == "single":
    source = cv2.imread(input_path)[:,:,::-1]
    images = [source]
elif op_mode == "tile":
    source = cv2.imread(input_path)[:,:,::-1]
    images = []
    if source.shape[0] % tile_dims[0] != 0 or source.shape[1] % tile_dims[1] != 0:
        raise ValueError("Image size has to be a divisible by the tile size!")
    for y in range(0, source.shape[0], tile_dims[0]):
        for x in range(0, source.shape[1], tile_dims[1]):
            images.append(source[y:y + tile_dims[0], x:x + tile_dims[1]])
    if max_tiles is not np.inf:
        images = images[:max_tiles]
elif op_mode == "batch":
    source = []
    filelist = []
    ls=os.listdir(input_path)
    for file in ls[:]:
        try:
            source.append(cv2.imread(input_path+"/"+file)[:,:,::-1])
            filelist.append(file)
        except:
            pass

## Pre-Processing

In [None]:
if grayscale:
    for i in range(len(images)):
        images[i] = cv2.cvtColor(images[i], cv2.COLOR_RGB2GRAY)[:,:,np.newaxis]

## Clustering Algorithmen

### Gabor Filters

In [None]:
def applygaborfilter(A, g, mr=False):
  if A.shape[2]==3:
      Agray = cv2.cvtColor(A, cv2.COLOR_RGB2GRAY)
  elif A.shape[2] == 1:
      Agray = A
      
  numRows = Agray.shape[0]
  numCols = Agray.shape[1]
  sizeFactors = np.array([4])
  #sizeFactors = np.power(2, sizeFactors)
  gabormag = np.ndarray((Agray.shape[0], Agray.shape[1], g.shape[2]*len(sizeFactors)))
  sigmas = np.ndarray((g.shape[2]+1)*len(sizeFactors))
    
  #for i in range(g.shape[2]):
  #    plt.subplot(15, 15, i+1)
  #    plt.imshow(cv2.resize(g[:,:,i], (g[:,:,i].shape[0]*sizeFactors[-1], g[:,:,i].shape[1]*sizeFactors[-1]), interpolation = cv2.INTER_LANCZOS4))
  #plt.show()
    
  for i in range(g.shape[2]):
      for s in range(len(sizeFactors)):
          gabormag[:, :, i*len(sizeFactors)+s] = cv2.filter2D(Agray, -1, cv2.resize(g[:,:,i], (int(g[:,:,i].shape[0]*sizeFactors[s]), int(g[:,:,i].shape[1]*sizeFactors[s])), interpolation = cv2.INTER_LANCZOS4), borderType=cv2.BORDER_REPLICATE)
          sigmas[i*len(sizeFactors)+s] = np.sqrt(2)*g.shape[0]*sizeFactors[s]/49

  #for i in range(gabormag.shape[2]):
  #    plt.subplot(15, 15, i+1)
  #    plt.imshow(gabormag[:, :, i])
  #plt.show()      
      
  for i in range(gabormag.shape[2]):
      gabormag[:, :, i] = cv2.GaussianBlur(gabormag[:, :, i], (0, 0), 3*sigmas[i])

  X = np.arange(1, numCols + 1)
  Y = np.arange(1, numRows + 1)
  X, Y = np.meshgrid(X, Y)
        
  numPoints = numRows * numCols
  featureSet = gabormag
    
  if mr is True:
    argsort = np.argsort(featureSet[:,:,:-2].sum(axis=(0,1)))
    maxresp = featureSet[:, :, argsort]
    maxresp = maxresp[:,:,:(6*len(sizeFactors))]
    rotinv = featureSet[:, :, -2:]
    featureSet = np.concatenate((maxresp, rotinv), 2)
    
  if A.shape[2]==3:
      featureSet = np.concatenate((featureSet, np.expand_dims(A[:,:,0], axis=2)), 2)
      featureSet = np.concatenate((featureSet, np.expand_dims(A[:,:,1], axis=2)), 2)
      featureSet = np.concatenate((featureSet, np.expand_dims(A[:,:,2], axis=2)), 2)
  elif A.shape[2]==1:
      featureSet = np.concatenate((featureSet, np.expand_dims(cv2.GaussianBlur(A[:, :, 0], (0, 0), 15), axis=2)), 2)
  
  featureSet = np.concatenate((featureSet, np.expand_dims(X, axis=2)), 2)
  featureSet = np.concatenate((featureSet, np.expand_dims(Y, axis=2)), 2)

  X = featureSet.reshape(numPoints, -1)

  #for i in range(X.shape[1]):
  #    plt.subplot(15, 15, i+1)
  #    plt.imshow(X[:, i].reshape((640,625)))
  #plt.show()
    
  X = X - X.mean(axis=0)
  X = X / X.std(axis=0, ddof=1)
  X = X[:, ~np.isnan(X).any(axis=0)]
  X = X[:, ~np.isinf(X).any(axis=0)]
  X = X.reshape(A.shape[0], A.shape[1], -1)

  X[:,:,-2:]=  X[:,:,-2:]*0.0
  if A.shape[2]==3:
    X[:,:,-5:-3]=  X[:,:,-5:-3]*0.75
  elif A.shape[2]==1:
    X[:,:,-3]=  X[:,:,-3]*0.75
    
    
  X = X.reshape(numPoints, -1)
    
  L = KMeans(n_clusters=100, n_init=3, max_iter=50, n_jobs=1).fit(X).labels_

  return L

mat = sio.loadmat("./filterbanks/filterbanks.mat")
if clustering == "tsugf":
  def cluster(image):
    return applygaborfilter(image, mat["TSUGFfilters"])
if clustering == "lm":
  def cluster(image):
    return applygaborfilter(image, mat["LMfilters"])
if clustering == "s":
  def cluster(image):
    return applygaborfilter(image, mat["Sfilters"])
if clustering == "mr":
  def cluster(image):
    return applygaborfilter(image, mat["RFSfilters"], mr=True)
  

### SLIC Superpixels

In [None]:
if clustering == "slic":
  def cluster(image):
    if images[0].shape[2] == 1:
      return slic(image, n_segments=100, compactness=0.1, enforce_connectivity=True)
    elif images[0].shape[2] == 3:
      return slic(image, n_segments=100, compactness=15, enforce_connectivity=True)

In [None]:
if clustering not in ["tsugf", "slic", "mr", "lm", "s"]:
    raise NotImplementedError("Clusteringmethode "+clustering+" ist nicht implementiert!")

### Clustermatrix konvertieren

In [None]:
def gen_cells(tags):
  cells = []
  for c in np.unique(tags):
    cells.append(np.where(c == tags)[0])
  return cells

## Neuronale Netze

### MNet

#### Architektur

In [None]:
class MNet(nn.Module):
  def __init__(self, input_dim, feature_dim):
    super().__init__()
    self.input_dim = input_dim
    self.feature_dim = feature_dim
    self.fc_dim = feature_dim
    
    self.conv1 = nn.Conv2d(self.input_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act1 = nn.ReLU()
    self.bn1 = nn.BatchNorm2d(self.feature_dim)
    self.pool1 = nn.MaxPool2d(2)
    self.conv2a = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2a = nn.ReLU()
    self.bn2a = nn.BatchNorm2d(self.feature_dim)
    self.pool2a = nn.MaxPool2d(2)
    self.conv2b = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2b = nn.ReLU()
    self.bn2c = nn.BatchNorm2d(self.feature_dim)
    self.conv2c = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=5, stride=1, padding=2)
    self.act2c = nn.ReLU()
    self.bn2b = nn.BatchNorm2d(self.feature_dim)
    self.conv3 = nn.Conv2d(self.feature_dim, self.fc_dim, kernel_size=1, stride=1, padding=0)
    self.bn3 = nn.BatchNorm2d(self.fc_dim)
    self.fc1 = nn.Linear(self.fc_dim, self.fc_dim)
    self.actf1 = nn.ReLU()
    self.fc2 = nn.Linear(self.fc_dim, self.fc_dim)
    self.actf2 = nn.ReLU()
    # Softmax is already included in the loss function
  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.bn1(x)
    #x = self.pool1(x)
    x = self.conv2a(x)
    x = self.act2a(x)
    x = self.bn2a(x)
    #x = self.pool2a(x)
    x = self.conv2b(x)
    x = self.act2b(x)
    x = self.bn2b(x)
    x = self.conv2c(x)
    x = self.act2c(x)
    x = self.bn2c(x)
    x = self.conv3(x)
    x = self.bn3(x)
    #shape = x.shape
    #x = x.view(self.fc_dim, -1)
    #x = x.permute(1, 0)
    #x = self.fc1(x)
    #x = self.actf1(x)
    #x = self.fc2(x)
    #x = self.actf2(x)
    #x = x.permute(1, 0)
    #x = x.view(1, self.fc_dim, shape[2], shape[3])
    return x[0]


## Hilfsfunktionen

In [None]:
def gen_preview(tags, shape, colors):
  return label2rgb(tags.reshape(shape[0], shape[1]), colors=colors)

In [None]:
def get_output_size(input, network):
  model = MNet(input_dim = input.shape[1], feature_dim=100).cuda()
  return np.array(model(input).permute(1, 2, 0).shape)

## Main Loop

In [None]:
previews = []
labels = []
all_tags = []

for id in range(min(len(images), max_tiles)):
  colors = np.random.randint(255, size=(1000, 3))
  model = MNet(input_dim = images[id].shape[2], feature_dim=100).cuda()
  loss_fn = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.6)

    
  if images[id].shape[2] == 3:
    preprocess = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
  else:
    preprocess = transforms.Compose([
      transforms.ToTensor(),
    ])
  
  input = preprocess(images[id].copy()).unsqueeze(0).cuda()

  target_shape = get_output_size(input, MNet)
  tags= cluster(images[id])

    
  if target_shape[0] != images[id].shape[0] or target_shape[1] != images[id].shape[1]:
    tags = cv2.resize(tags.reshape(images[id].shape[0], images[id].shape[1]), dsize=(target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST).flatten()

  cells = gen_cells(tags)
    
  last_loss = -10
  loss_change = []


    
  if tensorboard:
    tb = SummaryWriter("{}/{}/{}".format("tensorboard", starttime, id))
    tb.add_image("input/", images[id], 0, dataformats="HWC")
    tb.add_image("clustered/", gen_preview(tags, target_shape, colors), 0, dataformats="HWC")

    
  display("+------------+-------------+--------------+------------+-------------------+")
  display("|    Tile    |    Epoch    |    Labels    |    Loss    |    Loss Change    |")
  display("+------------+-------------+--------------+------------+-------------------+")

  max_epochs = 100
  for epoch in range(max_epochs):
    raw = model(input)
    predicted = raw.permute(1, 2, 0).view(target_shape[0]*target_shape[1], -1)
    argmax = torch.argmax(predicted, dim=1).cpu().numpy()
    n_labels = len(np.unique(argmax))

    target = np.zeros_like(argmax)
    for cell in cells:
      possible = argmax[cell].flatten()
      target[cell] = mode(possible)[0]

    optimizer.zero_grad()
    loss = loss_fn(predicted, torch.from_numpy(target).cuda().long())
    loss.backward()
    optimizer.step()
    current_loss = loss.item()

    display("|  {: 4}/{: 4} |  {: 4}/{: 4}  |     {: 4}     |  {: 3.5f}  |     {: 3.7f}    |".format((id + 1), len(images), epoch + 1, max_epochs, n_labels, current_loss,(current_loss - last_loss) / last_loss))

    loss_change.append(np.abs((current_loss - last_loss) / last_loss))

    last_loss = current_loss
    if tensorboard:
      tb.add_scalar("loss/", current_loss, epoch)
      tb.add_scalar("labels/", n_labels, epoch)
      tb.add_scalar("loss_variation/", loss_change[-1], epoch)
      tb.add_image("target/", gen_preview(target, target_shape, colors), epoch, dataformats="HWC")
      tb.add_image("preview/", gen_preview(argmax, target_shape, colors), epoch, dataformats="HWC")
      tb.flush()

    if ((all(c <= -np.inf for c in loss_change[-2:-1]) or loss<=1.7) and epoch >= 5 and n_labels <= 10) or n_labels <= 4:
      break
    
  labels.append(argmax.astype(np.uint8).reshape(target_shape[0], target_shape[1]))
  all_tags.append(tags)


In [None]:
 sio.savemat("data_"+str(starttime), {
    "filelist": filelist,
    "mode": op_mode,
    "clustering": clustering,
    "images": images,
    "tags": all_tags,
    "cells": cells,
    "labels": labels,
    "grayscale": grayscale,
    "source": source
})