Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] DL4J Embedding Layer - Word Vectors Initialization #7173

Merged
merged 4 commits into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

package org.deeplearning4j.nn.layers.feedforward.embedding;

import lombok.EqualsAndHashCode;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
Expand All @@ -29,8 +31,10 @@
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
Expand Down Expand Up @@ -516,4 +520,87 @@ public void testEmbeddingLayerWithMasking() {
}
}


@Test
public void testW2VInits(){
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

for( int i=0; i<2; i++ ) {

INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3);

EmbeddingLayer el;
if(i == 0){
el = new EmbeddingLayer.Builder().weightInit(vectors).build();
} else {
el = new EmbeddingLayer.Builder().weightInit(new WordVectorsMockup()).build();
}

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(el)
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray w = net.getParam("0_W");
assertEquals(vectors, w);

TestUtils.testModelSerialization(net);

//Test same thing for embedding sequence layer:
EmbeddingSequenceLayer esl;
if(i == 0){
esl = new EmbeddingSequenceLayer.Builder().weightInit(vectors).build();
} else {
esl = new EmbeddingSequenceLayer.Builder().weightInit(new WordVectorsMockup()).build();
}

conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(esl)
.layer(new GlobalPoolingLayer())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

net = new MultiLayerNetwork(conf);
net.init();

w = net.getParam("0_W");
assertEquals(vectors, w);

TestUtils.testModelSerialization(net);
}
}

@EqualsAndHashCode
private static class WordVectorsMockup implements EmbeddingInitializer {

@Override
public void loadWeightsInto(INDArray array) {
INDArray vectors = Nd4j.linspace(1,15,15, DataType.FLOAT).reshape(5,3);
array.assign(vectors);
}

@Override
public long vocabSize() {
return 5;
}

@Override
public int vectorSize() {
return 3;
}

@Override
public boolean jsonSerializable() {
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.io.Serializable;
Expand All @@ -32,7 +33,7 @@
*
* @author Adam Gibson
*/
public interface WordVectors extends Serializable {
public interface WordVectors extends Serializable, EmbeddingInitializer {

String getUNK();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,24 @@ protected void update(Environment env, Event event) {
heartbeat.reportEvent(event, env, task);
}
}

@Override
public void loadWeightsInto(INDArray array) {
array.assign(lookupTable.getWeights());
}

@Override
public long vocabSize() {
return lookupTable.getWeights().size(0);
}

@Override
public int vectorSize() {
return lookupTable.layerSize();
}

@Override
public boolean jsonSerializable() {
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,46 @@ public void setModelUtils(ModelUtils utils) {
// no-op
}

@Override
public void loadWeightsInto(INDArray array) {
int n = (int)vocabSize();
INDArray zero = null;
for( int i=0; i<n; i++ ){
INDArray arr = storage.get(i);
if(arr == null){ //TODO is this even possible?
if(zero == null)
zero = Nd4j.create(array.dataType(), 1, array.size(1));
arr = zero;
}
array.putRow(i, arr);
}
}

@Override
public long vocabSize() {
return storage.size();
}

@Override
public int vectorSize() {
INDArray arr = storage.get(0);
if(arr != null)
return (int)arr.length();

int vs = (int)vocabSize();
for( int i=1; i<vs; i++ ){
arr = storage.get(0);
if(arr != null)
return (int)arr.length();
}
throw new UnsupportedOperationException("No vectors found");
}

@Override
public boolean jsonSerializable() {
return false;
}

public static class Builder {

private AbstractStorage<Integer> storage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,32 @@

import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.util.Collection;

Expand Down Expand Up @@ -125,4 +138,68 @@ public void testLabelAwareIterator_1() throws Exception {
public void testPlot() {
//word2vec.lookupTable().plotVocab();
}


@Test
public void testW2VEmbeddingLayerInit() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

val inputFile = new ClassPathResource("/big/raw_sentences.txt").getFile();

val iter = new BasicLineIterator(inputFile);
val t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

Word2Vec vec = new Word2Vec.Builder()
.minWordFrequency(1)
.epochs(1)
.layerSize(300)
.limitVocabularySize(1) // Limit the vocab size to 2 words
.windowSize(5)
.allowParallelTokenization(true)
.batchSize(512)
.learningRate(0.025)
.minLearningRate(0.0001)
.negativeSample(0.0)
.sampling(0.0)
.useAdaGrad(false)
.useHierarchicSoftmax(true)
.iterations(1)
.useUnknown(true) // Using UNK with limited vocab size causes the issue
.seed(42)
.iterate(iter)
.workers(4)
.tokenizerFactory(t).build();

vec.fit();

INDArray w = vec.lookupTable().getWeights();
System.out.println(w);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345).list()
.layer(new EmbeddingLayer.Builder().weightInit(vec).build())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(w.size(1)).nOut(3).build())
.layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

INDArray w0 = net.getParam("0_W");
assertEquals(w, w0);



ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);

assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;

Expand Down Expand Up @@ -120,6 +124,36 @@ public Builder hasBias(boolean hasBias) {
return this;
}

@Override
public Builder weightInit(IWeightInit weightInit) {
if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]);
nOut(shape[1]);
}
return super.weightInit(weightInit);
}

/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public Builder weightInit(EmbeddingInitializer embeddingInitializer){
return weightInit(new WeightInitEmbedding(embeddingInitializer));
}

/**
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
* array will be discarded (so that, if necessary, it can be garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public Builder weightInit(INDArray vectors){
return weightInit(new ArrayEmbeddingInitializer(vectors));
}

@Override
@SuppressWarnings("unchecked")
public EmbeddingLayer build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;

Expand Down Expand Up @@ -162,6 +166,36 @@ public Builder inferInputLength(boolean inferInputLength) {
return this;
}

@Override
public Builder weightInit(IWeightInit weightInit) {
if(weightInit instanceof WeightInitEmbedding){
long[] shape = ((WeightInitEmbedding) weightInit).shape();
nIn(shape[0]);
nOut(shape[1]);
}
return super.weightInit(weightInit);
}

/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public Builder weightInit(EmbeddingInitializer embeddingInitializer){
return weightInit(new WeightInitEmbedding(embeddingInitializer));
}

/**
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
* array will be discarded (so that, if necessary, it can be garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public Builder weightInit(INDArray vectors){
return weightInit(new ArrayEmbeddingInitializer(vectors));
}

@Override
@SuppressWarnings("unchecked")
public EmbeddingSequenceLayer build() {
Expand Down
Loading