In [37]:
import numpy as np
import pandas as pd
from math import sqrt
from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

In [38]:
ratings = pd.read_csv('./data/ratings_small.csv')
ratings

Unnamed: 0,userId,movieId,rating,timestamp
0,1,31,2.5,1260759144
1,1,1029,3.0,1260759179
2,1,1061,3.0,1260759182
3,1,1129,2.0,1260759185
4,1,1172,4.0,1260759205
...,...,...,...,...
99999,671,6268,2.5,1065579370
100000,671,6269,4.0,1065149201
100001,671,6365,4.0,1070940363
100002,671,6385,2.5,1070979663


In [39]:
train_df, test_df = train_test_split(ratings, test_size=0.2, random_state=42)
train_df.shape, test_df.shape

((80003, 4), (20001, 4))

In [40]:
train_df

Unnamed: 0,userId,movieId,rating,timestamp
37865,273,5816,4.5,1466946328
46342,339,2028,4.5,1446663181
64614,461,3895,0.5,1093224965
41974,300,3578,4.5,1086010878
50236,369,292,3.0,847465462
...,...,...,...,...
6265,33,3911,5.0,1032769506
54886,394,377,3.0,1298378869
76820,532,1347,3.5,1076971646
860,12,3408,4.0,968045379


In [41]:
sparse_matrix = train_df.groupby('movieId').apply(lambda x: pd.Series(x['rating'].values, index=x['userId'])).unstack()
sparse_matrix

userId,1,2,3,4,5,6,7,8,9,10,...,662,663,664,665,666,667,668,669,670,671
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,,,,,,,3.0,,4.0,,...,,4.0,3.5,,,,,,,5.0
2,,,,,,,,,,,...,5.0,,,3.0,,,,,,
3,,,,,,,,,,,...,,,,3.0,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,
5,,,,,,,,,,,...,,,,3.0,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
161830,,,,,,,,,,,...,,,,,,,,,,
161918,,,,,,,,,,,...,,,,,,,,,,
161944,,,,,,,,,,,...,,,,,,,,,,
162542,,,,,,,,,,,...,,,,,,,,,,


In [42]:
sparse_matrix_movie = sparse_matrix.apply(lambda x: x.fillna(x.mean()), axis=1)
sparse_matrix_user = sparse_matrix.apply(lambda x: x.fillna(x.mean()), axis=0)

In [7]:
sparse_matrix_movie.head()

userId,1,2,3,4,5,6,7,8,9,10,...,662,663,664,665,666,667,668,669,670,671
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,3.851485,3.851485,3.851485,3.851485,3.851485,3.851485,3.0,3.851485,4.0,3.851485,...,3.851485,4.0,3.5,3.851485,3.851485,3.851485,3.851485,3.851485,3.851485,5.0
2,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382,...,5.0,3.44382,3.44382,3.0,3.44382,3.44382,3.44382,3.44382,3.44382,3.44382
3,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875,...,3.1875,3.1875,3.1875,3.0,3.1875,3.1875,3.1875,3.1875,3.1875,3.1875
4,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,...,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545,2.454545
5,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826,...,3.347826,3.347826,3.347826,3.0,3.347826,3.347826,3.347826,3.347826,3.347826,3.347826


In [8]:
sparse_matrix_user.head()

userId,1,2,3,4,5,6,7,8,9,10,...,662,663,664,665,666,667,668,669,670,671
movieId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,2.5,3.514706,3.552632,4.317647,3.89375,3.225806,3.0,3.877778,4.0,3.694444,...,3.416667,4.0,3.5,3.289617,2.941176,3.653846,3.866667,3.233333,3.695652,5.0
2,2.5,3.514706,3.552632,4.317647,3.89375,3.225806,3.478261,3.877778,3.725,3.694444,...,5.0,3.804348,3.80649,3.0,2.941176,3.653846,3.866667,3.233333,3.695652,3.919355
3,2.5,3.514706,3.552632,4.317647,3.89375,3.225806,3.478261,3.877778,3.725,3.694444,...,3.416667,3.804348,3.80649,3.0,2.941176,3.653846,3.866667,3.233333,3.695652,3.919355
4,2.5,3.514706,3.552632,4.317647,3.89375,3.225806,3.478261,3.877778,3.725,3.694444,...,3.416667,3.804348,3.80649,3.289617,2.941176,3.653846,3.866667,3.233333,3.695652,3.919355
5,2.5,3.514706,3.552632,4.317647,3.89375,3.225806,3.478261,3.877778,3.725,3.694444,...,3.416667,3.804348,3.80649,3.0,2.941176,3.653846,3.866667,3.233333,3.695652,3.919355


In [43]:
def get_svd(s_matrix, k=300):
    u, s, vh = np.linalg.svd(s_matrix.transpose())
    S = s[:k] * np.identity(k, float)
    T = u[:,:k]
    Dt = vh[:k, :]
    
    item_factors = np.transpose(np.matmul(S, Dt))
    user_factors = np.transpose(T)
    
    return item_factors, user_factors

In [44]:
item_factors, user_factors = get_svd(sparse_matrix_movie)
pred_result_df = pd.DataFrame(np.matmul(item_factors, user_factors),
                              columns=sparse_matrix_movie.columns.values,
                              index=sparse_matrix_movie.index.values)
movie_pred_result_df = pred_result_df.transpose()

In [45]:
item_factors[0]

array([-9.97778186e+01, -3.60857424e+00, -2.38322701e+00, -9.04368632e-01,
       -1.50765037e+00,  1.19823516e+00, -2.12104783e+00,  1.35826051e-01,
        4.73476715e-01, -1.61377201e+00, -7.98195487e-01,  1.05735170e+00,
       -8.07940076e-01, -2.69756866e-01,  1.95731106e+00, -1.71554095e+00,
        1.47708439e+00, -1.55527088e-01,  1.86630425e+00, -8.60759529e-01,
        4.72967698e-01, -2.94304475e-01, -1.24468573e+00,  1.33137902e+00,
       -7.70183216e-01,  7.46700747e-02,  8.73900193e-01, -4.09929755e-01,
       -7.57898692e-01,  1.09278610e+00, -1.33540519e-02,  1.05103734e+00,
        1.02555643e+00,  1.54967784e+00, -4.22964644e-01,  5.05913016e-02,
        2.94581493e+00, -7.16436333e-01,  6.24077897e-01, -3.80623906e+00,
       -9.49452446e-01, -1.77129455e-02,  3.51039012e-01, -2.27820035e+00,
       -8.11167045e-01,  1.49660465e+00, -4.80806871e-01,  2.06255297e+00,
        8.46386446e-01, -1.51407025e+00, -7.29893035e-01,  1.17591107e+00,
       -3.20103451e-01,  

In [46]:
user_factors[0]

array([-0.03857693, -0.03860221, -0.03859386, -0.0387798 , -0.03863592,
       -0.03858594, -0.03858764, -0.03861301, -0.03860336, -0.03861145,
       -0.03862167, -0.03856265, -0.03859858, -0.03859913, -0.03725003,
       -0.03861608, -0.03861078, -0.03859405, -0.03855235, -0.03852936,
       -0.03857543, -0.03855963, -0.03854117, -0.03860812, -0.03859506,
       -0.03855306, -0.03860191, -0.03862327, -0.03860068, -0.03891138,
       -0.03863062, -0.03862162, -0.03859362, -0.03861674, -0.03857401,
       -0.038609  , -0.03862057, -0.03862669, -0.03862686, -0.03863203,
       -0.03869556, -0.03862042, -0.03855087, -0.03859953, -0.0386054 ,
       -0.03864519, -0.03862419, -0.03857701, -0.03861104, -0.03859474,
       -0.03862804, -0.03862218, -0.03857695, -0.03860113, -0.03861436,
       -0.03854777, -0.0387013 , -0.03861506, -0.03854818, -0.03863193,
       -0.03856522, -0.03861838, -0.03861739, -0.03861602, -0.03860416,
       -0.03862423, -0.03863218, -0.03858916, -0.03865833, -0.03

In [47]:
pred_result_df

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,662,663,664,665,666,667,668,669,670,671
1,3.788388,3.767808,3.776663,3.865839,3.828209,3.803000,3.114177,3.823025,3.945936,3.840780,...,3.817372,3.953483,3.421847,3.827313,3.761351,3.883660,3.873830,3.880161,3.820446,5.026974
2,3.255509,3.425875,3.525161,3.389180,3.540596,3.253587,3.388431,3.504871,3.412057,3.463966,...,4.367361,3.438237,3.425260,3.002818,3.502093,3.462110,3.379974,3.398258,3.444519,3.591383
3,3.210529,3.261263,3.259922,3.162579,3.281898,3.180689,3.222245,3.190341,3.060060,3.192234,...,3.221685,3.196402,3.189868,2.984983,3.200677,3.203058,3.204729,3.182508,3.177036,3.246090
4,2.471339,2.463294,2.464101,2.473939,2.470317,2.448745,2.471931,2.457864,2.421379,2.453038,...,2.494429,2.453832,2.463645,2.445318,2.438761,2.448407,2.457455,2.431249,2.420894,2.437669
5,3.382591,3.307403,3.256805,3.311425,3.264043,3.389565,3.311992,3.284511,3.389857,3.395594,...,3.416624,3.346588,3.351123,2.971017,3.272767,3.373862,3.380386,3.378195,3.226933,3.522576
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
161830,1.000000,0.999950,0.999875,1.000032,0.999974,1.000144,1.000081,0.999920,1.000004,1.000352,...,0.999964,1.000164,0.999985,1.000039,1.000063,1.000111,0.999868,0.999958,0.999871,0.999984
161918,1.500000,1.499925,1.499813,1.500048,1.499961,1.500217,1.500122,1.499880,1.500005,1.500528,...,1.499946,1.500247,1.499977,1.500059,1.500095,1.500167,1.499802,1.499937,1.499806,1.499976
161944,5.000000,4.999751,4.999376,5.000161,4.999870,5.000722,5.000406,4.999599,5.000018,5.001761,...,4.999819,5.000822,4.999923,5.000196,5.000317,5.000557,4.999339,4.999789,4.999354,4.999920
162542,5.000000,4.999751,4.999376,5.000161,4.999870,5.000722,5.000406,4.999599,5.000018,5.001761,...,4.999819,5.000822,4.999923,5.000196,5.000317,5.000557,4.999339,4.999789,4.999354,4.999920


In [48]:
movie_pred_result_df

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,160656,160718,161084,161155,161594,161830,161918,161944,162542,163949
1,3.788388,3.255509,3.210529,2.471339,3.382591,3.956395,3.483888,3.328588,3.072380,3.324786,...,3.500000,4.000000,2.500000,0.500000,3.000000,1.000000,1.500000,5.000000,5.000000,5.000000
2,3.767808,3.425875,3.261263,2.463294,3.307403,3.907488,3.648476,3.336206,3.097623,4.030007,...,3.499826,3.999801,2.499876,0.499975,2.999851,0.999950,1.499925,4.999751,4.999751,4.999751
3,3.776663,3.525161,3.259922,2.464101,3.256805,4.035639,3.487522,3.337107,3.074721,3.545697,...,3.499563,3.999501,2.499688,0.499938,2.999625,0.999875,1.499813,4.999376,4.999376,4.999376
4,3.865839,3.389180,3.162579,2.473939,3.311425,3.887802,3.485964,3.385314,3.064289,3.340674,...,3.500113,4.000129,2.500080,0.500016,3.000097,1.000032,1.500048,5.000161,5.000161,5.000161
5,3.828209,3.540596,3.281898,2.470317,3.264043,3.869228,3.388102,3.300536,3.112431,3.250916,...,3.499909,3.999896,2.499935,0.499987,2.999922,0.999974,1.499961,4.999870,4.999870,4.999870
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
667,3.883660,3.462110,3.203058,2.448407,3.373862,3.929642,3.368835,3.288891,2.984378,3.363669,...,3.500390,4.000446,2.500279,0.500056,3.000334,1.000111,1.500167,5.000557,5.000557,5.000557
668,3.873830,3.379974,3.204729,2.457455,3.380386,3.969988,3.454918,3.316871,3.049228,3.429763,...,3.499537,3.999471,2.499669,0.499934,2.999603,0.999868,1.499802,4.999339,4.999339,4.999339
669,3.880161,3.398258,3.182508,2.431249,3.378195,3.900931,3.479495,3.334253,3.016525,3.420720,...,3.499852,3.999831,2.499894,0.499979,2.999873,0.999958,1.499937,4.999789,4.999789,4.999789
670,3.820446,3.444519,3.177036,2.420894,3.226933,3.942869,3.432726,3.291207,3.117953,3.371568,...,3.499548,3.999483,2.499677,0.499935,2.999613,0.999871,1.499806,4.999354,4.999354,4.999354
