In [1]:
import pandas as pd
import numpy as np
import datetime
import time
import pickle
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline
from collections import Counter
from random import choices, sample
from gensim.models import Word2Vec
import gc
pd.set_option('display.float_format',lambda x : '%.2f' % x)
from math import log10

In [2]:
from os import listdir
from os.path import isfile, isdir, join

# 指定要列出所有檔案的目錄
mypath = './dien_ali'
# 取得所有檔案與子目錄名稱
files = listdir(mypath)
csv_list = []
# 以迴圈處理
for f in files:
  # 產生檔案的絕對路徑
  fullpath = join(mypath, f)
  # 判斷 fullpath 是檔案還是目錄
  if isfile(fullpath):
#     print("檔案：", f)
    if f[-4:]=='.csv':
        csv_list.append(f)
csv_list.sort()

In [3]:
csv_list

['ali_dien_rank_11-26.csv',
 'ali_dien_rank_11-27.csv',
 'ali_dien_rank_11-28.csv',
 'ali_dien_rank_11-29.csv',
 'ali_dien_rank_11-30.csv',
 'ali_dien_rank_12-01.csv',
 'ali_dien_rank_12-02.csv',
 'ali_dien_rank_12-03.csv']

In [4]:
knn_tables = []
for part in range(1, 5):
    with open('../Taobao_v3/dict_knn_table/knn_table_'+ str(part) +'.pkl', 'rb') as file:
        knn_tables.append(pickle.load(file))

In [5]:
def knn_ranking(r):
    for part in range(4):
        if str(r.target) in knn_tables[part]:
            lt = knn_tables[part][str(r.target)]
            ids, cos = zip(*lt)
            try:
                rank = ids.index(str(r.next)) + 1
            except:
                rank = np.nan
            break
    else:
        rank = np.nan
    return rank

In [6]:
NN_tables = []
for part in range(1, 5):
    with open('../Taobao_v3/dict_NN_table/NN_table_'+ str(part) +'.pkl', 'rb') as file:
        NN_tables.append(pickle.load(file))

In [7]:
def NN_ranking(r):
    for part in range(4):
        if str(r.target) in NN_tables[part]:
            lt = NN_tables[part][str(r.target)]
            ids, score = zip(*lt)
            try:
                rank = ids.index(str(r.next)) + 1
            except:
                rank = np.nan
            break
    else:
        rank = np.nan
    return rank

In [8]:
lt_df_dien_ali = []

In [9]:
for file_name in tqdm(csv_list):
    df = pd.read_csv('./dien_ali/' + file_name)
    df.rename(columns = {'last_id':'target', 'target_iid':'next'}, inplace=True)
    df['knn_rank'] = df.apply(knn_ranking, axis=1)
    df['NN_rank'] = df.apply(NN_ranking, axis=1)
    lt_df_dien_ali.append(df)

100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


In [10]:
full_table = pd.concat(lt_df_dien_ali[:], ignore_index=True)

In [11]:
full_table

Unnamed: 0,uid,hist_iid,target,next,dien_rank,knn_rank,NN_rank
0,339,3437026/3523692/1850821/1687195/2487310/203675...,222342,1692072,61,4,1
1,817,187317/2596619/1914934/3696279/4025973/2525212...,524677,4872789,64,70,51
2,889,3594856/2907526/4697594/1838487/4697594/805967...,347712,2907526,15,20,23
3,1010,5000746/1556415/896383/2202649/3459004/1320429...,139140,4808594,60,20,34
4,1123,5046379/2645331/2645331/2239418/2645331/183186...,132534,1140169,72,27,30
...,...,...,...,...,...,...,...
111331,1016608,4783009/4946821/2185553/4946821/4946821,4946821,2185553,79,25,29
111332,1017023,4489720/685183/493664/4261030/1957519/4846705,4846705,1894664,26,27,34
111333,1017720,1311676/715050/3636765/778799/140470/895691/22...,4723710,3804709,14,8,5
111334,1017898,2275478/3811147/2275478/4852928,4852928,4632023,74,37,6


In [12]:
full_table.describe()

Unnamed: 0,uid,target,next,dien_rank,knn_rank,NN_rank
count,111336.0,111336.0,111336.0,111336.0,111336.0,111336.0
mean,501328.85,2585261.81,2579131.18,50.17,30.02,31.93
std,294152.43,1473382.3,1470020.38,28.83,28.04,27.83
min,5.0,72.0,81.0,1.0,1.0,1.0
25%,247670.25,1321509.0,1332973.0,25.0,6.0,8.0
50%,495206.0,2597539.0,2575122.0,50.0,20.0,24.0
75%,756143.75,3850142.25,3848283.75,75.0,49.0,51.0
max,1018011.0,5163006.0,5163006.0,99.0,99.0,99.0


In [13]:
full_table['dien<knn'] = full_table.apply(lambda df: df['dien_rank'] <= df['knn_rank'], axis=1)

In [16]:
full_table[full_table['dien<knn'] == True]

Unnamed: 0,uid,hist_iid,target,next,dien_rank,knn_rank,NN_rank,dien<knn
1,817,187317/2596619/1914934/3696279/4025973/2525212...,524677,4872789,64,70,51,True
2,889,3594856/2907526/4697594/1838487/4697594/805967...,347712,2907526,15,20,23,True
7,2281,912924/4697995/1638888/3286719/4261038/4257767...,820829,2611141,75,93,21,True
9,2585,286908/545214/286908/4723710/1237636/376751/26...,144993,1706729,20,96,76,True
18,4749,2848577/838369,838369,1917033,70,78,77,True
...,...,...,...,...,...,...,...,...
111328,1016387,151902/151902/4295611/956343/4923258,4923258,956343,36,99,68,True
111329,1016423,68708/4846512,4846512,68708,1,1,1,True
111330,1016582,210305/382929/1382069/648516/1029266/1985548/5...,4923964,5067173,11,21,50,True
111332,1017023,4489720/685183/493664/4261030/1957519/4846705,4846705,1894664,26,27,34,True


In [17]:
full_table[full_table['dien<knn'] == True].describe()

Unnamed: 0,uid,target,next,dien_rank,knn_rank,NN_rank
count,34751.0,34751.0,34751.0,34751.0,34751.0,34751.0
mean,500517.93,2575717.65,2581362.77,27.93,55.13,39.33
std,293457.76,1475082.42,1472404.11,22.47,27.83,29.07
min,7.0,72.0,142.0,1.0,1.0,1.0
25%,248509.0,1319584.0,1336162.0,9.0,33.0,14.0
50%,491715.0,2573745.0,2573779.0,23.0,57.0,34.0
75%,755219.0,3845720.0,3858178.0,42.0,79.0,63.0
max,1018011.0,5162674.0,5162806.0,99.0,99.0,99.0


In [19]:
lt_df_dien_ali[0].sort_values(by=['uid'])

Unnamed: 0,uid,hist_iid,target,next,dien_rank,knn_rank,NN_rank
3574,7,980099/3427154/2689961/2292610/4034225/516760/...,2063176,1793668,45,53,40
10975,96,3005381/4160584,4160584,13283,94,5,7
7222,152,842950/3048082/3682882,3682882,3739757,18,6,6
10976,239,709710/4826455/709710/1992230/1842641/4586628/...,5030834,24834,11,28,16
0,339,3437026/3523692/1850821/1687195/2487310/203675...,222342,1692072,61,4,1
...,...,...,...,...,...,...,...
7221,1017455,3878908/2029728/4877816/4736367/1387695,1387695,1584827,20,2,24
14740,1017559,969064/7521/1367959/812391/4257874,4257874,2736720,39,34,80
14741,1017591,777173/2503250/4551433,4551433,2157244,51,38,91
14742,1017830,2382143/964233/310906/4676590,4676590,2609472,30,97,67


In [20]:
del knn_tables, NN_tables

In [21]:
gc.collect()

3121989