In [136]:
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
from time import time

In [8]:
from utils import mnist_reader

In [19]:
x_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
x_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')

In [25]:
tshirts = []
for i in range(len(y_train)):
    if y_train[i] == 0:
        tshirts.append(x_train[i])
for j in range(len(y_test)):
    if y_test[j] == 0:
        tshirts.append(x_test[j])

In [595]:
tshirts = np.array(tshirts)
A = tshirts.reshape((7000, 784))
beta = 0.0001
k = 5

In [992]:
def f(W, ZT):
    return np.linalg.norm(A - W @ ZT, ord='fro') ** 2 + beta * np.linalg.norm(W, ord='fro') ** 2 + beta * np.linalg.norm(ZT.T, ord='fro') ** 2

In [993]:
def grad_f_wrt_W(W, ZT):
    return 2 * ((beta * W) - ((A - W @ ZT) @ ZT.T))

In [994]:
def grad_f_wrt_ZT(W, ZT):
    return 2 * (beta * ZT + W.T @ W @ ZT - W.T @ A)

In [995]:
def project(B):
    return np.clip(B, 0, np.inf)

In [1018]:
def find_updated_factor(f, W, ZT, p, starting_alpha, wrt_W):
    eval_f = f(W, ZT)
    alpha = starting_alpha
    tol = 1e-6
    if wrt_W:
        while (f(W+alpha*p, ZT) > eval_f and alpha > tol):
            alpha = alpha / 2
        print(alpha)
        return project(W+alpha*p)
    else:
        while (f(W, ZT+alpha*p) > eval_f and alpha > tol):
            alpha = alpha / 2 
        print(alpha)
        return project(ZT+alpha*p)

In [1019]:
def nmf(steps, k):
    f_k = 0
    W_k = np.random.uniform(0, 1, (7000, k))
    ZT_k = np.random.uniform(0, 1, (k, 784))
    for iteration in range(steps):
        # Take gradient step with projection
        W_k = find_updated_factor(f, W_k, ZT_k, -grad_f_wrt_W(W_k, ZT_k), 1, True)
        ZT_k = find_updated_factor(f, W_k, ZT_k, -grad_f_wrt_ZT(W_k, ZT_k), 1, False)
        f_k = f(W_k, ZT_k)
        print("Iteration: "+ str(iteration+1))
        print(f_k)
    return f_k, ZT_k

In [1020]:
val, Z = nmf(100, 5)

0.0009765625
9.5367431640625e-07
Iteration: 1
163567937133558.4
9.5367431640625e-07
1
Iteration: 2
81838845188.71066
9.5367431640625e-07
7.62939453125e-06
Iteration: 3
13464102330.508884
9.5367431640625e-07
1.52587890625e-05
Iteration: 4
12416338568.49592
9.5367431640625e-07
1.52587890625e-05
Iteration: 5
12920822306.65198
9.5367431640625e-07
1.52587890625e-05
Iteration: 6
20105695108.91629
9.5367431640625e-07
7.62939453125e-06
Iteration: 7
20483580306.124283
1.9073486328125e-06
3.814697265625e-06
Iteration: 8
14981721519.893646
3.814697265625e-06
1.9073486328125e-06
Iteration: 9
12237778464.243551
3.814697265625e-06
1.9073486328125e-06
Iteration: 10
11558601409.780651
3.814697265625e-06
1.9073486328125e-06
Iteration: 11
11454427490.144361
3.814697265625e-06
1.9073486328125e-06
Iteration: 12
11433833363.307571
3.814697265625e-06
1.9073486328125e-06
Iteration: 13
11428505340.944698
3.814697265625e-06
1.9073486328125e-06
Iteration: 14
11426653226.948294
7.62939453125e-06
7.62939453125e-0

In [1021]:
val

7643579352.319479

In [1022]:
np.max(Z), np.min(Z)

(25.568807740890577, 0.0)

In [897]:
from PIL import Image

In [1029]:
img = Image.fromarray((255 / 25.568807740890577 * Z[4, :].reshape((28, -1))).astype(np.uint8), 'L')
img.show()

In [341]:
Z.shape

(784, 5)

In [702]:
Z

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       ...,
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

In [767]:
img = Image.fromarray(A[444].reshape((28, -1)), 'L')
img.show()