-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added module to train fasttext embeddings
Includes test pipeline.
- Loading branch information
Showing
5 changed files
with
229 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# This file is part of Pimlico | ||
# Copyright (C) 2016 Mark Granroth-Wilding | ||
# Licensed under the GNU GPL v3.0 - http://www.gnu.org/licenses/gpl-3.0.en.html | ||
import os | ||
from collections import Counter | ||
from io import open | ||
|
||
import fasttext | ||
import numpy as np | ||
|
||
from pimlico.core.modules.base import BaseModuleExecutor | ||
from pimlico.utils.progress import get_progress_bar | ||
|
||
|
||
class ModuleExecutor(BaseModuleExecutor): | ||
def execute(self): | ||
input_corpus = self.info.get_input("text") | ||
opts = self.info.options | ||
|
||
input_path = os.path.join(self.info.get_module_output_dir(absolute=True), "fasttext_input_data.txt") | ||
self.log.info("Preparing input data file for fastText: {}".format(input_path)) | ||
pbar = get_progress_bar(len(input_corpus), title="Preparing data") | ||
|
||
# Fasttext needs to read its input from a unicode text file, so we output the corpus | ||
# We'll also keep word counts at the same time, for writing plain embeddings later | ||
word_counts = Counter() | ||
with open(input_path, "w", encoding="utf-8") as f: | ||
for doc_name, doc in pbar(input_corpus): | ||
for sentence in doc.sentences: | ||
f.write(u"{}\n".format(u" ".join(sentence))) | ||
word_counts.update(sentence) | ||
|
||
self.log.info("Training fastText embeddings") | ||
# Almost all options come straight from the module options | ||
model = fasttext.train_unsupervised( | ||
input_path, | ||
model=opts["model"], | ||
lr=opts["lr"], | ||
dim=opts["dim"], | ||
ws=opts["ws"], | ||
epoch=opts["epoch"], | ||
minCount=opts["min_count"], | ||
minn=opts["minn"], maxn=opts["maxn"], | ||
neg=opts["neg"], | ||
wordNgrams=opts["word_ngrams"], | ||
loss=opts["loss"], | ||
bucket=opts["bucket"], | ||
thread=self.processes, | ||
lrUpdateRate=opts["lr_update_rate"], | ||
t=opts["t"], | ||
verbose=opts["verbose"], | ||
) | ||
|
||
num_words = len(model.words) | ||
self.log.info("Training complete. Trained {:,d} vectors".format(num_words)) | ||
|
||
self.log.info("Writing out fastText embeddings in native fastText format") | ||
with self.info.get_output_writer("model") as writer: | ||
writer.save_model(model) | ||
|
||
self.log.info("Writing out plain embeddings") | ||
with self.info.get_output_writer("embeddings") as writer: | ||
# Build a dictionary of word counts for all the words for which embeddings are stored | ||
model_word_counts = [(word, word_counts[word]) for word in model.words] | ||
writer.write_word_counts(model_word_counts) | ||
# Now get the embeddings in one big matrix | ||
vectors = np.zeros((len(model.words), model.dim), dtype=np.float32) | ||
for w, word in enumerate(model.words): | ||
vectors[w] = model[word] | ||
# Output the vectors | ||
writer.write_vectors(vectors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# This file is part of Pimlico | ||
# Copyright (C) 2016 Mark Granroth-Wilding | ||
# Licensed under the GNU GPL v3.0 - http://www.gnu.org/licenses/gpl-3.0.en.html | ||
|
||
""" | ||
Train fastText embeddings on a tokenized corpus. | ||
Uses the `fastText Python package <https://fasttext.cc/docs/en/python-module.html>`. | ||
FastText embeddings store more than just a vector for each word, since they | ||
also have sub-word representations. We therefore store a standard embeddings | ||
output, with the word vectors in, and also a special fastText embeddings output. | ||
""" | ||
from pimlico.core.dependencies.python import PythonPackageOnPip, numpy_dependency | ||
from pimlico.core.modules.base import BaseModuleInfo | ||
from pimlico.core.modules.options import choose_from_list | ||
from pimlico.datatypes import GroupedCorpus, Embeddings | ||
from pimlico.datatypes.corpora.tokenized import TokenizedDocumentType | ||
from pimlico.datatypes.embeddings import FastTextEmbeddings | ||
|
||
|
||
class ModuleInfo(BaseModuleInfo): | ||
module_type_name = "fasttext" | ||
module_readable_name = "fastText embedding trainer" | ||
module_inputs = [("text", GroupedCorpus(TokenizedDocumentType()))] | ||
module_outputs = [("embeddings", Embeddings()), ("model", FastTextEmbeddings())] | ||
module_options = { | ||
"model": { | ||
"help": "unsupervised fasttext model: cbow, skipgram. Default: skipgram", | ||
"type": choose_from_list(["skipgram", "cbow"]), | ||
"default": "skipgram", | ||
}, | ||
"dim": { | ||
"help": "size of word vectors. Default: 100", | ||
"type": int, | ||
"default": 100, | ||
}, | ||
"lr": { | ||
"help": "learning rate. Default: 0.05", | ||
"type": float, | ||
"default": 0.05, | ||
}, | ||
"ws": { | ||
"help": "size of the context window. Default: 5", | ||
"type": int, | ||
"default": 5, | ||
}, | ||
"epoch": { | ||
"help": "number of epochs. Default: 5", | ||
"type": int, | ||
"default": 5, | ||
}, | ||
"min_count": { | ||
"help": "minimal number of word occurences. Default: 5", | ||
"type": int, | ||
"default": 5, | ||
}, | ||
"minn": { | ||
"help": "min length of char ngram. Default: 3", | ||
"type": int, | ||
"default": 3, | ||
}, | ||
"maxn": { | ||
"help": "max length of char ngram. Default: 6", | ||
"type": int, | ||
"default": 6, | ||
}, | ||
"neg": { | ||
"help": "number of negatives sampled. Default: 5", | ||
"type": int, | ||
"default": 5, | ||
}, | ||
"word_ngrams": { | ||
"help": "max length of word ngram. Default: 1", | ||
"type": int, | ||
"default": 1, | ||
}, | ||
"loss": { | ||
"help": "loss function: ns, hs, softmax, ova. Default: ns", | ||
"type": choose_from_list(["ns", "hs", "softmax", "ova"]), | ||
"default": "ns", | ||
}, | ||
"bucket": { | ||
"help": "number of buckets. Default: 2,000,000", | ||
"type": int, | ||
"default": 2000000, | ||
}, | ||
"lr_update_rate": { | ||
"help": "change the rate of updates for the learning rate. Default: 100", | ||
"type": int, | ||
"default": 100, | ||
}, | ||
"t": { | ||
"help": "sampling threshold. Default: 0.0001", | ||
"type": float, | ||
"default": 0.0001, | ||
}, | ||
"verbose": { | ||
"help": "verbose. Default: 2", | ||
"type": int, | ||
"default": 2, | ||
}, | ||
} | ||
module_supports_python2 = False | ||
|
||
def get_software_dependencies(self): | ||
return super(ModuleInfo, self).get_software_dependencies() + [ | ||
PythonPackageOnPip("fasttext"), numpy_dependency, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Train fastText embeddings on a tiny corpus | ||
[pipeline] | ||
name=fasttext_train | ||
release=latest | ||
|
||
# Take tokenized text input from a prepared Pimlico dataset | ||
[europarl] | ||
type=pimlico.datatypes.corpora.GroupedCorpus | ||
data_point_type=TokenizedDocumentType | ||
dir=%(test_data_dir)s/datasets/corpora/tokenized | ||
|
||
[fasttext] | ||
type=pimlico.modules.embeddings.fasttext | ||
# Set low, since we're training on a tiny corpus | ||
min_count=1 | ||
# Very small vectors: usually this will be more like 100 or 200 | ||
dim=10 |