In [1]:
from keras.layers import Dot
from keras.layers import Input, Dense, Reshape
from keras.layers.embeddings import Embedding
from keras.models import Model
import keras.backend as K

Using TensorFlow backend.


In [2]:
# 词汇表大小
vocab_size = 5000
# 输出向量维度
embed_size = 300

# 构建词向量网络模型
word_input = Input(shape=(1, vocab_size))
word_embedding = Embedding(vocab_size, embed_size,
                         embeddings_initializer="glorot_uniform",
                         input_length=1)(word_input)
word_embedding_reshape = Reshape((embed_size,))(word_embedding)


# 构建上下文向量网络模型
context_input = Input(shape=(2, vocab_size))
context_embedding = Embedding(vocab_size, embed_size,
                         embeddings_initializer="glorot_uniform",
                         input_length=1)(context_input)
context_embedding_reshape = Reshape((embed_size,))(context_embedding)

# 合并
merge_layer = Dot(axes=1)([word_embedding_reshape, context_embedding_reshape])

outputs_layer = Dense(1, kernel_initializer="glorot_uniform", activation="sigmoid")(merge_layer)

model = Model(inputs=[word_input, context_input], outputs=outputs_layer)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 1, 5000)      0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 2, 5000)      0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 300)       1500000     input_1[0][0]                    
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, 1, 300)       1500000     input_2[0][0]                    
__________________________________________________________________________________________________
reshape_1 

In [3]:
from keras.utils import plot_model
plot_model(model, to_file='skip_gram.png', show_shapes=True)