In [133]:
%matplotlib notebook
import matplotlib.pyplot as plt 
import numpy as np
import torch
from sklearn.feature_extraction.text import HashingVectorizer, TfidfTransformer

import nmf.mult
import nmf.pgrad
import nmf.nesterov

import nmf_torch.mult
import nmf_torch.pgrad
import nmf_torch.nesterov

import read_data.reading as reading
from theory.represent import rescale_WH

In [78]:
def get_random_lowrank_matrix(m, r, n):
    return np.random.rand(m,  r) @ np.random.rand(r, n) 

In [138]:
class HashTfidfBectoriser:
    def __init__(self, n_features):
        self.hashing_vectoriser = HashingVectorizer(n_features=n_features, alternate_sign=False)
        self.tfidf_transformer = TfidfTransformer()
        
    def fit_transform(self, data):
        return self.tfidf_transformer.fit_transform(self.hashing_vectoriser.fit_transform(data))
    
    
def unroll_images(data):
    original_shape  = data.shape
    return data.reshape(original_shape[0], -1), (original_shape[1], original_shape[2])


def roll_images(data, original_image_shape):
    if data.ndim == 1:
        return data.reshape(*original_image_shape)
    else:
        return data.reshape(-1, *original_image_shape)
    
def images_matrix_grid(data, grid_shape):
    imrows = [
        np.hstack([data[i, :, :] 
                   for i in range(grid_shape[1] * a, grid_shape[1] * a + grid_shape[1])])
        for a in range(0, grid_shape[0])
    ]
    return np.vstack(imrows)

# Loading data

In [None]:
reuters_data = reader.read_reuters21578("data/reuters21578")
reuters_data.bytes / 2**20

In [67]:
indian_pines = reader.read_pines("data/indian_pines/images")
ip_unrolled_site3, ori_shape_site3 = unroll_images(indian_pines["site3_im"])
ip_unrolled_ns_line, ori_shape_ns_line = unroll_images(indian_pines["ns_line_im"])

In [100]:
faces = reader.read_face_images("data/att_faces/images/")
faces_unrolled, ori_shape_faces = unroll_images(faces)

# Random Data

In [87]:
shape = (100, 50, 100)
A = get_random_lowrank_matrix(*shape)
W_init = np.random.randn(shape[0], shape[1])
H_init = np.random.randn(shape[1], shape[2])

In [88]:
W, H, errors_proj_sub = nmf.pgrad.factorise_Fnorm_subproblems(
                              A, inner_dim=5, record_errors=True,
                              n_steps=100, epsilon=0,
                              W_init=W_init.copy(),
                              H_init=H_init.copy())

W, H, errors_proj_sub_torch = nmf_torch.pgrad.factorise_Fnorm_subproblems(
                                  torch.tensor(A), inner_dim=5, record_errors=True,
                                  n_steps=100, epsilon=0,
                                  W_init=torch.tensor(W_init.copy()),
                                  H_init=torch.tensor(H_init.copy()))

In [158]:
f, axs = plt.subplots(2, 1)
axs[0].plot(errors_proj_sub[:, 1], np.log(errors_proj_sub[:, 0]), label="numpy")
axs[0].plot(errors_proj_sub_torch[:, 1], np.log(errors_proj_sub_torch[:, 0]), label="torch")
axs[1].plot(errors_proj_sub_torch[:, 1] / errors_proj_sub[:, 1], label="torch / numpy ")
axs[0].legend(); axs[1].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f05503512e8>

# Face Images

In [136]:
inner_dim_faces = 8
W_init = np.random.randn(faces_unrolled.shape[0], inner_dim_faces)
H_init = np.random.randn(inner_dim_faces, faces_unrolled.shape[1])

W, H, errors_proj_sub = nmf.pgrad.factorise_Fnorm_subproblems(
                                faces_unrolled, 
                                inner_dim=inner_dim_faces,
                                record_errors=True,
                                n_steps=100, epsilon=1e-3,
                                W_init=W_init.copy(),
                                H_init=H_init.copy())

In [148]:
W, H, errors_proj_sub_torch = nmf_torch.pgrad.factorise_Fnorm_subproblems(
                                            torch.tensor(faces_unrolled, dtype=torch.float64), 
                                            inner_dim=20, 
                                            record_errors=True,
                                            n_steps=100, epsilon=1e-3,
                                            W_init=torch.tensor(W_init.copy()),
                                            H_init=torch.tensor(H_init.copy()))

In [150]:
W = W.numpy()
H = H.numpy()

In [151]:
W, H = rescale_WH(W, H)

face_features = roll_images(H, faces_ori_shape)
face_features = images_matrix_grid(face_features, (1, 8))

In [152]:
plt.imshow(face_features, cmap="gray")

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f0552e742b0>

In [144]:
faces_recovered = roll_images(W @ H, faces_ori_shape)

faces_recovered = images_matrix_grid(faces_recovered, (20, 20))
all_faces = images_matrix_grid(faces, (20, 20))

f, axs = plt.subplots(1, 2, sharex=True, sharey=True)
axs[0].imshow(faces_recovered, cmap="gray")
axs[1].imshow(all_faces, cmap="gray")

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7f0552f34710>

In [156]:
f, axs = plt.subplots(2, 1)
axs[0].plot(errors_proj_sub[:, 1], np.log(errors_proj_sub[:, 0]), label="numpy")
axs[0].plot(errors_proj_sub_torch[:, 1], np.log(errors_proj_sub_torch[:, 0]), label="torch")
axs[1].plot(errors_proj_sub_torch[:, 1] / errors_proj_sub[:, 1], label="torch / numpy ")
axs[0].legend(); axs[1].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f05503e29e8>