In [1]:
%matplotlib inline
import os
import random
import sys
import time
import pickle
import numpy as np
import scipy.sparse as sp
import matplotlib.pyplot as plt

In [2]:
M, N = 10000, 1000

In [35]:
try:
    with open("data_train.pickle", "rb") as file:
        data, test = pickle.load(file)
        
except FileNotFoundError:
    def parse_line(line):
        coord, rating = line.split(',')
        user, film = coord.split('_')
        return (int(user[1:]), int(film[1:]), int(rating))

    data = sp.lil_matrix((M, N))
    test = sp.lil_matrix((M, N))
    with open("data_train.csv") as file:
        file.__next__() # skip header
        for line in file:
            user, film, rating = parse_line(line)
            if random.randrange(10) == 0:
                test[user-1, film-1] = rating
            else:
                data[user-1, film-1] = rating

    data = data.tocsr()
    test = test.tocsr()
    with open("data_train.pickle", mode="bw") as file:
        pickle.dump((data, test), file)

In [19]:
#The slow version
def k_means(data, test, K=10, tol=0.01, centroids=None):
    M, N = data.shape
    def get_assignments(centroids):
        assignments = np.zeros(M, dtype=np.int32)
        for user in range(M):
            dists = np.linalg.norm(data[user] - data[user].sign().multiply(centroids), axis=1)
            assignments[user] = np.argmin(dists)
        return assignments
    def get_centroids(assignments):
        centroids = np.zeros((K, N))
        counts = np.zeros((K, N))
        for k in range(K):
            centroids[k] = data[assignments == k].sum(0)
            counts[k] = data[assignments == k].sign().sum(0)
        counts[counts == 0] = 1
        return centroids / counts
    def rmse(test, guess):
        diffs = test - test.sign().multiply(guess)
        return np.sqrt(np.sum(np.square(diffs.data)) / diffs.nnz)
    
    err = 999
    
    if centroids is None:
        centroids = np.random.rand(K, N) * 6 + 0.5
    while True:
        assignments = get_assignments(centroids)
        centroids = get_centroids(assignments)
        newerr = rmse(test, centroids[assignments])
        print(rmse(data, centroids[assignments]), newerr)
        if (err - newerr) / err < tol:
            break
        err = newerr
    return err, centroids, assignments

In [6]:
import pyximport
pyximport.install(reload_support=True)

(None, <pyximport.pyximport.PyxImporter at 0x1096fa278>)

In [7]:
import skm

In [40]:
err, centroids, assignments = skm.k_means(data, test, K=6, tol=0.0001)
err

1.0119077297787558

In [42]:
centroids[:,9]

array([ 3.626703  ,  3.82683983,  3.24137931,  3.70652174,  3.26075949,
        4.17274939])

In [43]:
data[0,9]

5.0

In [None]:
plt.hist(data[:,9].data)