In [1]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

In [7]:
class Recommendation:  
    def __init__(self, file_name='ratings.train.txt', K=20, lamda=0.1, eta=0.02):      
        self.file_name = file_name
        self.K = K
        self.lamda = lamda
        self.eta = eta
        
        item_max = 0
        user_max = 0
        item_set = set()
        user_set = set()        
        with open(file_name, 'r') as file:
            for line in file:
                item, user, _ = Recommendation.parse_line(line)
                if item > item_max:
                    item_max = item
                if user > user_max:
                    user_max = user
                item_set.add(item - 1)
                user_set.add(user - 1) 
            file.close()

        self.q = np.random.uniform(0, np.sqrt(5/K), (item_max, K))
        self.p = np.random.uniform(0, np.sqrt(5/K), (user_max, K))
        self.item_list = list(item_set)
        self.user_list = list(user_set)
    
    @staticmethod
    def parse_line(line):
        u = int(line.split()[0])
        i = int(line.split()[1])
        R = float(line.split()[2])
        return i, u, R
    
    def compute_error(self):
        E = self.lamda * (np.sum(self.q[self.item_list,:]**2) + np.sum(self.p[self.user_list, :]**2))
        with open(self.file_name, 'r') as file:
            for line in file:
                i, u, R = Recommendation.parse_line(line)
                E += np.square(R - self.q[i-1,:].dot(self.p[u-1,:].T))          
            file.close()
        return E
    
    def updates(self):
        with open(self.file_name, 'r') as file:
            for line in file:
                i, u, R = Recommendation.parse_line(line)
                eps = 2*(R - self.q[i-1,:].dot(self.p[u-1,:].T))

                temp_q = self.q[i-1,:] + self.eta * (eps * self.p[u-1,:] - 2 * self.lamda * self.q[i-1,:])
                temp_p = self.p[u-1,:] + self.eta * (eps * self.q[i-1,:] - 2 * self.lamda * self.p[u-1,:])

                self.q[i-1,:] = temp_q
                self.p[u-1,:] = temp_p
            file.close()
            
        error = self.compute_error()  
        return error
    
    def main(self):
        error_list = list()
        for i in range(40):
            error = self.updates()
            error_list.append(error)
            print('iteration {}, error {}'.format(i, error))

        plt.plot(error_list)
        plt.xlabel('iteration')
        plt.ylabel('error')
        plt.show()

In [8]:
Recommendation_obj = Recommendation()
Recommendation_obj.main()

iteration 0, error 88383.30702354417
iteration 1, error 86315.23423924786
iteration 2, error 83365.3679400729
iteration 3, error 80151.50533202277
iteration 4, error 77313.12390575379
iteration 5, error 74561.81077368431
iteration 6, error 71903.18144247934
iteration 7, error 69467.26040395409
iteration 8, error 67328.79475378522
iteration 9, error 65498.3341722907
iteration 10, error 63951.73430958084
iteration 11, error 62651.69514010175
iteration 12, error 61559.24481703808
iteration 13, error 60638.85235496051
iteration 14, error 59860.12501782112
iteration 15, error 59197.893020830335
iteration 16, error 58631.63661349119
iteration 17, error 58144.726754042385
iteration 18, error 57723.693767768214
iteration 19, error 57357.60243897321
iteration 20, error 57037.545481749075
iteration 21, error 56756.2411451589
iteration 22, error 56507.71479211627
iteration 23, error 56287.0462547783
iteration 24, error 56090.168550203714
iteration 25, error 55913.706942512996
iteration 26, error 

<IPython.core.display.Javascript object>