Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorporating Word Vectors rather than using an Embedding class #853

Closed
vindiesel opened this issue Oct 19, 2015 · 111 comments
Closed

Incorporating Word Vectors rather than using an Embedding class #853

vindiesel opened this issue Oct 19, 2015 · 111 comments

Comments

@vindiesel
Copy link

I am solving an NLP task and I am trying to model it directly as a sequence using different RNN flavors. How can I use my own Word Vectors rather than using an instance of layers.embeddings.Embedding?

@sergeyf
Copy link

sergeyf commented Oct 20, 2015

UPDATED NOV 13, 2017

You have to pass a weight matrix to the Embedding layer. Here is an example:

Let's say index_dict is a dictionary that maps all the words in your dictionary to indices from 1 to n_symbols (0 is reserved for the masking).

So, an example index_dict is the following:

{
 'yellow': 1,
 'four': 2,
 'woods': 3,
 'ornate': 31,
 'woody': 5,
 'cyprus': 6,
 'marching': 7,
 'canes': 8,
 'caned': 9,
 'hermann': 10,
 'lord': 11,
 'meadows': 12,
 'shaving': 13,
 'swivel': 14
...
}

And you also have a dictionary called word_vectors that maps words to vectors like so:

{
 'yellow': array([0.1,0.5,...,0.7]),
 'four': array([0.2,1.2,...,0.9]),
...
}

The following code should do what you want

# assemble the embedding_weights in one numpy array
vocab_dim = 300 # dimensionality of your word vectors
n_symbols = len(index_dict) + 1 # adding 1 to account for 0th index (for masking)
embedding_weights = np.zeros((n_symbols, vocab_dim))
for word,index in index_dict.items():
    embedding_weights[index, :] = word_vectors[word]

# define inputs here
embedding_layer = Embedding(output_dim=vocab_dim, input_dim=n_symbols, trainable=True)
embedding_layer.build((None,)) # if you don't do this, the next step won't work
embedding_layer.set_weights([embedding_weights])

embedded = embedding_layer(input_layer)
# ... continue model definition here

Note that this kind of setup will result in your embeddings being trained from their initial point! If you want them fixed, then you have to set trainable=False.

@farizrahman4u
Copy link
Contributor

No need of skipping the embedding layer.. setting word vectors as the initial weights of embedding layer is a valid approach. The word vectors will get fine tuned for the specific NLP task during training.

@dandxy89
Copy link

Has anybody else attempted to embedded the word vectors into a model?

I've managed to create the model however Im not able to achieve a worthwhile level of accuracy yet. I've used the "20 newsgroups dataset" from scikit-learn to test this model, with my own w2v vectors. The best accuracy I've achieved so far is 28%, over 5 epochs, which is not great (scikit script best - 85%). I intend to continue experimenting with the network configuration (inner dimensions and epochs initially). I do suspect that the number of dimentions is too high for such a small dataset (1000 samples).

Will update again if the results improve

@viksit
Copy link

viksit commented Nov 28, 2015

From what I've seen, training your own vectors on top of a custom dataset has given me much better accuracy within that domain.

That said - any updates on this?

@farizrahman4u
Copy link
Contributor

@viksit

There are 3 approaches:

  • Learn embedding from scratch - simply add an Embedding layer to your model
  • Fine tune learned embeddings - this involves setting word2vec / GloVe vectors as your Embedding layer's weights.
  • Use word word2vec / Glove word vectors as inputs to your model, instead of one-hot encoding.

The third one is the best option(Assuming the word vectors were obtained from the same domain as the inputs to your models. For e.g, if you are doing sentiment analysis on tweets, you should use GloVe vectors trained on tweets).

In the first option, everything has to be learned from scratch. You dont need it unless you have a rare scenario. The second one is good, but, your model will be unnecessarily big with all the word vectors for words that are not frequently used.

@viksit
Copy link

viksit commented Nov 29, 2015

@farizrahman4u agreed on those counts. The domains are a bit more specific and I have a lot more luck with option (2) than with (1) or (3) so far. An easy way to address the size problem with (2) is to prune out the vocabulary itself to the top k words.

@XuesongYang
Copy link

@farizrahman4u Thanks for sharing the ideas. I have a question on your first approach.

Learning embeddings from scratch: each word in the dictionary is represented as one-hot vector, and then this vector is embedded as a continuous vector after applying embedding layer. Is that right?

@farizrahman4u
Copy link
Contributor

@MagicYoung Yes.

@dandxy89
Copy link

To follow up from @sergeyf suggestions... @viksit @MagicYoung @vindiesel

Take a look at my attempt (Script and Data) - it uses Gensim. This could easily all be done in Keras using the skipgram of course to remove dependencies...

@sergeyf
Copy link

sergeyf commented Dec 17, 2015

Looks cool! Are the results to your liking?

@dandxy89
Copy link

Results after 2 epochs:
Validation Accuracy: 0.8485
Loss: 0.3442

@liyi193328
Copy link

@dandxy89
Wonderful examples for me.I'll try it in my context. Thanks.
And if I want to feed pre-trained word2vec to lstm directly, how to handle different sequence lengths?
My trying is: setting maxlen=200(every sequence) and word2vec dim=600, if some sequence's length is 100, then the first 100 rows([0:100)) has float numbers(every row is a word vector), then [100:200) rows is padding with zeors, which means the remaining rows is all zeros. But after doing that, I get loss NaN, which confuses me for a long time. like the issue #1360 , what I can do then?
Thanks.

@viksit
Copy link

viksit commented Dec 28, 2015

@liyi193328 are you using keras' sequence.pad_sequences(myseq, maxlen=maxlen)?

Looks like your padding is to the right of the vectors where as it should be to the left.

from keras.preprocessing import sequence
In [69]: sequence.pad_sequences([[1,2], [1,2,3]], maxlen=10)
Out[69]:
array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2],
       [0, 0, 0, 0, 0, 0, 0, 1, 2, 3]], dtype=int32)

Secondly, you should be using categorical_crossentropy as your loss as opposed to mean_squared_error. See #321

@liyi193328
Copy link

@viksit Thanks.
I don't use sequence.pad_sequence. I pad zeros for last rows manually.
Because the input of pad_sequence is 2D array, setting A, then A[i,j] is the index of a word in vocabulary, The each row of A represents a sentences; But in my case, every word is a 600 dim vector, not a index. So I can't use it.
I don't think is what I think right? How I can pad zeros then?
Thanks.

@viksit
Copy link

viksit commented Dec 28, 2015

You need to pad before you convert the words to vectors (presumably you have a step where you have only word indexes).

@liyi193328
Copy link

@viksit Thanks.
But How do I change the word index to word vector ?
Actually I don't use word indexes, I use word vector directly.
And More specifically, three sentences like:

[ [ He, like, keras], [learning], [like, keras] ]

so Index 2D array is [ [1,2,3], [4], [2,3] ] ---padding---> [ [1,2,3], [0,0,4], [0,2,,3] ](index of each is given)
the word vector(4 dim) each is:

He -> [1,1,1,1], like->[2,2,2,2], keras->[3,3,3,3], learning->[5,5,5,5]

then after padding ,the 3D array shape is (2,3,4), like:

[
[ [1,1,1,1], [2,2,2,2], [3,3,3,3] ],
[ [0,0,0,0], [0,0,0,0], [5,5,5,5] ],
[ [0,0,0,0], [2,2,2,2], [3,3,3,3] ]
]

if the specific example right?
Thanks.

@viksit
Copy link

viksit commented Dec 29, 2015

Thats correct.

@liyi193328
Copy link

@viksit Thanks.
My mistake in initialing array with np.empty. should use np.zeros. Everything goes well now.

@taozhijiang
Copy link

@dandxy89 @farizrahman4u @viksit

Do you mean, when initialize Embedding Layer with weights learned for other corpus calculated word2vec, the LSTM model can quickly convergence?

Here is my code, just iter 2-3 times will get high score.

https://github.com/taozhijiang/chinese_nlp/blob/master/DL_python/dl_segment_v2.py

@talentlei
Copy link

Thanks for sharing. I benefit a lot. @

@ngopee
Copy link

ngopee commented Jan 31, 2016

Follow @liyi193328 comment:

If this is my X_train:
[
[ [1,1,1,1], [2,2,2,2], [3,3,3,3] ],
[ [0,0,0,0], [0,0,0,0], [5,5,5,5] ],
[ [0,0,0,0], [2,2,2,2], [3,3,3,3] ]
]

How should I structure my Y_train given each word (word-level training) will have it's own tag(Y and is also multi class). Is it like:

[
[1,2,3],
[0,0,3],
[0,2,1]
] ?

Because I am having an error:
"Exception: All input arrays and the target array must have the same number of samples."

Thank you very much!

@liyi193328
Copy link

@ngopee
For multi classes, if the label class is 2(total 3 classes), then it must be transformmed as 1D array [0,0,1];
Specifically if the sentence has x tokens. if every token has a label and has y classes, then all the labels's(Y_train) shape is (x,y), 2D array.
May you solve problemes.

@ngopee
Copy link

ngopee commented Feb 5, 2016

@liyi193328

Thank you very much for your reply!

Yes, I realised that later on but now having another issue which I'm not sure how to fix:

"if every token has a label and has y classes, then all the labels's(Y_train) shape is (x,y), 2D array". I'm not sure I follow this part. But below is what I have so far.

So, in my case, each token(Word vector) has is own tag.

Here is a sample of my input:

X_train =
[
 [ 8496  1828  …5447]
 [ 9096  8895  …13890]
 [ 5775   115 … 15037]
 [ 6782  9918  …  5048]
]


Y_train=
[
array([[ 0.,  0.,  0.,  1.], [ 0.,  0.,  1.,  0.], …[ 0.,  0.,  0.,  1.]]), 
array([[ 0.,  0.,  1.,  0.], [ 0.,  0.,  0.,  1.],…[ 0.,  0.,  1.,  0.]]), 
array([[ 0.,  0.,  1.,  0.], [ 0.,  0.,  0.,  1.], …[ 0.,  0.,  1.,  0.]]), 
array([[ 0.,  1.,  0.,  0.], [ 0.,  1.,  0.,  0.], …[ 0.,  1.,  0.,  0.]])
]

I am getting this error:

AssertionError: Theano Assert failed!
Apply node that caused the error: Assert(Elemwise{Composite{(i0 - EQ(i1, i2))}}.0, Elemwise{eq,no_inplace}.0)
Inputs types: [TensorType(int8, matrix), TensorType(int8, scalar)]
Inputs shapes: [(1, 100), ()]
Inputs strides: [(100, 1), ()]
Inputs values: ['not shown', array(0, dtype=int8)]

Here is my code:

vocab_dim = 300
maxlen = 100
batch_size = 1
n_epoch = 2

print('Keras Model...')
model = Sequential()  # or Graph or whatever
model.add(Embedding(output_dim=vocab_dim,
                    input_dim=n_symbols + 1,
                    mask_zero=True,
                    weights=[embedding_weights])) 
model.add(LSTM(vocab_dim, return_sequences=True))
model.add(Dropout(0.3))
model.add(TimeDistributedDense(input_dim=vocab_dim, output_dim=1))

print('Compiling the Model...')
model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              class_mode='categorical')

print("Train...")
model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=n_epoch,
          validation_data=(X_test, y_test), show_accuracy=True)

print("Evaluate...")
score, acc = model.evaluate(X_test, y_test,
                            batch_size=batch_size,
                            show_accuracy=True)
print('Test score:', score)
print('Test accuracy:', acc)

Thank you very much!

@liyi193328
Copy link

@ngopee
model.add(TimeDistributedDense(input_dim=vocab_dim, output_dim=1)): output_dim == 1?
I think the number is the number of classes, while I do not read the whole context.

@ngopee
Copy link

ngopee commented Feb 6, 2016

@liyi193328

Thank you very much for pointing that out. That would have been yet another mistake.

However this did not seem to fix the issue I previously had. Any insight on what I could be doing wrong?

Thanks!

@liyi193328
Copy link

@ngoee
Sorry for late reply! The meaning of your code is a many-to-one classification, but your actual goal is many-to-many. So it needs to change the logic.

@MaratZakirov
Copy link

I still do not understand can I use Embedding layer and NOT TUNE IT AT ALL just as the saving memory mechanism?

@prhbrt
Copy link

prhbrt commented Jan 13, 2017

An Embedding Layer has weights, and they need to be tuned: https://github.com/fchollet/keras/blob/master/keras/layers/embeddings.py#L95

There are, however, pretrained word2vec models, which are already trained and hence need not to be retrained if they fit your needs.

@Cospel
Copy link

Cospel commented Feb 1, 2017

What should be the padding vector if i am using pretrained word2vec from google?

Should i use word like 'stop' and transform it to vector with google word2vec model or should i use just vector of zeros?

@vijay120
Copy link

What happened to the weights argument to the Embedding function? I am following the tutorial on the blog here: https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html where we pass in the pre-trained embedding matrix. However, in the latest keras version, the weights function has been deleted. The PR that deleted it is here: 023331e

@MaratZakirov
Copy link

@prinsherbert This approach may lead to over-fitting. For example you have 100000 words (pretty small vocabulary) all with 100 vector size if they are trainable you will get 10 millions free parameters "from nothing".

@prhbrt
Copy link

prhbrt commented Mar 22, 2017

@MaratZakirov Yes, it is wise to pretrain or just train with a large corpus such as Wikipedia to prevent overfitting. And yes, you need a high document to word ratio.

@MadhumitaSushil
Copy link

From what I understand, the Embedding layer in Keras performs a lookup for a word index present in an input sequence, and replaces it with the corresponding vector through an embedding matrix.

However, what I am confused about it what happens when we want to test/apply a model on unknown data? For example, if there is a word in the test document which is not present in the training vocabulary, we could compute the corresponding vectors using character n-grams using pre-trained Fast-text model. However, this term would not be present in the word_index that was generated while training the model, and a lookup in the embedding matrix would fail.

One possible solution can be to create a word_index from the entire dataset, including the test data, as done for this example: https://github.com/fchollet/keras/blob/master/examples/pretrained_word_embeddings.py
However, I would like to avoid that so that the model is applicable to unknown data.

Any suggestion for workarounds for that?

@prhbrt
Copy link

prhbrt commented Mar 23, 2017

I think unknown words in general are mapped to random vectors, but this is not what the Embedding-layer does. In text processing people often consider the vocabulary prior knowledge. And if your model is word based, you will unlikely learn about words not seen in the training data anyway.

Tweets are examples of data with many words that you'll likely will not encounter during training, but that do appear in testing, because people make typos. More importantly, they also add (hash) tags, which are typically some concatenation of words. If you search the literature, you'll notice many tweet-classifiers use a character level convolutional layer, and than some classifier on top of that (like LSTM).

So if you want to generalize to new words, consider character level features/classifiers.

@BrianMiner
Copy link

It seems like these responses ignore the main issue, is if weights argument was removed in Keras 2.0?

@MadhumitaSushil
Copy link

@BrianMiner Yes, it seems like embedding weights have indeed been removed. My guess is it can still be used via the option 'embeddings_initializer' and coding a custom initializer which returns embedding weights. @vijay120

@prinsherbert Yes, I am aware of such techniques. I was only wondering of an elegant way to add embeddings of OOVs based on character n-gram embeddings using Fast-text during test to avoid completely ignoring OOV terms in a word based model.

@monod91
Copy link

monod91 commented Jun 6, 2017

Hello @Madhumita-Git @prinsherbert

I also have a similar scenario, where I want to use a character-based model for another NLP task.
Basically, my input data is a 3D tensor containing n sentences, each containing m words, each represented as a vector of o characters. So, 1sr dimension = batch, 2nd dimension = temporal dimension (max length of sentence), 3rd dimension = max characters of the word

So that my model starts with:
word_input = Input((self.max_length, self.max_word_length))
Now I would like to use characters embeddings on the character level (and on the top of this, using 1D convolution+MaxPooling to obtain a fixed-size vector representation of the word, in a similar way as in this paper: "Learning Character-level Representations for Part-of-Speech Tagging" http://proceedings.mlr.press/v32/santos14.pdf).

Any idea about how I could use an Embedding layer in such way in keras?

@Huzefa-Calcutta
Copy link

If you are working on large data, it is recommended to directly use word vectors as an input to LSTM layer rather than having an embedding layer. This avoids matrix multiplication which takes a lot of time if there are more sequences.

@naisanza
Copy link

naisanza commented Sep 2, 2017

@sergeyf

Hi! Is it just me or would your embedding_weights numpy zeros array be a +1 too wide?

Since you've already set the width to be n_symbols = len(index_dict) + 1 to account for 0th index

But in your embedding_weights it's embedding_weights = np.zeros((n_symbols+1,vocab_dim)), which would be the same as the original n_symbols = len(index_dict) + 2? Why the extra +1 to length?

vocab_dim = 300 # dimensionality of your word vectors
n_symbols = len(index_dict) + 1 # adding 1 to account for 0th index (for masking)
embedding_weights = np.zeros((n_symbols+1,vocab_dim))
for word,index in index_dict.items():
    embedding_weights[index,:] = word_vectors[word]

# assemble the model
model = Sequential() # or Graph or whatever
model.add(Embedding(output_dim=rnn_dim, input_dim=n_symbols + 1, mask_zero=True, weights=[embedding_weights])) # note you have to put embedding weights in a list by convention
model.add(LSTM(dense_dim, return_sequences=False))  
model.add(Dropout(0.5))
model.add(Dense(n_symbols, activation='softmax')) # for this is the architecture for predicting the next word, but insert your own here

@sergeyf
Copy link

sergeyf commented Sep 3, 2017 via email

@JakSla
Copy link

JakSla commented Sep 9, 2017

Hello!

Following this:
https://blog.keras.io/using-pre-trained-word-embeddings-in-a-keras-model.html
I tried to use pretrained word embeddings with my Embedding layer in Keras.
Yet I am getting this error:
ValueError: Layer weight shape (10000, 100) not compatible with provided weight shape (88585, 100)
at this line:
model.add(Embedding(max_features, 100, input_length=max_review_length,mask_zero=True, weights=[embedding_matrix]))

From what I see Keras 2 + is not supporting embedding weights (yes?).
I've tried older Keras 1.2 and 1.1.2 versions, but they still gave me the same error.

Anyone can advise whether I am doing something wrong?
Or what would be the proper way to use my own embeddings in Embedding layer?

Thanks!
Providing the code I am using below:

from keras.datasets import imdb
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Convolution1D, Flatten, Dropout
from keras.layers.embeddings import Embedding
from keras.preprocessing import sequence
from keras.callbacks import TensorBoard
from gensim.models import word2vec
import numpy as np
import os

import keras
#Using keras to load the dataset with the top_words
max_features = 10000 #max number of words to include, words are ranked by how often they occur (in training set)
max_review_length = 1600

(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features)
print 'loaded dataset...'
#Pad the sequence to the same length
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length)

index_dict = keras.datasets.imdb.get_word_index()

print 'loading glove...'
embeddings_index = {}
f = open(os.path.join('/home/ejaksla/PycharmProjects/MachineLearningPlayground/BachelorDegree/glove_word2vec/glove.6B/', 'glove.6B.100d.txt'))
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()

print 'creating embedding matrix...'
embedding_matrix = np.zeros((len(index_dict) + 1, 100))
for word, i in index_dict.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embedding_vector
print('Found %s word vectors.' % len(embeddings_index))

print 'assembling model..'
# Using embedding from Keras
model = Sequential()
model.add(Embedding(max_features, 100, input_length=max_review_length,mask_zero=True, weights=[embedding_matrix]))

@DomHudson
Copy link
Contributor

Is there any concensus of whether @sergeyf's approach still works? It does indeed appear that the weights argument has been removed, but it's still being used in the examples here..

@AllardJM
Copy link

AllardJM commented Nov 13, 2017 via email

@sergeyf
Copy link

sergeyf commented Nov 13, 2017

Hi everyone,

Looks like this is still a references for some people. Here is what I do now with Keras 2.0.8:

def set_embedding_layer_weights(embedding_layer, pretrained_embeddings):
    dense_dim = pretrained_embeddings.shape[1]
    weights = np.vstack((np.zeros(dense_dim), pretrained_embeddings))
    embedding_layer.set_weights([weights])

# load up your pretrained_embeddings here 
d = pretrained_embeddings.shape[1] # should be np.array
embedding_layer = Embedding(output_dim=d, input_dim=n_vocab, trainable=True)
embedding_layer.build((None,)) # if you don't do this, the next step won't work
set_embedding_layer_weights(embedding_layer, pretrained_embeddings)

Note! This version assumes that the pretrained_embeddings array does not come with a mask first row, and explicitly make an all-zeros row for it here: weights = np.vstack((np.zeros(dense_dim), pretrained_embeddings)). If you already have a special mask row, then feel free to just do embedding_layer.set_weights([pretrained_embeddings])

Hope that helps.

@DomHudson
Copy link
Contributor

DomHudson commented Nov 15, 2017

Thanks for the reply both!

@AllardJM It doesn't error no. I've done a little investigation and with the following code I've come to the conclusion that it is being utilised and in fact, the original solution (using a weights kwarg) that @sergeyf posted does still work.

import numpy as np
from keras import initializers
from keras.layers import Embedding, LSTM, Dense
from keras.models import Sequential
weights = np.concatenate([
    np.zeros((1, 100)), # Masking row: all zeros.
    np.ones((1, 100)) # First word: all weights preset to 1.
]) 
print(weights)

array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.]])

layer = Embedding(
    output_dim = 100,
    input_dim = 2,
    mask_zero = True,
    weights = [weights],
)

model = Sequential([
    layer,
    LSTM(2, dropout = 0.2, activation = 'tanh'),
    Dense(1, activation = 'sigmoid')
])

model.compile(
    optimizer = 'adam',
    loss = 'binary_crossentropy',
    metrics = []
)
print(layer.get_weights())

[array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)]


The recent solution above also appears to work, but is probably more inefficient as it's initialising the weights and then overwriting them. This can be seen as shown:

layer = Embedding(
    output_dim = 100,
    input_dim = 2,
    mask_zero = True
)
layer.build((None,))
print(layer.get_weights())

[array([[ -2.64064074e-02, -4.05902900e-02, -1.71032399e-02,
6.36395207e-03, 4.03554030e-02, -2.91514937e-02,
-3.05371974e-02, 1.60062015e-02, -4.58858572e-02,
-2.71607353e-03, -6.45029533e-04, -3.60430926e-02,
-4.47065122e-02, -4.46958952e-02, 8.49759020e-03,
-2.07597855e-02, -4.63474654e-02, -4.47412431e-02,
.....

layer.set_weights([weights])
print(layer.get_weights())

[array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)]

@BrianMiner
Copy link

BrianMiner commented Nov 15, 2017 via email

@aksg87
Copy link

aksg87 commented Dec 18, 2017

Is there standard code or a function that takes a model built in gensim word2vec and converts it into the dictionary format's (i.e. index_dict and word_vectors the first comment above)? Otherwise I will write my code for this but that seems much less efficient.

Thanks!

--
So, an example index_dict is the following:

{
'yellow': 1,
'four': 2,
'woods': 3,
'ornate': 31,
'woody': 5,
'cyprus': 6,
'marching': 7,
'canes': 8,
'caned': 9,
'hermann': 10,
'lord': 11,
'meadows': 12,
'shaving': 13,
'swivel': 14
...
}
And you also have a dictionary called word_vectors that maps words to vectors like so:

{
'yellow': array([0.1,0.5,...,0.7]),
'four': array([0.2,1.2,...,0.9]),
...
}

@DomHudson
Copy link
Contributor

@aksg87 You could use the gensim.models.keyedvectors.KeyedVectors.get_keras_embedding method?

The KeyedVectors instance is accessible from a Word2Vec instance via the wv attribute, for example:

model = Word2Vec.load(fname)
embedding_layer = model.wv.get_keras_embedding(train_embeddings=True)

Source: https://github.com/RaRe-Technologies/gensim/blob/develop/gensim/models/keyedvectors.py#L1048

@aksg87
Copy link

aksg87 commented Dec 19, 2017

Thank you so much for your reply. I ended up finding some examples and wrote it out:

I'll have to try the version you provided.

Also, if you set train_embeddings=True the weights in the layer will change from the word2vect output is this advisable in general?

The code I wrote to do this:

load the whole embedding into memory

embeddings_index = dict()
f = open('vectors.txt')
for line in f:
values = line.split()
word = values[0]
coefs = asarray(values[1:], dtype='float32')
embeddings_index[word] = coefs
f.close()
print('Loaded %s word vectors.' % len(embeddings_index))

dim_len = len(coefs)
print('Dimension of vector %s.' % dim_len)

create a weight matrix for words in training docs

embedding_matrix = zeros((vocab_size, dim_len))
for word, i in tqdm(t.word_index.items()):
embedding_vector = embeddings_index.get(word)

if embedding_vector is not None and np.shape(embedding_vector) != (202,):
	embedding_matrix[i] = embedding_vector		
if np.shape(embedding_vector) == (202,):
	print(i)
	print("embedding_vector", np.shape(embedding_vector))
	print("embedding_matrix", np.shape(embedding_matrix[i]))

@aksg87
Copy link

aksg87 commented Dec 19, 2017

Another question I have is my final output is a softmax prediction on several classes (396 to be exact).

The output vector is messy (see below).

Is their a clean way to both 1) convert this into the top 3 labels predicted and 2) write a custom accuracy function which checks how often the softmax predicts the top 3?

array([ 2.74735111e-22, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00, 3.84925198e-38, 0.00000000e+00,
1.72161353e-34, 1.86862336e-26, 6.87889553e-07,
1.09056833e-04, 1.17705227e-26, 6.17638065e-08,
6.54662412e-23, 3.28686365e-05, 4.67332768e-08,
0.00000000e+00, 5.22176857e-10, 4.09760102e-38,
0.00000000e+00, 5.86631461e-17, 1.14025260e-08,
4.42352757e-07, 8.37238900e-08, 0.00000000e+00,
1.48040133e-14, 3.42079135e-14, 2.47516301e-20,
...

@DomHudson
Copy link
Contributor

DomHudson commented Dec 19, 2017

Also, if you set train_embeddings=True the weights in the layer will change from the word2vect output is this advisable in general?

I don't think there's a 'correct' answer to this - it's up to you and the problem you're modelling. By having a trainable embeddings layer the weights will be tuned for the model's NLP task. This will give you more domain specific weights at the cost of increased training time.

It's quite common to train initial weights on a large corpus (or to use a pre-trained third party model) and then use that to seed your embedding layer. In this case you will likely find benefit if you do train the embeddings layer with the model. However, if you've trained your own Word2Vec model on exactly the domain you're modelling, you may find that the difference in results is negligible and that training the layer is not preferential over a shorter training time.

Is their a clean way to convert this into the top 3 labels predicted

To do this you could use numpy's argpartition method.

>>> predictions = np.array([0.1, 0.3, 0.2, 0.4, 0.5])
>>> top_three_classes = np.argpartition(predictions, -3)[-3:]
>>> top_three_classes
array([1, 3, 4])

Write a custom accuracy function which checks how often the softmax predicts the top 3?

Yes this should be fairly straightforward utilising the above logic and a custom metric class or function.

@aksg87
Copy link

aksg87 commented Dec 20, 2017

  1. Thank so much for your reply! I discovered the np.argpartition function soon after my post and it worked perfectly.

To calculate accuracy, I created a few functions and used Map to apply them on my prediction which essentially tell me how often my model's 'Top 3' prediction contains the true answer. (At the very end I basically counted 'True' vs 'False' to arrive at a percentage. I thought Keras might have a way to overwrite their Accuracy function but didn't see a way.)

  1. Now, I am now incorporating multiple inputs into the model (but will apply aggressive dropout to all of them so the model should work on even some of them). What I would love to have, is the model assume a blank input or perhaps a vector of zeros, if no input is provided during the model.predict instead of throwing an error. Is their any way to do this or should I hardcode a vector of zeros if no input is provided? -- Thanks again for all the awesome feedback.

fchollet pushed a commit that referenced this issue Sep 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests