In [6]:
import os
import pandas as pd
from datetime import datetime
from tqdm import tqdm
from collections import defaultdict
import math
from gensim.models import Word2Vec
import numpy as np
from sklearn.manifold import TSNE
import seaborn as sns
import random
from matplotlib import pyplot as plt
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
 

In [7]:
path = '../data/'
train_sessions = pd.read_csv(path+'train_sessions.csv')
train_purchases = pd.read_csv(path+'train_purchases.csv')


In [8]:
train_sessions.head()


Unnamed: 0,session_id,item_id,date
0,3,9655,2020-12-18 21:25:00.373
1,3,9655,2020-12-18 21:19:48.093
2,13,15654,2020-03-13 19:35:27.136
3,18,18316,2020-08-26 19:18:30.833
4,18,2507,2020-08-26 19:16:31.211


In [14]:
train_sessions['date'] = pd.to_datetime(train_sessions['date'])
train_purchases['date'] = pd.to_datetime(train_purchases['date'])


In [15]:
train_sessions['ts'] = train_sessions['date'].dt.to_pydatetime()
train_sessions['ts'] = train_sessions['ts'].apply(lambda x : x.timestamp())


In [16]:
train_purchases['ts'] = train_purchases['date'].dt.to_pydatetime()
train_purchases['ts'] = train_purchases['ts'].apply(lambda x : x.timestamp())


In [17]:
train_sessions = train_sessions.append(train_purchases)


In [18]:
df = train_sessions.sort_values(by=['session_id','ts'],ascending=True)
df


Unnamed: 0,session_id,item_id,date,ts
1,3,9655,2020-12-18 21:19:48.093,1.608326e+09
0,3,9655,2020-12-18 21:25:00.373,1.608327e+09
0,3,15085,2020-12-18 21:26:47.986,1.608327e+09
2,13,15654,2020-03-13 19:35:27.136,1.584128e+09
1,13,18626,2020-03-13 19:36:15.507,1.584128e+09
...,...,...,...,...
4743804,4440001,19539,2020-10-30 23:37:09.460,1.604101e+09
4743815,4440001,20409,2020-10-30 23:37:20.658,1.604101e+09
4743818,4440001,27852,2020-10-30 23:39:55.186,1.604101e+09
4743806,4440001,20449,2020-10-30 23:40:28.149,1.604101e+09


In [29]:
item_list = df['item_id'].unique()
user_list = df['session_id'].unique()


In [27]:
dic_item = defaultdict(list)
dic_user = defaultdict(list)


In [28]:
for x in tqdm(df[['session_id','item_id']].values):
    dic_item[f'item_{x[1]}'].append(f'user_{x[0]}')
    dic_user[f'user_{x[0]}'].append(f'item_{x[1]}')


100%|██████████| 5743820/5743820 [00:14<00:00, 386123.81it/s]


In [36]:
dic_item['item_9655']


['user_3',
 'user_3',
 'user_97085',
 'user_152155',
 'user_210806',
 'user_241795',
 'user_241795',
 'user_351746',
 'user_361368',
 'user_376975',
 'user_392308',
 'user_409797',
 'user_427384',
 'user_494934',
 'user_500224',
 'user_500224',
 'user_565440',
 'user_568659',
 'user_568659',
 'user_580810',
 'user_590009',
 'user_661938',
 'user_679781',
 'user_679781',
 'user_679781',
 'user_679929',
 'user_736041',
 'user_739306',
 'user_739306',
 'user_763132',
 'user_776357',
 'user_799484',
 'user_808882',
 'user_853913',
 'user_888461',
 'user_889362',
 'user_934553',
 'user_953098',
 'user_962579',
 'user_977860',
 'user_1054579',
 'user_1058683',
 'user_1060718',
 'user_1089357',
 'user_1158804',
 'user_1166260',
 'user_1194260',
 'user_1290510',
 'user_1315142',
 'user_1335512',
 'user_1335512',
 'user_1340718',
 'user_1375780',
 'user_1376628',
 'user_1420117',
 'user_1452140',
 'user_1452140',
 'user_1452140',
 'user_1513697',
 'user_1529794',
 'user_1606283',
 'user_1724434

In [34]:
#查找一个列表里面指定元素的index：np.where
random_user = 'user_3'
last_item = 'item_9655'
s = np.where(np.array(dic_user[random_user]) == last_item)
s


(array([0, 1]),)

In [43]:
path_length = 16
sentences = []
num_sentences = 1000000  #实际跑的时候建议50w+ (有2w个item)

'''
badcase:
    item_a : session_1
    session_1 : [item_b,item_a]
需要加一个max_repeat_time 避免死循环
'''

max_repeat_times = 2*path_length
for _ in tqdm(range(num_sentences)):
    start_item = 'item_{}'.format(random.choice(item_list))
    sentence = [start_item]
    repeat_time = 0
    while len(sentence) < path_length:
        last_item = sentence[-1]
        # 根据当前item随机选取一个session
        random_user = random.choice(dic_item[last_item])
        # 找到这个session中当前item的下一个item加入序列
        indices = np.where(np.array(dic_user[random_user]) == last_item)[0]
        if indices[-1] == len(dic_user[random_user])-1:
            next_item_index = None
        else:
            next_item_index = indices[-1] + 1
        
        if next_item_index:
            next_item = dic_user[random_user][next_item_index]
            sentence.append(next_item)
        repeat_time +=1
        if repeat_time > max_repeat_times:
            break
    sentences.append(sentence)


100%|██████████| 1000000/1000000 [02:04<00:00, 8018.55it/s]


In [44]:
len(sentences)


1000000

In [64]:
sentences[1]


['item_26941',
 'item_20202',
 'item_18536',
 'item_16576',
 'item_5390',
 'item_11691',
 'item_27937',
 'item_11762',
 'item_5649',
 'item_8279',
 'item_3873',
 'item_10716',
 'item_23689',
 'item_20505',
 'item_95',
 'item_3488']

In [60]:
import gensim
class LossLogger(gensim.models.callbacks.CallbackAny2Vec):
    def __init__(self):
        self.losses = []
        self.logfile = 'train_loss.log'

    def on_epoch_end(self, model):
        file = open(self.logfile, 'a')
        loss = model.get_latest_training_loss()
        self.losses.append(loss)
        print('Loss after epoch {}: {}'.format(len(self.losses), loss),file=file)
        file.close()


In [71]:

losslog = LossLogger()
model = Word2Vec(sentences,vector_size=64,epochs=20, compute_loss=True,callbacks=[losslog],sg=True)
print(model.get_latest_training_loss())


2023-12-18 20:17:55,998 : INFO : collecting all words and their counts
2023-12-18 20:17:56,001 : INFO : PROGRESS: at sentence #0, processed 0 words, keeping 0 word types
2023-12-18 20:17:56,095 : INFO : PROGRESS: at sentence #10000, processed 159139 words, keeping 19737 word types
2023-12-18 20:17:56,158 : INFO : PROGRESS: at sentence #20000, processed 318121 words, keeping 21731 word types
2023-12-18 20:17:56,220 : INFO : PROGRESS: at sentence #30000, processed 477491 words, keeping 22553 word types
2023-12-18 20:17:56,279 : INFO : PROGRESS: at sentence #40000, processed 636357 words, keeping 22998 word types
2023-12-18 20:17:56,337 : INFO : PROGRESS: at sentence #50000, processed 795427 words, keeping 23238 word types
2023-12-18 20:17:56,394 : INFO : PROGRESS: at sentence #60000, processed 954444 words, keeping 23379 word types
2023-12-18 20:17:56,452 : INFO : PROGRESS: at sentence #70000, processed 1113597 words, keeping 23468 word types
2023-12-18 20:17:56,510 : INFO : PROGRESS: at

122757560.0


https://zhuanlan.zhihu.com/p/194263854
word2vec使用说明

In [72]:
from tempfile import mkstemp
temp_path = mkstemp("word2vec")  # 创建一个temp文件
model_path = 'w2v.pt'
model.save(model_path)  # 保存模型

# # 加载模型
# new_model = Word2Vec.load(temp_path)  


2023-12-18 21:02:58,136 : INFO : Word2Vec lifecycle event {'fname_or_handle': 'w2v.pt', 'separately': 'None', 'sep_limit': 10485760, 'ignore': frozenset(), 'datetime': '2023-12-18T21:02:58.136070', 'gensim': '4.3.2', 'python': '3.10.9 (main, Mar  8 2023, 10:47:38) [GCC 11.2.0]', 'platform': 'Linux-6.2.0-33-generic-x86_64-with-glibc2.35', 'event': 'saving'}
2023-12-18 21:02:58,138 : INFO : not storing attribute cum_table
2023-12-18 21:02:58,204 : INFO : saved w2v.pt


In [75]:
model.wv.most_similar('item_9655',topn=20)


[('item_8420', 0.8633750677108765),
 ('item_4183', 0.8454250693321228),
 ('item_21804', 0.840990424156189),
 ('item_16934', 0.8408167362213135),
 ('item_24005', 0.83484947681427),
 ('item_15085', 0.832846999168396),
 ('item_21336', 0.824638843536377),
 ('item_2797', 0.8221443295478821),
 ('item_7874', 0.8206487894058228),
 ('item_10340', 0.8188933730125427),
 ('item_13969', 0.818332314491272),
 ('item_9582', 0.8181937336921692),
 ('item_24802', 0.8170543313026428),
 ('item_27457', 0.8162038326263428),
 ('item_18101', 0.8085795640945435),
 ('item_9273', 0.7969487309455872),
 ('item_23861', 0.7952935695648193),
 ('item_13052', 0.794991135597229),
 ('item_5030', 0.7899231314659119),
 ('item_6894', 0.7863765954971313)]

In [76]:
# 物品的embedding表示
model.wv["item_9655"]


array([ 0.5474127 ,  0.10113707,  0.06519491, -0.20486458, -0.44360718,
       -0.37973052,  0.19479772, -0.298581  , -0.31641227, -0.42177755,
        0.6568746 , -0.5503279 , -0.7552047 , -0.37866205, -0.3773759 ,
        0.31284085, -0.01405071, -0.29185936, -0.30418834,  0.5309945 ,
        0.7505086 ,  0.48758724,  0.683042  , -0.73200566, -0.5583907 ,
       -0.08589131,  0.22740822,  0.3769146 , -0.00691084, -0.02080834,
       -0.5460711 , -0.15433544, -0.29834038,  0.26222408, -0.0800554 ,
       -0.08941248, -0.21070714,  0.2098801 ,  0.43608046, -0.8117922 ,
        0.39504394, -0.14192897, -0.54636693, -0.62741274, -0.30292308,
        0.51958436,  0.2946213 , -0.5071558 , -0.25500712,  0.39684674,
       -0.0276969 ,  0.2900035 ,  0.27569714,  0.439764  ,  0.3187852 ,
       -0.22543488, -0.10924769, -0.1578216 , -0.06648488,  0.8571492 ,
       -0.18774283,  0.10959336, -0.05288515,  0.5622437 ], dtype=float32)