In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt 
import numpy as np
import torch
from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer

import nmf.mult
import nmf.pgrad
import nmf.nesterov

import nmf_torch.mult
import nmf_torch.pgrad
import nmf_torch.nesterov
import nmf_torch.norms

import read_data.reading as reading
from theory.represent import rescale_WH

In [2]:
import importlib
importlib.reload(nmf_torch.mult)
importlib.reload(nmf_torch.pgrad)
importlib.reload(nmf_torch.nesterov)
importlib.reload(nmf_torch.norms)

<module 'nmf_torch.norms' from '/home/aiadmin/maxim/pytorch_project/finalYearProjectNMF/nmf_torch/norms.py'>

In [3]:
def get_random_lowrank_matrix(m, r, n):
    return np.random.rand(m, r) @ np.random.rand(r, n) 

In [4]:
class HashTfidfVectoriser:
    def __init__(self, n_features):
        self.hashing_vectoriser = HashingVectorizer(n_features=n_features, alternate_sign=False)
        self.tfidf_transformer = TfidfTransformer()
        
    def fit_transform(self, data):
        return self.tfidf_transformer.fit_transform(self.hashing_vectoriser.fit_transform(data))
    
    
def unroll_images(data):
    original_shape  = data.shape
    return data.reshape(original_shape[0], -1), (original_shape[1], original_shape[2])


def roll_images(data, original_image_shape):
    if data.ndim == 1:
        return data.reshape(*original_image_shape)
    else:
        return data.reshape(-1, *original_image_shape)
    
def images_matrix_grid(data, grid_shape):
    imrows = [
        np.hstack([data[i, :, :] 
                   for i in range(grid_shape[1] * a, grid_shape[1] * a + grid_shape[1])])
        for a in range(0, grid_shape[0])
    ]
    return np.vstack(imrows)


def get_time_ratio(errors_0, errors_1):
    # Rartio of times to reach certain cost function value 
    max_log_error = max(np.max(np.log(errors_0[1:, 1])), 
                        np.max(np.log(errors_1[1:, 1])))
    min_log_error = min(np.max(np.log(errors_0[:, 1])), 
                        np.min(np.log(errors_1[:, 1])))
    
    n = 100
    error_space = np.linspace(min_log_error, max_log_error, n)
    time_rates = np.zeros(n) 
    for err_i in range(n):
        try:
            time_0 = errors_0[np.log(errors_0[:, 1]) <= error_space[err_i], 0][0]
            time_1 = errors_1[np.log(errors_1[:, 1]) <= error_space[err_i], 0][0]
            time_rates[err_i] = time_0 / time_1
        except:
            time_rates[err_i] = np.nan
    return np.array([error_space, time_rates]).T

In [5]:
def compare_perofrmance(V, inner_dim, time_limit,
                        W_init, H_init,
                        algo_dict_to_test):
    errors = {}
    for algo_name, algo in algo_dict_to_test.items():
        torch.cuda.empty_cache()
        _, _, errors[algo_name] = algo(V=V,
                                       inner_dim=inner_dim,
                                       record_errors=True,
                                       time_limit=time_limit,
                                       max_steps=np.inf,
                                       epsilon=0,
                                       W_init=W_init.copy(),
                                       H_init=H_init.copy())
    return errors

def plot_perofrmance(errors):
    keys = sorted(errors.keys())
    for name in keys:
        ls = "--" if "torch" in name else "-" 
        plt.plot(errors[name][:, 0], np.log(errors[name][:, 1]), label=name, ls=ls)
    plt.legend()

# Loading data

In [6]:
reuters_data = reading.read_reuters21578("data/reuters21578", 
                                         vectorizer=HashTfidfVectoriser(12000))

data/reuters21578/reut2-000.sgm
data/reuters21578/reut2-001.sgm
data/reuters21578/reut2-002.sgm
data/reuters21578/reut2-003.sgm
data/reuters21578/reut2-004.sgm
data/reuters21578/reut2-005.sgm
data/reuters21578/reut2-006.sgm
data/reuters21578/reut2-007.sgm
data/reuters21578/reut2-008.sgm
data/reuters21578/reut2-009.sgm
data/reuters21578/reut2-010.sgm
data/reuters21578/reut2-011.sgm
data/reuters21578/reut2-012.sgm
data/reuters21578/reut2-013.sgm
data/reuters21578/reut2-014.sgm
data/reuters21578/reut2-015.sgm
data/reuters21578/reut2-016.sgm
data/reuters21578/reut2-017.sgm
data/reuters21578/reut2-018.sgm
data/reuters21578/reut2-019.sgm
data/reuters21578/reut2-020.sgm
data/reuters21578/reut2-021.sgm


In [None]:
print(reuters_data.shape)
reuters_data.toarray().astype(np.float32).nbytes / 2**30

In [None]:
indian_pines = reading.read_pines("data/indian_pines/images")
ip_unrolled_site3, ori_shape_site3 = unroll_images(indian_pines["site3_im"])
ip_unrolled_ns_line, ori_shape_ns_line = unroll_images(indian_pines["ns_line_im"])

In [None]:
faces = reading.read_face_images("data/att_faces/images/")
faces_unrolled, ori_shape_faces = unroll_images(faces)

# Random Data

In [None]:
shape = (10, 5, 10)
A = get_random_lowrank_matrix(*shape)
W_init = np.random.randn(shape[0], shape[1])
H_init = np.random.randn(shape[1], shape[2])

In [None]:
W, H, errors_proj_sub = nmf.pgrad.factorise_Fnorm_subproblems(
                              A, inner_dim=5, record_errors=True,
                              n_steps=1000, epsilon=0,
                              W_init=W_init.copy(),
                              H_init=H_init.copy())

W, H, errors_proj_sub_torch = nmf_torch.pgrad.factorise_Fnorm_subproblems(
                                  torch.tensor(A), inner_dim=5, record_errors=True,
                                  n_steps=1000, epsilon=0,
                                  W_init=torch.tensor(W_init.copy()),
                                  H_init=torch.tensor(H_init.copy()))

In [None]:
f, axs = plt.subplots(2, 1)
axs[0].plot(errors_proj_sub[:, 0], np.log(errors_proj_sub[:, 1]), label="numpy")
axs[0].plot(errors_proj_sub_torch[:, 0], np.log(errors_proj_sub_torch[:, 1]), label="torch")
axs[1].plot(errors_proj_sub_torch[:, 0] / errors_proj_sub[:, 0], label="torch / numpy")
axs[0].legend(); axs[1].legend()

In [None]:
rate_data = get_time_ratio(errors_proj_sub_torch, errors_proj_sub)
plt.plot(rate_data[:, 0], rate_data[:, 1])
plt.gca().invert_xaxis()

In [None]:
torch.sqrt(torch.tensor([10]))

# Face Images

In [None]:
inner_dim_faces = 8
W_init = np.random.randn(faces_unrolled.shape[0], inner_dim_faces)
H_init = np.random.randn(inner_dim_faces, faces_unrolled.shape[1])

W, H, errors_proj_sub = nmf.pgrad.factorise_Fnorm_subproblems(
                                faces_unrolled, 
                                inner_dim=inner_dim_faces,
                                record_errors=True,
                                n_steps=100, epsilon=1e-3,
                                W_init=W_init.copy(),
                                H_init=H_init.copy())

In [None]:
W, H, errors_proj_sub_torch = nmf_torch.pgrad.factorise_Fnorm_subproblems(
                                            torch.tensor(faces_unrolled, dtype=torch.float64), 
                                            inner_dim=20, 
                                            record_errors=True,
                                            n_steps=100, epsilon=1e-3,
                                            W_init=torch.tensor(W_init.copy()),
                                            H_init=torch.tensor(H_init.copy()))

In [None]:
W = W.numpy()
H = H.numpy()

In [None]:
W, H = rescale_WH(W, H)

face_features = roll_images(H, faces_ori_shape)
face_features = images_matrix_grid(face_features, (1, 8))

In [None]:
plt.imshow(face_features, cmap="gray")

In [None]:
faces_recovered = roll_images(W @ H, faces_ori_shape)

faces_recovered = images_matrix_grid(faces_recovered, (20, 20))
all_faces = images_matrix_grid(faces, (20, 20))

f, axs = plt.subplots(1, 2, sharex=True, sharey=True)
axs[0].imshow(faces_recovered, cmap="gray")
axs[1].imshow(all_faces, cmap="gray")

In [None]:
f, axs = plt.subplots(2, 1)
axs[0].plot(errors_proj_sub[:, 0], np.log(errors_proj_sub[:, 1]), label="numpy")
axs[0].plot(errors_proj_sub_torch[:, 0], np.log(errors_proj_sub_torch[:, 1]), label="torch")
axs[1].plot(errors_proj_sub_torch[:, 0] / errors_proj_sub[:, 0], label="torch / numpy ")
axs[0].legend(); axs[1].legend()

# Performance evaluation

In [6]:
def torch_algo_wrapper(algo, device="cpu"):
    def algo_w(*args, **kwargs):
        kwargs["V"] =  torch.tensor(kwargs["V"], device=device)
        if "W_init" in kwargs.keys():
            kwargs["W_init"] = torch.tensor(kwargs["W_init"], device=device)
        if "H_init" in kwargs.keys():
            kwargs["H_init"] = torch.tensor(kwargs["H_init"], device=device)
        W, H, errors = algo(*args, **kwargs)
        W = W.to("cpu").numpy()
        H = H.to("cpu").numpy()
        return W, H, errors
    return algo_w

algo_dict_to_test = {
    "mult": nmf.mult.factorise_Fnorm,
    "pgrad": nmf.pgrad.factorise_Fnorm_subproblems,
    "nesterov": nmf.nesterov.factorise_Fnorm,

    "mult_torch": torch_algo_wrapper(nmf_torch.mult.factorise_Fnorm, 
                                     device="cuda"),
    "pgrad_torch": torch_algo_wrapper(nmf_torch.pgrad.factorise_Fnorm_subproblems, 
                                      device="cuda"),
    "nesterov_torch": torch_algo_wrapper(nmf_torch.nesterov.factorise_Fnorm, 
                                        device="cuda")
}

In [10]:
np.random.seed(1)

shape = (500, 40, 500)
V = get_random_lowrank_matrix(*shape)
V = np.random.rand(shape[0], shape[2])
W_init = np.random.rand(shape[0], shape[1])
H_init = np.random.rand(shape[1], shape[2])

In [11]:
errors = compare_perofrmance(V, 40, time_limit=100,
                             W_init=W_init, H_init=H_init, 
                             algo_dict_to_test=algo_dict_to_test)


In [9]:
plt.figure()
plot_perofrmance(errors)

<IPython.core.display.Javascript object>

In [None]:
plt.figure()

time_ratio = get_time_ratio(errors["pgrad_torch"], errors["nesterov_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="gpu")
time_ratio = get_time_ratio(errors["pgrad"], errors["nesterov"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="cpu")
plt.gca().invert_xaxis()
plt.gca().legend()

In [None]:
plt.figure()

time_ratio = get_time_ratio(errors["pgrad"], errors["pgrad_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="pgrad")
time_ratio = get_time_ratio(errors["mult"], errors["mult_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="mult")
time_ratio = get_time_ratio(errors["nesterov"], errors["nesterov_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="nesterov")

plt.gca().invert_xaxis()
plt.gca().set_title("gpu advantage")
plt.gca().legend()

In [None]:
reuters_data_torch = torch.tensor(reuters_data.toarray(), device="cuda")

In [None]:
r = 135
shape = (reuters_data_torch.shape[0], r, reuters_data_torch.shape[1])

_, _, errors = nmf_torch.nesterov.factorise_Fnorm(reuters_data_torch, r, 
                                                 record_errors=True,
                                                 time_limit=60,
                                                 max_steps=np.inf,
                                                 epsilon=0)

In [14]:
algo_dict_to_test = {
    "mult": nmf.mult.factorise_Fnorm,
    "pgrad": nmf.pgrad.factorise_Fnorm_subproblems,
    "nesterov": nmf.nesterov.factorise_Fnorm,

    "mult_torch": torch_algo_wrapper(nmf_torch.mult.factorise_Fnorm, 
                                     device="cuda"),
    "pgrad_torch": torch_algo_wrapper(nmf_torch.pgrad.factorise_Fnorm_subproblems, 
                                      device="cuda"),
    "nesterov_torch": torch_algo_wrapper(nmf_torch.nesterov.factorise_Fnorm, 
                                        device="cuda")
}

inner_dim = 135
shape = (reuters_data.shape[0], inner_dim, reuters_data.shape[1])

W_init = np.random.rand(shape[0], shape[1])
H_init = np.random.rand(shape[1], shape[2])

errors = compare_perofrmance(reuters_data.toarray(), inner_dim, time_limit=100,
                             W_init=W_init, H_init=H_init, 
                             algo_dict_to_test=algo_dict_to_test)

In [15]:
plt.figure()

time_ratio = get_time_ratio(errors["nesterov"], errors["nesterov_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="nesterov")
time_ratio = get_time_ratio(errors["pgrad"], errors["pgrad_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="pgrad")
time_ratio = get_time_ratio(errors["mult"], errors["mult_torch"])
plt.plot(time_ratio[:, 0], time_ratio[:, 1], label="mult")
plt.gca().invert_xaxis()
plt.gca().legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f0d78915fd0>

In [16]:
plt.figure()
plot_perofrmance(errors)

<IPython.core.display.Javascript object>