Skip to content

Commit

Permalink
[WIP] DL4J Embedding Layer - Word Vectors Initialization (#7173)
Browse files Browse the repository at this point in the history
* Embedding layer W2V init - first pass

* fixes, tests

* Javadoc, polish, more tests

* EmbeddingSequenceLayer*
  • Loading branch information
AlexDBlack committed Feb 15, 2019
1 parent 0675a93 commit 5da1e0d
Show file tree
Hide file tree
Showing 10 changed files with 485 additions and 1 deletion.
Expand Up @@ -16,7 +16,9 @@

package org.deeplearning4j.nn.layers.feedforward.embedding;

import lombok.EqualsAndHashCode;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
Expand All @@ -29,8 +31,10 @@
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
Expand Down Expand Up @@ -516,4 +520,87 @@ public void testEmbeddingLayerWithMasking() {
}
}


@Test
public void testW2VInits(){
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

for( int i=0; i<2; i++ ) {

INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3);

EmbeddingLayer el;
if(i == 0){
el = new EmbeddingLayer.Builder().weightInit(vectors).build();
} else {
el = new EmbeddingLayer.Builder().weightInit(new WordVectorsMockup()).build();
}

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(el)
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray w = net.getParam("0_W");
assertEquals(vectors, w);

TestUtils.testModelSerialization(net);

//Test same thing for embedding sequence layer:
EmbeddingSequenceLayer esl;
if(i == 0){
esl = new EmbeddingSequenceLayer.Builder().weightInit(vectors).build();
} else {
esl = new EmbeddingSequenceLayer.Builder().weightInit(new WordVectorsMockup()).build();
}

conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(esl)
.layer(new GlobalPoolingLayer())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

net = new MultiLayerNetwork(conf);
net.init();

w = net.getParam("0_W");
assertEquals(vectors, w);

TestUtils.testModelSerialization(net);
}
}

@EqualsAndHashCode
private static class WordVectorsMockup implements EmbeddingInitializer {

@Override
public void loadWeightsInto(INDArray array) {
INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3);
array.assign(vectors);
}

@Override
public long vocabSize() {
return 5;
}

@Override
public int vectorSize() {
return 3;
}

@Override
public boolean jsonSerializable() {
return true;
}
}
}
Expand Up @@ -19,6 +19,7 @@
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.io.Serializable;
Expand All @@ -32,7 +33,7 @@
*
* @author Adam Gibson
*/
public interface WordVectors extends Serializable {
public interface WordVectors extends Serializable, EmbeddingInitializer {

String getUNK();

Expand Down
Expand Up @@ -337,4 +337,24 @@ protected void update(Environment env, Event event) {
heartbeat.reportEvent(event, env, task);
}
}

@Override
public void loadWeightsInto(INDArray array) {
array.assign(lookupTable.getWeights());
}

@Override
public long vocabSize() {
return lookupTable.getWeights().size(0);
}

@Override
public int vectorSize() {
return lookupTable.layerSize();
}

@Override
public boolean jsonSerializable() {
return false;
}
}
Expand Up @@ -336,6 +336,46 @@ public void setModelUtils(ModelUtils utils) {
// no-op
}

@Override
public void loadWeightsInto(INDArray array) {
int n = (int)vocabSize();
INDArray zero = null;
for( int i=0; i<n; i++ ){
INDArray arr = storage.get(i);
if(arr == null){ //TODO is this even possible?
if(zero == null)
zero = Nd4j.create(array.dataType(), 1, array.size(1));
arr = zero;
}
array.putRow(i, arr);
}
}

@Override
public long vocabSize() {
return storage.size();
}

@Override
public int vectorSize() {
INDArray arr = storage.get(0);
if(arr != null)
return (int)arr.length();

int vs = (int)vocabSize();
for( int i=1; i<vs; i++ ){
arr = storage.get(0);
if(arr != null)
return (int)arr.length();
}
throw new UnsupportedOperationException("No vectors found");
}

@Override
public boolean jsonSerializable() {
return false;
}

public static class Builder {

private AbstractStorage<Integer> storage;
Expand Down
Expand Up @@ -18,19 +18,32 @@

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.util.Collection;

Expand Down Expand Up @@ -125,4 +138,68 @@ public void testLabelAwareIterator_1() throws Exception {
public void testPlot() {
//word2vec.lookupTable().plotVocab();
}


@Test
public void testW2VEmbeddingLayerInit() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

val inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();

val iter = new BasicLineIterator(inputFile);
val t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(1)
.epochs(1)
.layerSize(300)
.limitVocabularySize(1) // Limit the vocab size to 2 words
.windowSize(5)
.allowParallelTokenization(true)
.batchSize(512)
.learningRate(0.025)
.minLearningRate(0.0001)
.negativeSample(0.0)
.sampling(0.0)
.useAdaGrad(false)
.useHierarchicSoftmax(true)
.iterations(1)
.useUnknown(true) // Using UNK with limited vocab size causes the issue
.seed(42)
.iterate(iter)
.workers(4)
.tokenizerFactory(t).build();

vec.fit();

INDArray w = vec.lookupTable().getWeights();
System.out.println(w);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(new EmbeddingLayer.Builder().weightInit(vec).build())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(w.size(1)).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray w0 = net.getParam("0_W");
assertEquals(w, w0);



ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());
}
}
Expand Up @@ -24,6 +24,10 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;

Expand Down Expand Up @@ -120,6 +124,36 @@ public Builder hasBias(boolean hasBias) {
return this;
}

@Override
public Builder weightInit(IWeightInit weightInit) {
if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]);
nOut(shape[1]);
}
return super.weightInit(weightInit);
}

/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public Builder weightInit(EmbeddingInitializer embeddingInitializer){
return weightInit(new WeightInitEmbedding(embeddingInitializer));
}

/**
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
* array will be discarded (so that, if necessary, it can be garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public Builder weightInit(INDArray vectors){
return weightInit(new ArrayEmbeddingInitializer(vectors));
}

@Override
@SuppressWarnings("unchecked")
public EmbeddingLayer build() {
Expand Down
Expand Up @@ -24,6 +24,10 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;

Expand Down Expand Up @@ -162,6 +166,36 @@ public Builder inferInputLength(boolean inferInputLength) {
return this;
}

@Override
public Builder weightInit(IWeightInit weightInit) {
if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]);
nOut(shape[1]);
}
return super.weightInit(weightInit);
}

/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public Builder weightInit(EmbeddingInitializer embeddingInitializer){
return weightInit(new WeightInitEmbedding(embeddingInitializer));
}

/**
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
* array will be discarded (so that, if necessary, it can be garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public Builder weightInit(INDArray vectors){
return weightInit(new ArrayEmbeddingInitializer(vectors));
}

@Override
@SuppressWarnings("unchecked")
public EmbeddingSequenceLayer build() {
Expand Down

0 comments on commit 5da1e0d

Please sign in to comment.