forked from hexiangnan/neural_collaborative_filtering
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dataset.py
75 lines (68 loc) · 2.52 KB
/
Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''
Created on Aug 8, 2016
Processing datasets.
@author: Xiangnan He (xiangnanhe@gmail.com)
'''
import scipy.sparse as sp
import numpy as np
class Dataset(object):
'''
classdocs
'''
def __init__(self, path):
'''
Constructor
'''
self.trainMatrix = self.load_rating_file_as_matrix(path + ".train.rating")
self.testRatings = self.load_rating_file_as_list(path + ".test.rating")
self.testNegatives = self.load_negative_file(path + ".test.negative")
assert len(self.testRatings) == len(self.testNegatives)
self.num_users, self.num_items = self.trainMatrix.shape
def load_rating_file_as_list(self, filename):
ratingList = []
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
user, item = int(arr[0]), int(arr[1])
ratingList.append([user, item])
line = f.readline()
return ratingList
def load_negative_file(self, filename):
negativeList = []
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
negatives = []
for x in arr[1: ]:
negatives.append(int(x))
negativeList.append(negatives)
line = f.readline()
return negativeList
def load_rating_file_as_matrix(self, filename):
'''
Read .rating file and Return dok matrix.
The first line of .rating file is: num_users\t num_items
'''
# Get number of users and items
num_users, num_items = 0, 0
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
u, i = int(arr[0]), int(arr[1])
num_users = max(num_users, u)
num_items = max(num_items, i)
line = f.readline()
# Construct matrix
mat = sp.dok_matrix((num_users+1, num_items+1), dtype=np.float32)
with open(filename, "r") as f:
line = f.readline()
while line != None and line != "":
arr = line.split("\t")
user, item, rating = int(arr[0]), int(arr[1]), float(arr[2])
if (rating > 0):
mat[user, item] = 1.0
line = f.readline()
return mat