In [17]:
import numpy as np
import pandas as pd
from svd_engine import read_data, keep_movies_rated_by_at_least, create_pivot_table, split

In [9]:
# # create a sample ratings matrix
# ratings = pd.DataFrame({
#     1: [5, 3, 0, 1],
#     2: [4, 0, 0, 1],
#     3: [1, 1, 0, 5],
#     4: [1, 0, 0, 4],
#     5: [0, 1, 5, 4]
# }, index=[1, 2, 3, 4])


In [20]:
data = read_data()
df = keep_movies_rated_by_at_least(data, 0.33)
ratings = create_pivot_table(df)

In [21]:
# perform CUR decomposition
U, S, V = np.linalg.svd(ratings)

k = 30
C = V[:k, :]
R = U[:, :k]
S = np.diag(S[:k])

C = np.dot(S, C)
R = np.dot(R, S)


In [22]:
# predict missing ratings using SGD
n_users, n_movies = ratings.shape
n_factors = 2
lr = 0.01
reg = 0.1
n_iterations = 100

P = np.random.normal(scale=1./n_factors, size=(n_users, n_factors))
Q = np.random.normal(scale=1./n_factors, size=(n_movies, n_factors))

def sgd(ratings, P, Q, n_factors, lr, reg, n_iterations):
    for _ in range(n_iterations):
        for i in range(n_users):
            for j in range(n_movies):
                if ratings[i, j] > 0:
                    e = ratings[i, j] - np.dot(P[i, :], Q[j, :].T)
                    P[i, :] += lr * (e * Q[j, :] - reg * P[i, :])
                    Q[j, :] += lr * (e * P[i, :] - reg * Q[j, :])
    return P, Q

ratings_array = np.array(ratings)
P, Q = sgd(ratings_array, P, Q, n_factors, lr, reg, n_iterations)
ratings_pred = np.dot(P, Q.T)


In [23]:
# convert predicted ratings matrix to pandas DataFrame
ratings_pred = pd.DataFrame(ratings_pred, index=ratings.index, columns=ratings.columns)


In [24]:
ratings_pred

Unnamed: 0,1,7,50,56,69,79,98,100,117,121,...,237,258,269,286,288,294,300,313,405,748
1,4.150881,4.198998,4.736453,4.480854,4.124027,4.465935,4.813824,4.441210,3.699303,3.644987,...,4.101359,4.099165,4.232586,4.212325,3.748580,3.446198,3.811052,4.549926,3.686197,3.217483
2,3.724076,3.892959,4.266132,4.386375,3.710891,4.038988,4.532579,4.279934,3.231882,3.060594,...,3.677470,3.745636,4.114116,4.113334,3.354092,2.969527,3.254760,4.044674,3.226649,2.684701
3,2.438031,2.309437,2.761138,2.174858,2.408648,2.582835,2.560742,2.240003,2.281406,2.402422,...,2.411658,2.322861,2.090821,2.057216,2.213028,2.176760,2.443592,2.719096,2.265568,2.141776
4,4.444281,4.208730,5.033121,3.961231,4.390621,4.707952,4.666037,4.080625,4.159557,4.381265,...,4.396226,4.233724,3.808489,3.747073,4.034205,3.969117,4.455906,4.956972,4.130627,3.906065
5,3.609219,3.477155,4.095283,3.389488,3.570780,3.838543,3.890002,3.453057,3.336981,3.459290,...,3.569168,3.470242,3.242116,3.200436,3.271931,3.165704,3.541191,4.007949,3.316561,3.076984
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
939,4.666725,4.499803,5.295717,4.393787,4.617357,4.964226,5.036299,4.473820,4.312067,4.466473,...,4.614873,4.489102,4.201723,4.148362,4.230336,4.089529,4.573749,5.181142,4.285863,3.972384
940,3.782291,3.661766,4.294037,3.604089,3.743559,4.027196,4.106921,3.660631,3.484624,3.595375,...,3.740010,3.646311,3.442610,3.401404,3.427543,3.300119,3.687626,4.194822,3.464153,3.195825
941,4.194843,4.162038,4.775802,4.291068,4.160639,4.492340,4.726365,4.296913,3.794866,3.819363,...,4.146205,4.098558,4.072245,4.040502,3.794141,3.561932,3.957923,4.622350,3.777396,3.382373
942,3.906821,3.390355,4.383363,2.580778,3.832803,4.059224,3.575746,2.860164,3.870766,4.367296,...,3.869930,3.554469,2.568378,2.471577,3.568605,3.790177,4.321724,4.449597,3.829278,3.930726


In [25]:
ratings

Unnamed: 0,1,7,50,56,69,79,98,100,117,121,...,237,258,269,286,288,294,300,313,405,748
1,5.0,4.0,5.0,4.0,3.0,4.0,4.0,5.0,3.0,4.0,...,2.0,5.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,4.0,0.0,5.0,0.0,0.0,0.0,0.0,5.0,0.0,0.0,...,4.0,3.0,4.0,4.0,3.0,1.0,4.0,5.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,2.0,0.0,0.0,2.0,2.0,2.0,0.0,0.0,0.0
4,0.0,0.0,5.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,5.0,0.0,0.0,4.0,5.0,5.0,0.0,0.0,0.0
5,4.0,0.0,4.0,0.0,1.0,3.0,3.0,5.0,0.0,4.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
939,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,...,5.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,0.0
940,0.0,4.0,4.0,5.0,2.0,0.0,4.0,3.0,0.0,0.0,...,0.0,5.0,4.0,3.0,0.0,4.0,5.0,5.0,0.0,0.0
941,5.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,0.0,...,0.0,4.0,0.0,0.0,0.0,4.0,4.0,0.0,0.0,0.0
942,0.0,0.0,5.0,0.0,0.0,5.0,0.0,0.0,4.0,0.0,...,0.0,4.0,2.0,0.0,0.0,0.0,5.0,3.0,0.0,0.0
