In [98]:
import pandas as pd
import numpy as np

import tqdm
import datetime

import tensorflow as tf
from tensorflow.keras import Model,Sequential
from tensorflow.keras.layers import Activation,Dense,Dot,Flatten,GlobalAveragePooling1D,Reshape,Embedding
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

In [62]:
print(tf.__version__)

2.4.1


In [63]:
movies=pd.read_csv(r'./ml-latest-small/movies.csv')
data=pd.read_csv(r'./ml-latest-small/ratings.csv')

In [64]:
movies.head()

Unnamed: 0,movieId,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama|Romance
4,5,Father of the Bride Part II (1995),Comedy


In [65]:
data.head()

Unnamed: 0,userId,movieId,rating,timestamp
0,1,1,4.0,964982703
1,1,3,4.0,964981247
2,1,6,4.0,964982224
3,1,47,5.0,964983815
4,1,50,5.0,964982931


In [66]:
# 过滤到评分低于3.5分以下的
data=data[data.rating>=3.5]
data=data.merge(movies,left_on='movieId',right_on='movieId')
data=data.sort_values(['userId','timestamp'])
data.reset_index(drop=True,inplace=True)

In [67]:
data.head()

Unnamed: 0,userId,movieId,rating,timestamp,title,genres
0,1,804,4.0,964980499,She's the One (1996),Comedy|Romance
1,1,1210,5.0,964980499,Star Wars: Episode VI - Return of the Jedi (1983),Action|Adventure|Sci-Fi
2,1,2018,5.0,964980523,Bambi (1942),Animation|Children|Drama
3,1,2628,4.0,964980523,Star Wars: Episode I - The Phantom Menace (1999),Action|Adventure|Sci-Fi
4,1,2826,4.0,964980523,"13th Warrior, The (1999)",Action|Adventure|Fantasy


In [68]:
mvlist=data.groupby('userId')['title'].apply(list)

In [69]:
sequences=list(mvlist)

In [70]:
voc_data=data.groupby(by='title').count().sort_values(by='userId',ascending=False)

In [71]:
voc_data

Unnamed: 0_level_0,userId,movieId,rating,timestamp,genres
title,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"Shawshank Redemption, The (1994)",289,289,289,289,289
Forrest Gump (1994),276,276,276,276,276
Pulp Fiction (1994),256,256,256,256,256
"Matrix, The (1999)",240,240,240,240,240
"Silence of the Lambs, The (1991)",239,239,239,239,239
...,...,...,...,...,...
Honey (Miele) (2013),1,1,1,1,1
Honey (2003),1,1,1,1,1
Home for the Holidays (1995),1,1,1,1,1
Home Alone 3 (1997),1,1,1,1,1


In [72]:
voc_index={}
index_voc={}
i=0
for movie in list(voc_data.index):
    voc_index[movie]=i
    index_voc[i]=movie
    i+=1

In [73]:
voc_index

{'Shawshank Redemption, The (1994)': 0,
 'Forrest Gump (1994)': 1,
 'Pulp Fiction (1994)': 2,
 'Matrix, The (1999)': 3,
 'Silence of the Lambs, The (1991)': 4,
 'Star Wars: Episode IV - A New Hope (1977)': 5,
 'Fight Club (1999)': 6,
 "Schindler's List (1993)": 7,
 'Braveheart (1995)': 8,
 'Raiders of the Lost Ark (Indiana Jones and the Raiders of the Lost Ark) (1981)': 9,
 'Star Wars: Episode V - The Empire Strikes Back (1980)': 10,
 'American Beauty (1999)': 11,
 'Usual Suspects, The (1995)': 12,
 'Terminator 2: Judgment Day (1991)': 13,
 'Lord of the Rings: The Fellowship of the Ring, The (2001)': 14,
 'Jurassic Park (1993)': 15,
 'Star Wars: Episode VI - Return of the Jedi (1983)': 16,
 'Toy Story (1995)': 17,
 'Godfather, The (1972)': 18,
 'Saving Private Ryan (1998)': 19,
 'Lord of the Rings: The Return of the King, The (2003)': 20,
 'Seven (a.k.a. Se7en) (1995)': 21,
 'Lord of the Rings: The Two Towers, The (2002)': 22,
 'Fugitive, The (1993)': 23,
 'Fargo (1996)': 24,
 'Back to

In [74]:
index_voc

{0: 'Shawshank Redemption, The (1994)',
 1: 'Forrest Gump (1994)',
 2: 'Pulp Fiction (1994)',
 3: 'Matrix, The (1999)',
 4: 'Silence of the Lambs, The (1991)',
 5: 'Star Wars: Episode IV - A New Hope (1977)',
 6: 'Fight Club (1999)',
 7: "Schindler's List (1993)",
 8: 'Braveheart (1995)',
 9: 'Raiders of the Lost Ark (Indiana Jones and the Raiders of the Lost Ark) (1981)',
 10: 'Star Wars: Episode V - The Empire Strikes Back (1980)',
 11: 'American Beauty (1999)',
 12: 'Usual Suspects, The (1995)',
 13: 'Terminator 2: Judgment Day (1991)',
 14: 'Lord of the Rings: The Fellowship of the Ring, The (2001)',
 15: 'Jurassic Park (1993)',
 16: 'Star Wars: Episode VI - Return of the Jedi (1983)',
 17: 'Toy Story (1995)',
 18: 'Godfather, The (1972)',
 19: 'Saving Private Ryan (1998)',
 20: 'Lord of the Rings: The Return of the King, The (2003)',
 21: 'Seven (a.k.a. Se7en) (1995)',
 22: 'Lord of the Rings: The Two Towers, The (2002)',
 23: 'Fugitive, The (1993)',
 24: 'Fargo (1996)',
 25: 'Bac

In [75]:
sequences_int=[]
for seq in sequences:
    movielist=[]
    for movie in seq:
        movielist.append(voc_index[movie])
    sequences_int.append(movielist)
        

In [76]:
sequences_int[:5]

[[2970,
  16,
  773,
  144,
  1561,
  30,
  797,
  1853,
  911,
  450,
  11,
  6189,
  106,
  258,
  576,
  1,
  392,
  180,
  413,
  142,
  381,
  91,
  1531,
  1037,
  232,
  427,
  696,
  1070,
  621,
  213,
  941,
  458,
  678,
  200,
  109,
  126,
  41,
  3448,
  907,
  1334,
  184,
  238,
  5,
  83,
  86,
  198,
  1005,
  147,
  467,
  403,
  538,
  758,
  146,
  681,
  1890,
  1413,
  859,
  501,
  2080,
  2636,
  4533,
  10,
  9,
  40,
  153,
  63,
  35,
  521,
  1393,
  19,
  3,
  704,
  23,
  175,
  93,
  36,
  497,
  8,
  1947,
  241,
  325,
  703,
  221,
  112,
  821,
  339,
  1171,
  97,
  1060,
  302,
  1316,
  430,
  314,
  194,
  700,
  485,
  1607,
  15,
  1051,
  3559,
  343,
  111,
  1245,
  923,
  693,
  1031,
  137,
  165,
  751,
  330,
  118,
  43,
  161,
  1269,
  814,
  3166,
  577,
  610,
  4427,
  885,
  641,
  1222,
  1010,
  1868,
  17,
  320,
  439,
  6581,
  930,
  719,
  2163,
  1197,
  1140,
  837,
  411,
  646,
  376,
  2183,
  1311,
  1582,
  808,
  27

In [77]:
vocab_size=len(voc_index)

In [78]:
vocab_size

7359

In [79]:
def generate_training_data(sequences,vocab_size,window_size=5, num_ns=5,seed=44):
    targets, contexts, labels = [], [], []
    # 采样时对zipf分布数据加权,对于高频的以更小的概率被采样到。
    #sampling_table = tf.keras.preprocessing.sequence.make_sampling_table(vocab_size)
    for seq in tqdm.tqdm(sequences):
        # 对每一个会话生成正样本
        positive_skip_grams,_=tf.keras.preprocessing.sequence.skipgrams(
                                seq,
                                vocabulary_size=vocab_size,
                                #sampling_table=sampling_table #不做限制时表示采样窗口内的所有正样本，否则按照sampling_table中的概率进行采样
                                window_size=window_size,
                                negative_samples=0 # 不进行负采样  
                            )
        # 按照log-uniform 进行采样,采样时更加偏向稀有样本
        for target_word,context_word in positive_skip_grams:
            #[1,2]扩展维度[[1],[2]]:(1,)->(1,1)
            #正样本y
            context_class=tf.expand_dims(tf.constant([context_word],dtype='int64'),1)
            #负样本y
            negative_sampling_candidates,_,_=tf.random.log_uniform_candidate_sampler(
                true_classes=context_class,
                num_true=1,
                num_sampled=num_ns,
                unique=True,
                range_max=vocab_size,
                seed=seed,
                name='negative_sampling'
            )
            
            negative_sampling_candidates=tf.expand_dims(negative_sampling_candidates,1)
            
            # 正负样本y拼接
            context=tf.concat([context_class,negative_sampling_candidates],0)
            
            #创建标签
            label=tf.constant([1]+[0]*num_ns,dtype='int64')
            
            targets.append(target_word)
            contexts.append(context)
            labels.append(label)
            
    return targets,contexts,labels

In [80]:
targets,contexts,labels=generate_training_data(sequences_int,vocab_size,window_size=5, num_ns=5,seed=44)

100%|██████████| 609/609 [00:34<00:00, 17.86it/s]


In [81]:
targets[0]

10

In [82]:
contexts[0]

<tf.Tensor: shape=(6, 1), dtype=int64, numpy=
array([[ 501],
       [3488],
       [  50],
       [ 897],
       [   1],
       [   8]])>

In [83]:
labels[0]

<tf.Tensor: shape=(6,), dtype=int64, numpy=array([1, 0, 0, 0, 0, 0])>

In [84]:
print(len(targets),len(contexts),len(labels))

593842 593842 593842


In [85]:
579*1024

592896

In [86]:
BATCH_SIZE = 1024
BUFFER_SIZE = 10000
# 打包成(input,output),对于有多个input时((input1,input2,...),label),其中input1
dataset = tf.data.Dataset.from_tensor_slices(((targets, contexts), labels))
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
print(dataset)

<BatchDataset shapes: (((1024,), (1024, 6, 1)), (1024, 6)), types: ((tf.int32, tf.int64), tf.int64)>


In [87]:
AUTOTUNE = tf.data.AUTOTUNE
# cache(),prefetch() 预加载提升速度
dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE) 
print(dataset)

<PrefetchDataset shapes: (((1024,), (1024, 6, 1)), (1024, 6)), types: ((tf.int32, tf.int64), tf.int64)>


In [89]:
num_ns=5
class Word2Vec(Model):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.target_embedding = Embedding(vocab_size, 
                                      embedding_dim,
                                      input_length=1,
                                      name="w2v_embedding", )
        self.context_embedding = Embedding(vocab_size, 
                                       embedding_dim, 
                                       input_length=num_ns+1)
        self.dots = Dot(axes=(3,2))
        self.flatten = Flatten()

    def call(self, pair):
        target, context = pair
        we = self.target_embedding(target)
        ce = self.context_embedding(context)
        dots = self.dots([ce, we])
        return self.flatten(dots)

In [90]:
targets[0]

10

In [93]:
embedding_dim = 128
word2vec = Word2Vec(vocab_size, embedding_dim)
word2vec.compile(optimizer='adam',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [102]:
# 生成日志文件名称
log_dir = "./logs/fit/item2vec" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [103]:
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

In [104]:
pwd

'/Users/vccandice/Documents/myGithub/item2vec'

In [106]:
word2vec.fit(dataset, epochs=10, callbacks=[tensorboard_callback])

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f9b43f53220>