In [1]:
import pandas as pd
import numpy as np
import matrix_modules


In [3]:
# load in the data
ratings, news, users = matrix_modules.load_dataset(full=True)

# create the matrix and user 
R = matrix_modules.create_item_cluster_mat(ratings, news, num_users=len(users), num_clusters=len(news['cluster'].unique()), isALS=False)

In [15]:
# initialize U and V
K = 5 # five latent factors tentatively 
I = len(users) # number of users
M = 30 # number of items
np.random.seed(42)
U = np.random.uniform(0, 1, size=K*I).reshape((I, K))
V = np.random.uniform(0, 1, size=K*M).reshape((M, K))


# initialize a dataframe of the matrix to look at data
df = pd.DataFrame(R)
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,0,0,1,10,0,0,0,0,2,0,...,0,0,1,1,3,3,2,0,27,12
1,1,2,3,1,1,1,0,0,1,0,...,2,0,1,0,3,3,0,3,4,1
2,0,0,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,2,0
3,12,0,1,1,3,0,0,0,3,4,...,0,0,3,3,5,2,0,2,14,0
4,10,1,2,5,2,0,0,1,2,2,...,0,0,5,0,5,2,1,1,7,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
255985,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0
255986,0,0,1,1,1,0,1,0,2,0,...,0,0,1,0,1,0,0,0,3,0
255987,0,0,0,0,2,0,0,0,1,1,...,0,0,1,0,1,0,0,0,5,1
255988,0,0,0,1,0,0,0,0,1,0,...,0,0,0,1,0,1,0,0,4,0


In [5]:
from tqdm import tqdm

In [11]:
def rmse(X):
    """
    Computes root-mean-square-error, ignoring nan values
    """
    return np.sqrt(np.nanmean(X**2))

def max_update(X, Y, relative=True):
    """
    Compute elementwise maximum update
    
    parameters:
    - X, Y: numpy arrays or vectors
    - relative: [True] compute relative magnitudes
    
    returns
    - maximum difference between X and Y (relative to Y) 
    
    """
    if relative:
        updates = np.nan_to_num((X - Y)/Y)
    else:
        updates = np.nan_to_num(X - Y)
            
    return np.linalg.norm(updates.ravel(), np.inf)

In [16]:
def SGD(R,U,V,rate=0.01,max_iterations=10,lam=0.1, diff_threshold=1e-3):

    Uold = np.zeros_like(U)
    Vold = np.zeros_like(V)
    
    track_update = []

    for t in range(1, max_iterations): # , total=max_iterations, desc="Starting descent"):
        
        for i, m in tqdm(zip(*np.where(R != 0)), total=len(np.where(R != 0)[0]), desc="Iterating Over R", leave=True):

            # Calculate the penalty terms.
            u_penalty = 2 * rate * lam * U[i]
            v_penalty = 2 * rate * lam * V[m]

            # Calculate the error and then update U[i]
            error = R[i, m] - (V[m] * U[i])
            U[i] = U[i] + ((2 * rate * error) * V[m]) - u_penalty

            # Calculate the error again with updated U[i] and then update V[m]
            error = R[i, m] - (V[m] * U[i])
            V[m] = V[m] + ((2 * rate * error) * U[i]) - v_penalty

        # track_rmse += [{
        #     'iteration':i, 
        #     'rmse': rmse(Gnew),
        #     'max residual change': max_update(Gnew, G, relative=False)
        # }]
        track_update += [{
            'iteration':t, 
            'max update':max(max_update(U, Uold), max_update(V, Vold))
        }]
        Uold = U.copy()
        Vold = V.copy()
        if track_update[-1]['max update'] < diff_threshold:
            print("Threshold reached, stopping descent")
            break
        
        # compute error after one sweep of updates
        # error += [(t, matrix_modules.rmse(R, U @ V.T))]
        
        # keep track of how much U and V changes
        # update += [(t, max(matrix_modules.max_update(Uold, U), matrix_modules.max_update(Vold, V)))]

    # error = pd.DataFrame(error, columns=['iteration', 'rmse'])
    # update = pd.DataFrame(update , columns=['iteration', 'maximum update'])

    return U, V  # , error, update

# U, V = SGD(R, U, V)

U, V = SGD(R, U, V)
U

Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74291.28it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74524.15it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74303.64it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74157.44it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74488.67it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74256.69it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74410.53it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 73896.21it/s]
Iterating Over R: 100%|██████████| 2436236/2436236 [00:32<00:00, 74125.06it/s]


array([[2.04343031, 2.04055529, 2.04170713, 2.04752202, 2.04165545],
       [0.87613552, 0.87512058, 0.87577629, 0.87820275, 0.87547165],
       [0.46561442, 0.49376399, 0.48897153, 0.47342004, 0.47040606],
       ...,
       [0.75691737, 0.7562522 , 0.75706478, 0.75882424, 0.75685735],
       [0.50472598, 0.49225474, 0.49238879, 0.50589916, 0.49851509],
       [0.43180489, 0.43080645, 0.4327978 , 0.43619102, 0.43454948]])

In [17]:
V

array([[1.95728308, 1.95714621, 1.95498019, 1.95504413, 1.95797742],
       [1.94004485, 1.94561932, 1.94062996, 1.9407626 , 1.94649783],
       [1.4297694 , 1.43194288, 1.43016562, 1.42954554, 1.43146526],
       [2.17971834, 2.18492451, 2.1849389 , 2.17583093, 2.18686407],
       [1.35505139, 1.35618066, 1.35500737, 1.35351007, 1.35614137],
       [1.54508846, 1.54418927, 1.54376478, 1.5420002 , 1.54552348],
       [1.85226521, 1.85254687, 1.85094859, 1.85211695, 1.85411855],
       [1.11730366, 1.12374627, 1.11834714, 1.11833421, 1.12370341],
       [2.28895492, 2.29401255, 2.29037905, 2.28421618, 2.29351544],
       [0.97856918, 0.97830678, 0.97721072, 0.97752144, 0.97836393],
       [3.71870145, 3.7254471 , 3.72161568, 3.71215183, 3.72395701],
       [1.63161309, 1.63344364, 1.63081262, 1.63180513, 1.63472346],
       [1.379167  , 1.38112095, 1.38543213, 1.37881613, 1.38543001],
       [1.71917469, 1.71949835, 1.71766955, 1.71710614, 1.72003798],
       [2.05911225, 2.06134211, 2.

In [58]:
pd.DataFrame(U @ V.T)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,-4.666942,56.276332,0.993091,39.528913,7.713917,11.649680,24.955317,9.705748,27.514584,11.119666,...,-24.263911,-7.360216,-16.093142,-6.860031,-53.930400,-54.058669,-20.548281,-22.042722,-57.627190,57.423576
1,-0.387409,4.671571,0.082438,3.281346,0.640342,0.967055,2.071573,0.805687,2.284021,0.923058,...,-2.014178,-0.610981,-1.335912,-0.569460,-4.476832,-4.487480,-1.705739,-1.829795,-4.783707,4.766805
2,0.808715,-9.751893,-0.172089,-6.849802,-1.336713,-2.018725,-4.324404,-1.681869,-4.767889,-1.926881,...,4.204593,1.275421,2.788714,1.188746,9.345376,9.367603,3.560727,3.819692,9.985978,-9.950694
3,5.643521,-68.052423,-1.200900,-47.800527,-9.328091,-14.087431,-30.177336,-11.736723,-33.272142,-13.446509,...,29.341251,8.900376,19.460709,8.295526,65.215593,65.370704,24.848106,26.655267,69.685955,-69.439733
4,-1.179005,14.217039,0.250884,9.986153,1.948760,2.943048,6.304439,2.451954,6.950984,2.809151,...,-6.129770,-1.859405,-4.065596,-1.733044,-13.624388,-13.656793,-5.191093,-5.568633,-14.558305,14.506866
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
255985,0.404685,-4.879898,-0.086114,-3.427676,-0.668898,-1.010180,-2.163954,-0.841616,-2.385876,-0.964221,...,2.104000,0.638227,1.395487,0.594855,4.676475,4.687597,1.781806,1.911394,4.997035,-4.979379
255986,1.209942,-14.590087,-0.257467,-10.248185,-1.999894,-3.020272,-6.469864,-2.516293,-7.133375,-2.882862,...,6.290612,1.908195,4.172275,1.778518,13.981885,14.015140,5.327305,5.714751,14.940308,-14.887519
255987,-0.375782,4.531368,0.079964,3.182866,0.621124,0.938032,2.009401,0.781506,2.215473,0.895355,...,-1.953729,-0.592644,-1.295819,-0.552369,-4.342473,-4.352802,-1.654547,-1.774879,-4.640139,4.623744
255988,1.611729,-19.435036,-0.342964,-13.651314,-2.664002,-4.023218,-8.618321,-3.351881,-9.502164,-3.840178,...,8.379544,2.541851,5.557768,2.369113,18.624868,18.669166,7.096351,7.612456,19.901555,-19.831237
