Skip to content

Commit

Permalink
Merge pull request #31 from cosanlab/mult_update
Browse files Browse the repository at this point in the history
use binary mask to improve NMF_mult fitting
  • Loading branch information
ejolly committed Jul 29, 2021
2 parents bd51959 + bce527b commit f5d02d3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
22 changes: 13 additions & 9 deletions neighbors/_fit.py
Expand Up @@ -118,8 +118,9 @@ def sgd(


@nb.njit(cache=True, nogil=True)
def mult(X, W, H, data_range, eps, tol, n_iterations, verbose):
"""Lee & Seung (2001) multiplicative update rule"""
def mult(X, M, W, H, data_range, eps, tol, n_iterations, verbose):
"""Lee & Seung (2001) multiplicative update rule extended with
Zhu (2016) binary 'weight' matrix to handle missing data."""

last_e = 0
error_history = np.zeros((n_iterations))
Expand All @@ -143,20 +144,23 @@ def mult(X, W, H, data_range, eps, tol, n_iterations, verbose):
tol,
)

# Update H
numer = W.T @ X
denom = W.T @ W @ H + eps
# The np.multiply's below have the effect of only using observed (non-missing)
# ratings when performing the factor matrix updates

# Update H (factor x item)
numer = W.T @ np.multiply(M, X)
denom = W.T @ np.multiply(M, W @ H) + eps
H *= numer
H /= denom

# Update W
numer = X @ H.T
denom = W @ H @ H.T + eps
# Update W (user x factor)
numer = np.multiply(M, X) @ H.T
denom = np.multiply(M, W @ H) @ H.T + eps
W *= numer
W /= denom

# Make prediction and get error
errors = X - W @ H
errors = np.multiply(M, X) - np.multiply(M, W @ H)
rmse = np.sqrt(np.mean(np.power(errors, 2)))

# Normalize current error with respect to max of dataset
Expand Down
8 changes: 6 additions & 2 deletions neighbors/models.py
Expand Up @@ -351,15 +351,19 @@ def fit(
# Whereas in SGD we explity pass in indices of training data for fitting, here we set testing indices to 0 so they have no impact on the multiplicative update. See Zhu, 2016 for more details: https://arxiv.org/pdf/1612.06037.pdf
self.dilate_mask(n_samples=dilate_by_nsamples)

# fillna(0) is equivalent to hadamard (element-wise) product with a binary mask
X = self.masked_data.fillna(0).to_numpy()
# Generate a binary mask matrix for observed ratings
M = self.mask.to_numpy().astype(float)
# mult() will compute hadamard (element-wise) product of the binary mask
# and all data to 0 out NaNs, so we need all data here
X = self.data.to_numpy()

# Run multiplicative updating
# Silence numba warning until this issue gets fixed: https://github.com/numba/numba/issues/4585
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumbaPerformanceWarning)
error_history, converged, n_iter, delta, norm_rmse, W, H = mult(
X,
M,
self.W,
self.H,
self.data_range,
Expand Down

0 comments on commit f5d02d3

Please sign in to comment.