In [1]:
import numpy as np
import io_data
from copy import deepcopy

In [6]:
def gaussian(x, mu, Sigma):
    _, k = x.shape
    temp = 1 / ((2 * np.pi) ** (k / 2) * np.sqrt(np.linalg.det(Sigma)))
    power = -0.5 * np.sum(np.dot(x - mu, np.linalg.inv(Sigma)) * (x - mu), axis=1)
    prob = temp * np.exp(power)
    
    # Handling NaN and 0
    for i in range(len(prob)):
        if np.isnan(prob[i]) or prob[i] == 0:
            prob[i] = 0.01
    return prob

In [3]:
def EM(file, K=2, iters=100):
    print(f"Training model on {file}.")
    # Initialisation
    # Load data and initialise x
    data, image = io_data.read_data(file, True)
    x = np.array([data[i][2:] for i in range(len(data))])
    
    # Initialise pi as 1/K
    pi = np.array([1/K for _ in range(K)])
    
    # Initialise mu by randomly picking values in the valid range of Lab coordinates
    mu = np.array([[0, 0, 0] for _ in range(K)])
    for i in range(K):
        mu[i][0] = np.random.choice(range(101), 1)
        mu[i][1], mu[i][2] = np.random.choice(range(-128, 128), 2)
        
    # Initialise Sigma
    Sigma = np.array([np.cov(np.transpose(x)) for _ in range(K)])
    
    # Compute log likelihood
    # Following the slides, but it doesn't seem to be used anywhere
    #log_likelihood = 0
    #for i in range(len(x)):
    #    p = 0
    #    for k in range(K):
    #        p += pi[k] * gaussian(x, mu[k], Sigma[k])
    #    log_likelihood += np.log(p)
    
    for __ in range(iters):
        # E-step
        gamma = np.zeros((len(x), K))
        for k in range(K):
            gamma[:, k] = pi[k] * gaussian(x, mu[k], Sigma[k])
        gamma /= np.sum(gamma, axis=1, keepdims=True)
        
        # M-step
        for k in range(K):
            N = np.sum(gamma[:, k])
            pi[k] = N / len(x)
            mu[k] = np.dot(gamma[:, k], x) / N
            Sigma[k] = np.dot(gamma[:, k] * np.transpose(x - mu[k]), x - mu[k]) / N
    return gamma

In [9]:
def generate_images(file, gamma):
    print(f"Generating images for {file}")
    
    data, _ = io_data.read_data(file, True)
    width, height = data[-1][0], data[-1][1]
    x = np.array([data[i][2:] for i in range(len(data))])
    
    white = [100, 0.01, -0.01]
    black = [0, 0, 0]
    
    # Assign each pixel to either foreground or background layer
    layer = np.argmax(gamma, axis=1)
    
    # Create image mask by setting background layer pixels to white and foreground pixels to black
    mask_data = deepcopy(data)
    for i in range(len(x)):
        mask_data[i][2:] = black if layer[i] == 0 else white
    mask_file = f"{file[:-4]}_mask.txt"
    mask_image = f"{file[:-4]}_mask.jpg"
    io_data.write_data(mask_data, mask_file)
    io_data.read_data(mask_file, False, False, True, mask_image)
    
    # create foreground image by setting background layer pixels to black
    fg_data = deepcopy(data)
    for i in range(len(x)):
        fg_data[i][2:] = black if layer[i] == 0 else x[i]
    fg_file = f"{file[:-4]}_fg.txt"
    fg_image = f"{file[:-4]}_fg.jpg"
    io_data.write_data(fg_data, fg_file)
    io_data.read_data(fg_file, False, False, True, fg_image)
    
    # create background image by setting foreground layer pixels to black
    bg_data = deepcopy(data)
    for i in range(len(x)):
        bg_data[i][2:] = black if layer[i] == 1 else x[i]
    bg_file = f"{file[:-4]}_bg.txt"
    bg_image = f"{file[:-4]}_bg.jpg"
    io_data.write_data(bg_data, bg_file)
    io_data.read_data(bg_file, False, False, True, bg_image)

In [10]:
gamma_fox = EM("data/fox.txt")
generate_images("data/fox.txt", gamma_fox)

Training model on data/fox.txt.
Generating images for data/fox.txt


In [19]:
gamma_owl = EM("data/owl.txt")
generate_images("data/owl.txt", gamma_owl)

Training model on data/owl.txt.
Generating images for data/owl.txt


In [15]:
gamma_zebra = EM("data/zebra.txt")
generate_images("data/zebra.txt", gamma_zebra)

Training model on data/zebra.txt.
Generating images for data/zebra.txt


In [16]:
gamma_cow = EM("data/cow.txt")
generate_images("data/cow.txt", gamma_cow)

Training model on data/cow.txt.
Generating images for data/cow.txt
