# Wnet for unsupervised segmentation

https://arxiv.org/abs/1711.08506

Based on these repos:

* https://github.com/lwchen6309/unsupervised-image-segmentation-by-WNet-with-NormalizedCut
* https://github.com/zwenaing/unsupervised-image-segmentation 



In [None]:
# Tensorflow < 2.0
import os
import tensorflow as tf
from datetime import datetime
import cv2
import wnet
from input_data import DataReader
from soft_ncut import *

In [None]:
# Print iterations progress
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█', printEnd = "\r"):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = printEnd)
    # Print New Line on Complete
    if iteration == total: 
        print()

## Training

In [None]:
# Set this to the directory containing the video frames.
train_vid_dir = ''

# network parameters
n_batch = 1
learning_rate = 0.003
num_steps = 1000//n_batch
num_epochs = 50
display_step = 1000//n_batch
num_classes =  20
global_step = tf.train.get_or_create_global_step()

autoencoder = wnet.Wnet(num_classes=all_params['num_classes'])
data_reader = DataReader(train_vid_dir, batch_size=n_batch)

image = tf.placeholder(tf.float32, [None, 224, 224, 3], name="image")
neighbor_indices = tf.placeholder(tf.int64, name="neighbor_indices")
neighbor_vals = tf.placeholder(tf.float32, name="neighbor_vals")
neighbor_shape = tf.placeholder(tf.int64, name="neighbor_shape")

with tf.name_scope("Encoding"):
    encoded_image = autoencoder.encode(image)    
    argmax = tf.argmax(encoded_image, axis=3)
    scaled = tf.multiply(argmax, 256//num_classes)
    reshaped = tf.reshape(scaled, [-1, 224, 224, 1])
    seg = tf.cast(reshaped, tf.uint8)
    
with tf.name_scope("Decoding"):
    decoded_image = autoencoder.decode(encoded_image)

with tf.name_scope("Loss"):
    # Loss has two parts: reconstruction loss and "soft n-cut" loss
    y_pred = tf.reshape(decoded_image, [-1, 150528])
    y_true = tf.reshape(image, [-1, 150528])
    reconstruction_loss = tf.reduce_mean(tf.pow(y_pred - y_true, 2))
    
    neighbor_filter = (neighbor_indices, neighbor_vals, neighbor_shape)
    _image_weights = brightness_weight(image, neighbor_filter, sigma_I = all_params['sigma_I'])
    image_weights = convert_to_batchTensor(*_image_weights)
    soft_ncut_loss = soft_ncut(image, encoded_image, image_weights)[0]
    loss = reconstruction_loss + soft_ncut_loss
    
with tf.name_scope("Optimization"):
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train_op = optimizer.minimize(loss=loss)

orig_img_summary = tf.summary.image('original', image)
dec_img_summary = tf.summary.image('decoded', decoded_image)
seg_img_summary = tf.summary.image('segmentation', seg)
tf.summary.scalar("Loss", loss)
merged_summary = tf.summary.merge_all()

init = tf.global_variables_initializer()
saver = tf.train.Saver(filename="ckpt")

iterator = data_reader.input_data(num_images=all_params['train_size'])
flist, next_items = iterator.get_next()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.99

with tf.Session(config=config) as sess:
    sess.run(init)
    sess.run(iterator.initializer)
    var = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    print("Trainable vars: ", var)
    train_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())

    if exists(logdir):
        if ckpt is not None:
            saver.restore(sess, ckpt)
    
    for epoch in range(num_epochs):
        i = 0
        losses = np.zeros((num_steps+1))
        times = np.zeros((num_steps+1))
        printProgressBar(0, num_steps+1, prefix = 'Epoch {}'.format(str(epoch)), suffix = '', length = 50)
        while True:
            try:
                assert i <= num_steps
                stime = timeit.default_timer()
                batch_x =  sess.run(next_items)
                image_shape = image.get_shape().as_list()[1:3]
                
                gauss_indices, gauss_vals = gaussian_neighbor(image_shape, 
                                                              sigma_X=all_params['sigma_X'],
                                                              r=all_params['r'])
                weight_shapes = np.prod(image_shape).astype(np.int64)
                
                feed_dict = {image: batch_x,
                            neighbor_indices : gauss_indices,
                            neighbor_shape : [weight_shapes, weight_shapes],
                            neighbor_vals : gauss_vals}
                if i != 0 and i % display_step == 0:
                    loss_, _, summary = sess.run([loss, train_op, merged_summary], feed_dict=feed_dict)
                    losses[i] = loss_
                    train_writer.add_summary(summary)
                    saver.save(sess, os.path.join(logdir, 'ckpt'), global_step=global_step + i)
                else:
                    loss_, _ = sess.run([loss, train_op], feed_dict=feed_dict)
                    losses[i] = loss_
                etime = timeit.default_timer()
                times[i] = etime-stime
                printProgressBar(i+1, num_steps+1, prefix = 'Epoch {}'.format(str(epoch)), suffix = 'Loss: {:.2f}'.format(loss_), length = 50)
                i += 1
            except (AssertionError, tf.errors.OutOfRangeError):
                print("Epoch: {} Avg Loss: {:.2f} Avg time/batch: {:.2f} s".format(str(epoch), np.average(losses), np.average(times)))
                break
            

## Evaluation

In [None]:
# For converting class labels into RGB.
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb
cl = params['num_classes']
cmap = plt.get_cmap('viridis')
seg_colors = {i : to_rgb(cmap(i*1/cl)) for i in range(num_classes)}

In [None]:
tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.99

# Change this to the full path of the checkpoint metagraph.
ckpt_meta = ''

# Change this to the full path of the checkpoint file
ckpt_file = ''

def ssoftmax(z):
    ez = np.exp(z - np.max(z))
    return ez / ez.sum(axis=2, keepdims=True)

with tf.Session(config=config) as sess:
    i = 0
    tf.train.import_meta_graph(ckpt_meta)
    tf.train.Saver().restore(sess, ckpt_file)
    softmax = 'Encoding/rectangle9/softmax/Softmax:0'
    files = sorted(glob.glob(os.path.join(test_vid_dir, '*.jpg')))
    for f in files:
        img = cv2.imread(f)
        img = cv2.resize(img, (224, 224))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_u = img.astype(np.uint8)
        img_exp_f = img.reshape((-1, 224, 224, 3)).astype(np.float32)
        img_shape = img_exp_f.shape[1:3]
        gauss_indices, gauss_vals = gaussian_neighbor(img_shape, sigma_X = 4, r = 5)
        weight_shapes = np.prod(img_shape).astype(np.int64)
        feed_dict = {'image:0': img_exp_f,
                     'neighbor_indices:0' : gauss_indices,
                     'neighbor_shape:0' : [weight_shapes, weight_shapes],
                     'neighbor_vals:0' : gauss_vals}
        softmax_output = sess.run(softmax, feed_dict=feed_dict)
        img_seg = np.argmax(softmax_output, axis=3).astype(np.uint8).squeeze()
        img_seg_color = (np.array(np.vectorize(seg_colors.__getitem__)(img_seg)).transpose(1,2,0)*255).astype(np.uint8)
        img_seg_color = img_seg_color.reshape((224, 224, 3))
        img_final = cv2.cvtColor(img_seg_color, cv2.COLOR_RGB2BGR)
        cv2.imwrite(str(i).zfill(5) + '.png', img_final)
        i += 1