In [1]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import os

from datetime import datetime
from tqdm import tqdm

from collections import defaultdict
import math
import numpy as np

In [2]:
path = '../competition/data/'
train_sessions = pd.read_csv(path + 'train_sessions.csv')
train_purchase = pd.read_csv(path + 'train_purchases.csv')
test_sessions = pd.read_csv(path + 'test_leaderboard_sessions.csv')

In [None]:
train_sessions = train_sessions.append(train_purchase)

In [4]:
train_sessions.head()

Unnamed: 0,session_id,item_id,date
0,3,9655,2020-12-18 21:25:00.373
1,3,9655,2020-12-18 21:19:48.093
2,13,15654,2020-03-13 19:35:27.136
3,18,18316,2020-08-26 19:18:30.833
4,18,2507,2020-08-26 19:16:31.211


# 共现矩阵计算


In [5]:
user_item_dict = train_sessions.groupby('session_id')['item_id'].agg(list).to_dict()

In [6]:
sim_item = {}
item_cnt = defaultdict(int)
for user, items in tqdm(user_item_dict.items()):
    for i in items:
        item_cnt[i] += 1
        sim_item.setdefault(i,{})
        for relate_item in items:
            # 同一个item -》 continue
            if i == relate_item:
                continue
            sim_item[i].setdefault(relate_item,0)
            # 打压特别长的item 防止过长的session带来的影响
            sim_item[i][relate_item] += 1 / math.log(1 + len(items))


100%|█████████████████████████████████████████████████████████████████████| 1000000/1000000 [01:02<00:00, 16054.33it/s]


In [7]:
# 归一化
sim_item_corr = sim_item.copy()

for i,relate_items in tqdm(sim_item.items()):
    for j, cij in relate_items.items():
        sim_item_corr[i][j] = cij / math.sqrt(item_cnt[i]*item_cnt[j])

100%|██████████████████████████████████████████████████████████████████████████| 23618/23618 [00:14<00:00, 1593.71it/s]


In [8]:
sim_item_corr[434]

{2927: 0.002149112324538758,
 11662: 0.025689842519649616,
 28075: 0.0028485531614825944,
 16064: 0.016475000846382072,
 10414: 0.0017580769296347352,
 18539: 0.0004364981913779412,
 18476: 0.002477708349192527,
 13226: 0.02709998247758032,
 26148: 0.040273551654669004,
 25907: 0.003989024691715343,
 24315: 0.005448753018741564,
 5672: 0.000506406677507254,
 22407: 0.0006114367865865373,
 26381: 0.0004145713010395974,
 19034: 0.0002703454947962331,
 20399: 0.0007007592461226406,
 2447: 0.0569650002783693,
 18844: 0.14346404509777622,
 21423: 0.005726575532317345,
 11742: 0.00823704607786334,
 18156: 0.008134012370969494,
 6588: 0.007598379703034857,
 7317: 0.0005852903563022654,
 20862: 0.0015905920421237965,
 16660: 0.004643775121237228,
 8060: 0.012153045518773524,
 26362: 0.0040565237852868506,
 13840: 0.0017045085987224592,
 23451: 0.02031060647150104,
 27151: 0.004124643391193836,
 20770: 0.042640038780711424,
 16963: 0.0074438280872069105,
 24921: 0.007266955100199445,
 10673: 0.

# 根据根据相似矩阵推荐top-100


In [9]:
order = train_sessions['item_id'].value_counts()

In [10]:
popular_items = list(order.index)

In [11]:
popular_items

[8060,
 26853,
 2447,
 19882,
 8622,
 1644,
 17089,
 18156,
 11742,
 7963,
 20770,
 4028,
 23088,
 9184,
 2072,
 4193,
 6736,
 19912,
 21616,
 2915,
 21215,
 27225,
 24921,
 7640,
 26691,
 17239,
 27613,
 15501,
 18657,
 27555,
 2814,
 8861,
 12179,
 1018,
 14392,
 972,
 4130,
 27852,
 2173,
 20236,
 12251,
 23789,
 14550,
 13922,
 10390,
 26249,
 19150,
 21781,
 4917,
 27151,
 26301,
 22704,
 2855,
 1368,
 22607,
 16922,
 15249,
 9427,
 12540,
 18981,
 16064,
 13596,
 7792,
 21668,
 21890,
 15140,
 27556,
 23451,
 2098,
 20629,
 18723,
 11053,
 11923,
 16660,
 13409,
 7999,
 15738,
 14378,
 24243,
 8755,
 14306,
 11565,
 25415,
 17740,
 9522,
 7096,
 14927,
 7367,
 27442,
 13081,
 434,
 22886,
 17648,
 22747,
 5367,
 17431,
 12959,
 21152,
 21423,
 2188,
 3774,
 20028,
 26433,
 15777,
 1148,
 16218,
 5704,
 551,
 28133,
 26180,
 2845,
 12555,
 18947,
 15403,
 12958,
 6187,
 18801,
 6588,
 26565,
 3233,
 12662,
 107,
 19227,
 19992,
 14881,
 8648,
 25118,
 1818,
 25273,
 20599,
 2410,


In [12]:
test_sessions_dict = test_sessions.groupby('session_id')['item_id'].agg(list).to_dict()

In [13]:
session_item_list = test_sessions_dict[1178]

In [14]:
session_item_list

[15862, 2915, 23864, 2915]

In [42]:
def recommand(session_item_list):
    rank = {}
    for i in session_item_list:
        if i not in sim_item_corr.keys():
            continue

        for j, wij in sorted(sim_item_corr[i].items(),key=lambda d:d[1],reverse=True):
            if j not in session_item_list:
                rank.setdefault(j,0)
                rank[j] += wij
    if len(rank) == 0:
        item_list = popular_items[:100]
        score_list = 0
    else:
        rank = sorted(rank.items(),key=lambda d:d[1],reverse=True)[:100]
        rank = np.array(rank)
        item_list = list(rank[:,0].astype('int32'))
        score_list = rank[:,1]

        if len(item_list) < 100:
            index = 0
            while(len(item_list)<100):
                item_list.append(popular_items[index])
                item_list = list(set(item_list))
                index += 1
    return list(item_list)

In [37]:
temp_item_list = recommand(session_item_list)

In [38]:
temp_item_list

[7548,
 19882,
 5070,
 17479,
 3037,
 12959,
 3423,
 15403,
 972,
 9199,
 28105,
 3740,
 2674,
 7545,
 26281,
 4450,
 5577,
 16924,
 10991,
 4130,
 6636,
 14934,
 19263,
 14306,
 27400,
 14392,
 17376,
 8755,
 4981,
 15746,
 4131,
 23502,
 22752,
 5276,
 14381,
 10277,
 12163,
 20770,
 375,
 20629,
 13914,
 14725,
 19920,
 20471,
 28004,
 25270,
 8119,
 13845,
 10163,
 25111,
 25668,
 15870,
 7591,
 4917,
 12373,
 26785,
 1368,
 16432,
 25939,
 27225,
 2814,
 23427,
 1789,
 27349,
 23774,
 22704,
 11378,
 15932,
 2353,
 17089,
 13403,
 9870,
 19150,
 15479,
 21781,
 4758,
 9381,
 13544,
 20599,
 8090,
 13644,
 8060,
 9522,
 24730,
 20495,
 5083,
 7926,
 1441,
 18657,
 19840,
 25549,
 12251,
 3052,
 18647,
 23703,
 4028,
 17202,
 20208,
 23764,
 23221]

In [43]:
session_id_list = []
item_id_list = []
rank_list = []
for session_id, session_item_list in tqdm(test_sessions_dict.items()):
    temp_item_list = recommand(session_item_list)
    session_id_list += [session_id for _ in range(100)]
    item_id_list += temp_item_list
    rank_list += [x for x in range(1,101)]

100%|████████████████████████████████████████████████████████████████████████████| 50000/50000 [09:14<00:00, 90.17it/s]


In [40]:
res_df = pd.DataFrame()
res_df['session_id'] = session_id_list
res_df['item_id'] = item_id_list
res_df['rank'] = rank_list
res_df.to_csv('baseline.csv',index=False)

# 参考资料
## KDD Debias
- https://www.logicjake.xyz/2020/06/16/KDD-debias-TOP13/#排序
- https://blog.csdn.net/fengdu78/article/details/106990993/
- https://tianchi.aliyun.com/forum/postDetail?spm=5176.12586969.1002.12.6c3f29e8MNZSeb&postId=104936
- https://tianchi.aliyun.com/forum/postDetail?postId=103530
- https://zhuanlan.zhihu.com/p/137085716