## Matrix factorization

### Singular Value Decomposition (SVD)
- U (user features)
- Signma (singular values)
- V (item features)

In [3]:
from surprise import Dataset, Reader, SVD
from surprise.model_selection import train_test_split
from surprise.accuracy import rmse
import os

In [6]:
# Path to the dataset folder

file_path = os.path.expanduser('ml-1m/ratings.dat')

In [7]:
# Define a reader
# 'user item rating timestamp' are separated by '::' characters

columns = ['user_id','item_id','rating','timestamp']
reader = Reader(line_format = 'user item rating timestamp', sep = '::')

data = Dataset.load_from_file(file_path, reader=reader)
full_data = data.build_full_trainset()

train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

In [8]:
# Use SVD for item-based collaborative filtering

svd_model = SVD() # set user_based=False for item-based collaborative filtering

In [9]:
# Train the model on the training set

svd_model.fit(train_set)

<surprise.prediction_algorithms.matrix_factorization.SVD at 0x10809d490>

In [12]:
# Make predictions on the test set

predictions = svd_model.test(test_set)

In [13]:
# Evaluate the model using RMSE

accuracy = rmse(predictions)
print(f'RMSE on the test set: {accuracy:.4f}')

RMSE: 0.8734
RMSE on the test set: 0.8734


In [14]:
from collections import defaultdict

def get_top_n(predictions, n=10):
    top_n = defaultdict(list)
    
    for uid, iid, true_r, est, _ in predictions:
        top_n[uid].append((iid, est))

    for uid, user_ratings in top_n.items():
        user_ratings.sort(key=lambda x: x[1], reverse=True)
        top_n[uid] = user_ratings[:n]

    return top_n

top_n = get_top_n(predictions, n=10)
for uid, user_ratings in top_n.items():
    print(uid, [iid for (iid, _) in user_ratings])
        

1841 ['34', '318', '1247', '356', '1784', '1584', '296', '508', '1265', '2501']
3715 ['780', '2021', '1270', '1073', '2706', '2683', '2716', '2987', '3114', '1']
2002 ['903', '2203', '951', '930', '1269', '1267', '3307', '3801', '1934', '1278']
3332 ['223', '2692', '3160', '1198', '2734', '2890', '3114', '3163', '1961', '2693']
3576 ['553', '356', '593', '2000', '1250', '3801', '1378', '368', '2944', '3836']
2092 ['858', '1193', '908', '3435', '1234', '2997', '21', '946', '3801', '922']
5283 ['1148', '2997', '2908', '1257', '1036', '1210', '1580', '1921', '3654', '1914']
4610 ['745', '1223', '1136', '2918', '1278', '3000', '3745', '3039', '3793', '3702']
398 ['608', '1230', '2064', '50', '3730', '111', '3317', '47', '1276', '924']
4533 ['3178', '508', '2997', '3408', '2599', '1089', '357', '1097', '3534']
76 ['50', '2762', '3897', '1148', '337', '2686', '2692', '2289', '1', '2336']
921 ['593', '678', '1240', '2571', '1291', '2858', '1411', '1231', '1358', '2071']
4473 ['593', '2248', '