In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, manhattan_distances
import random

### 데이터로드

In [2]:
filename="ml1m" # mlsmall ml1m lastfm abook

#train data load
train_data_df = pd.read_csv(
    './data/'+filename+'.train.rating', 
    sep='\t', header=None, names=['user', 'item'], 
    usecols=[0, 1], dtype={0: np.int32, 1: np.int32})

#test data load
#99개는 나중에 test과정에서 랜덤으로 뽑자
test_data_df = pd.read_csv(
    './data/'+filename+'.test.rating', 
    sep='\t', header=None, names=['user', 'item'], 
    usecols=[0, 1], dtype={0: np.int32, 1: np.int32})

#user, item num
num_users = train_data_df['user'].max() + 1
num_items = train_data_df['item'].max() + 1

print("n_user : {}, n_item : {}".format(num_users, num_items))
print("train : {}, test : {}".format(len(train_data_df), len(test_data_df)))

n_user : 6040, n_item : 3706
train : 994169, test : 6040


In [5]:
from collections import Counter

pred = dict(Counter(train_data_df['item'].tolist()))
print("Asis Counter len : ",len(pred))

# 빈거 있으면 채우기
empty_num = set([i for i in range(3706)])-set(pred.keys())
if len(empty_num) != 0:
    for e_num in range(len(empty_num)):
        pred.update({list(empty_num)[e_num]:0})
print("Counter len : {} <<<< {} 0으로 채움".format(len(pred), empty_num))

Asis Counter len :  3704
Counter len : 3706 <<<< {3569, 1805} 0으로 채움


### evaluate

In [6]:
def hit(gt_item, pred_items):
	if gt_item in pred_items:
		return 1
	return 0


def ndcg(gt_item, pred_items):
	if gt_item in pred_items:
		index = pred_items.index(gt_item)
		return np.reciprocal(np.log2(index+2))
	return 0

def evaluate(gt_item, full_pred_items, K):
    pred_items = full_pred_items[0:K]
    return hit(gt_item, pred_items), ndcg(gt_item, pred_items)

def user_test(test_user, K):
    # 강좌별 score
    #pred = dict(Counter(train_data_df['item'].tolist()))

    # test 100개 리스트 만들기
    asis = train_data_df[train_data_df['user']==test_user]['item'].tolist()
    gt = test_data_df[test_data_df['user']==test_user]['item'].tolist()
    
    full = set(range(0,num_items))
    test_cand_99 = random.sample(list(full-set(asis)-set(gt)),99)
    test_cand = gt.copy()
    test_cand.extend(test_cand_99)
    
    # 100개 score 다시 뽑아서 test_score에 저장
    test_score=dict()
    for item in test_cand:
        test_score.update({item:pred[item]})

    res = dict(sorted(test_score.items(),key=(lambda x:x[1]), reverse=True))

    return evaluate(gt[0], list(res.keys()), K)

In [7]:
fin_hr=[]
fin_ndcg=[]
for epoch in tqdm(range(10)):
    _hr=[]
    _ndcg=[]

    #for i in tqdm(range(num_users)):
    for i in range(num_users):
        temp1, temp2 = user_test(i,10)
        _hr.append(temp1)
        _ndcg.append(temp2)
    fin_hr.append(sum(_hr)/len(_hr))
    fin_ndcg.append(sum(_ndcg)/len(_ndcg))
    print(epoch+1, sum(_hr)/len(_hr), sum(_ndcg)/len(_ndcg))

 10%|████████▎                                                                          | 1/10 [00:41<06:10, 41.22s/it]

1 0.4508278145695364 0.2525558108010725


 20%|████████████████▌                                                                  | 2/10 [01:20<05:24, 40.58s/it]

2 0.4480132450331126 0.25039106386052834


 30%|████████████████████████▉                                                          | 3/10 [01:55<04:32, 38.92s/it]

3 0.45678807947019867 0.2529372930455645


 40%|█████████████████████████████████▏                                                 | 4/10 [02:38<04:01, 40.26s/it]

4 0.45380794701986754 0.25208318627995113


 50%|█████████████████████████████████████████▌                                         | 5/10 [03:22<03:26, 41.39s/it]

5 0.4564569536423841 0.2525803077041824


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [03:59<02:39, 39.87s/it]

6 0.45380794701986754 0.25255536416507146


 70%|██████████████████████████████████████████████████████████                         | 7/10 [04:34<01:55, 38.56s/it]

7 0.4566225165562914 0.25359370357510497


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [05:09<01:15, 37.60s/it]

8 0.4541390728476821 0.25259631394944865


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [05:46<00:37, 37.38s/it]

9 0.45281456953642385 0.2503965149657311


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [06:27<00:00, 38.72s/it]

10 0.4587748344370861 0.2546600896681906





In [None]:
print(fin_hr)
print(fin_ndcg)