Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
scripts to download word vector models and reduce their size
Browse files Browse the repository at this point in the history
Summary:
pre-trained word vectors are conveniently available on fasttext's website at https://fasttext.cc/docs/en/crawl-vectors.html.
However, these vectors have a fixed dimension 300.

This commit adds a script to download word vectors, and a script to reduce their size to a custom dimension.

Reviewed By: piotr-bojanowski

Differential Revision: D18706087

fbshipit-source-id: d5370a62687b10387d9630c7128809bc4da30f0a
  • Loading branch information
Celebio authored and facebook-github-bot committed Jan 3, 2020
1 parent da2745f commit 02c61ef
Show file tree
Hide file tree
Showing 14 changed files with 432 additions and 9 deletions.
67 changes: 67 additions & 0 deletions docs/crawl-vectors.md
Expand Up @@ -7,6 +7,73 @@ We distribute pre-trained word vectors for 157 languages, trained on [*Common Cr
These models were trained using CBOW with position-weights, in dimension 300, with character n-grams of length 5, a window of size 5 and 10 negatives.
We also distribute three new word analogy datasets, for French, Hindi and Polish.

### Download directly with command line or from python

In order to download with command line or from python code, you must have installed the python package as [described here](http://localhost:3000/docs/en/support.html#building-fasttext-python-module).

<!--DOCUSAURUS_CODE_TABS-->
<!--Command line-->
```bash
$ ./download_model.py en # English
Downloading https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz
(19.78%) [=========> ]
```
Once the download is finished, use the model as usual:
```bash
$ ./fasttext nn cc.en.300.bin 10
Query word?
```
<!--Python-->
```py
>>> import fasttext.util
>>> fasttext.util.download_model('en', if_exists='ignore') # English
>>> ft = fasttext.load_model('cc.en.300.bin')
```
<!--END_DOCUSAURUS_CODE_TABS-->

### Adapt the dimension

The pre-trained word vectors we distribute have dimension 300. If you need a smaller size, you can use our dimension reducer.
In order to use that feature, you must have installed the python package as [described here](http://localhost:3000/docs/en/support.html#building-fasttext-python-module).

For example, in order to get vectors of dimension 100:
<!--DOCUSAURUS_CODE_TABS-->

<!--Command line-->
```bash
$ ./reduce_model.py cc.en.300.bin 100
Loading model
Reducing matrix dimensions
Saving model
cc.en.100.bin saved
```
Then you can use the `cc.en.100.bin` model file as usual.

<!--Python-->
```py
>>> import fasttext
>>> import fasttext.util
>>> ft = fasttext.load_model('cc.en.300.bin')
>>> ft.get_dimension()
300
>>> fasttext.util.reduce_model(ft, 100)
>>> ft.get_dimension()
100
```
Then you can use `ft` model object as usual:
```py
>>> ft.get_word_vector('hello').shape
(100,)
>>> ft.get_nearest_neighbors('hello')
[(0.775576114654541, u'heyyyy'), (0.7686290144920349, u'hellow'), (0.7663413286209106, u'hello-'), (0.7579624056816101, u'heyyyyy'), (0.7495524287223816, u'hullo'), (0.7473770380020142, u'.hello'), (0.7407292127609253, u'Hiiiii'), (0.7402616739273071, u'hellooo'), (0.7399682402610779, u'hello.'), (0.7396857738494873, u'Heyyyyy')]
```
or save it for later use:
```py
>>> ft.save_model('cc.en.100.bin')
```
<!--END_DOCUSAURUS_CODE_TABS-->


### Format

The word vectors are available in both binary and text formats.
Expand Down
3 changes: 3 additions & 0 deletions docs/faqs.md
Expand Up @@ -61,3 +61,6 @@ If you run fastText multiple times you'll obtain slightly different results each

## Why do I get a probability of 1.00001?
This is a known rounding issue. You can consider it as 1.0.

## How can I change the dimension of word vectors of a model file?
If you already trained a model, or downloaded a pre-trained word vectors model, you can adapt the dimension of the word vectors with the `reduce_model.py` script or by calling `fasttext.util.reduce_model` from python, as [described here](/docs/en/crawl-vectors.html#adapt-the-dimension)
48 changes: 48 additions & 0 deletions download_model.py
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse

import fasttext.util


args = None


def command_download(lang_id, if_exists):
"""
Download pre-trained common-crawl vectors from fastText's website
https://fasttext.cc/docs/en/crawl-vectors.html
"""
fasttext.util.download_model(lang_id, if_exists)


def main():
global args

parser = argparse.ArgumentParser(
description='fastText helper tool to reduce model dimensions.')
parser.add_argument("language", type=str, default="en",
help="language identifier of the pre-trained vectors. For example `en` or `fr`.")
parser.add_argument("--overwrite", action="store_true",
help="overwrite if file exists.")

args = parser.parse_args()

command_download(args.language, if_exists=(
'overwrite' if args.overwrite else 'strict'))


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion python/benchmarks/get_word_vector.py
Expand Up @@ -42,7 +42,8 @@ def get_word_vector(data, model):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Simple benchmark for get_word_vector.')
parser = argparse.ArgumentParser(
description='Simple benchmark for get_word_vector.')
parser.add_argument('model', help='A model file to use for benchmarking.')
parser.add_argument('data', help='A data file to use for benchmarking.')
args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions python/doc/examples/train_supervised.py
Expand Up @@ -20,6 +20,7 @@ def print_results(N, p, r):
print("P@{}\t{:.3f}".format(1, p))
print("R@{}\t{:.3f}".format(1, r))


if __name__ == "__main__":
train_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.train')
valid_data = os.path.join(os.getenv("DATADIR", ''), 'cooking.valid')
Expand Down
12 changes: 10 additions & 2 deletions python/fasttext_module/fasttext/FastText.py
Expand Up @@ -170,7 +170,7 @@ def check(entry):

def get_input_matrix(self):
"""
Get a copy of the full input matrix of a Model. This only
Get a reference to the full input matrix of a Model. This only
works if the model is not quantized.
"""
if self.f.isQuant():
Expand All @@ -179,7 +179,7 @@ def get_input_matrix(self):

def get_output_matrix(self):
"""
Get a copy of the full output matrix of a Model. This only
Get a reference to the full output matrix of a Model. This only
works if the model is not quantized.
"""
if self.f.isQuant():
Expand Down Expand Up @@ -292,6 +292,14 @@ def quantize(
qnorm
)

def set_matrices(self, input_matrix, output_matrix):
"""
Set input and output matrices. This function assumes you know what you
are doing.
"""
self.f.setMatrices(input_matrix.astype(np.float32),
output_matrix.astype(np.float32))

@property
def words(self):
if self._words is None:
Expand Down
26 changes: 23 additions & 3 deletions python/fasttext_module/fasttext/pybind/fasttext_pybind.cc
Expand Up @@ -180,14 +180,34 @@ PYBIND11_MODULE(fasttext_pybind, m) {
[](fasttext::FastText& m) {
std::shared_ptr<const fasttext::DenseMatrix> mm =
m.getInputMatrix();
return *mm.get();
})
return mm.get();
},
pybind11::return_value_policy::reference)
.def(
"getOutputMatrix",
[](fasttext::FastText& m) {
std::shared_ptr<const fasttext::DenseMatrix> mm =
m.getOutputMatrix();
return *mm.get();
return mm.get();
},
pybind11::return_value_policy::reference)
.def(
"setMatrices",
[](fasttext::FastText& m,
py::buffer inputMatrixBuffer,
py::buffer outputMatrixBuffer) {
py::buffer_info inputMatrixInfo = inputMatrixBuffer.request();
py::buffer_info outputMatrixInfo = outputMatrixBuffer.request();

m.setMatrices(
std::make_shared<fasttext::DenseMatrix>(
inputMatrixInfo.shape[0],
inputMatrixInfo.shape[1],
static_cast<float*>(inputMatrixInfo.ptr)),
std::make_shared<fasttext::DenseMatrix>(
outputMatrixInfo.shape[0],
outputMatrixInfo.shape[1],
static_cast<float*>(outputMatrixInfo.ptr)));
})
.def(
"loadModel",
Expand Down
2 changes: 2 additions & 0 deletions python/fasttext_module/fasttext/util/__init__.py
Expand Up @@ -11,3 +11,5 @@

from .util import test
from .util import find_nearest_neighbor
from .util import reduce_model
from .util import download_model
149 changes: 149 additions & 0 deletions python/fasttext_module/fasttext/util/util.py
@@ -1,3 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
Expand All @@ -18,6 +20,35 @@
from __future__ import unicode_literals

import numpy as np
import sys
import shutil
import os
import gzip

try:
from urllib.request import urlopen
except ImportError:
from urllib2 import urlopen


valid_lang_ids = {"af", "sq", "als", "am", "ar", "an", "hy", "as", "ast",
"az", "ba", "eu", "bar", "be", "bn", "bh", "bpy", "bs",
"br", "bg", "my", "ca", "ceb", "bcl", "ce", "zh", "cv",
"co", "hr", "cs", "da", "dv", "nl", "pa", "arz", "eml",
"en", "myv", "eo", "et", "hif", "fi", "fr", "gl", "ka",
"de", "gom", "el", "gu", "ht", "he", "mrj", "hi", "hu",
"is", "io", "ilo", "id", "ia", "ga", "it", "ja", "jv",
"kn", "pam", "kk", "km", "ky", "ko", "ku", "ckb", "la",
"lv", "li", "lt", "lmo", "nds", "lb", "mk", "mai", "mg",
"ms", "ml", "mt", "gv", "mr", "mzn", "mhr", "min", "xmf",
"mwl", "mn", "nah", "nap", "ne", "new", "frr", "nso",
"no", "nn", "oc", "or", "os", "pfl", "ps", "fa", "pms",
"pl", "pt", "qu", "ro", "rm", "ru", "sah", "sa", "sc",
"sco", "gd", "sr", "sh", "scn", "sd", "si", "sk", "sl",
"so", "azb", "es", "su", "sw", "sv", "tl", "tg", "ta",
"tt", "te", "th", "bo", "tr", "tk", "uk", "hsb", "ur",
"ug", "uz", "vec", "vi", "vo", "wa", "war", "cy", "vls",
"fy", "pnb", "yi", "yo", "diq", "zea"}


# TODO: Add example on reproducing model.test with util.test and model.get_line
Expand Down Expand Up @@ -58,3 +89,121 @@ def find_nearest_neighbor(query, vectors, ban_set, cossims=None):
rank -= 1
result_i = np.argpartition(cossims, rank)[rank]
return result_i


def _reduce_matrix(X_orig, dim, eigv):
"""
Reduces the dimension of a (m × n) matrix `X_orig` to
to a (m × dim) matrix `X_reduced`
It uses only the first 100000 rows of `X_orig` to do the mapping.
Matrix types are all `np.float32` in order to avoid unncessary copies.
"""
if eigv is None:
mapping_size = 100000
X = X_orig[:mapping_size]
X = X - X.mean(axis=0, dtype=np.float32)
C = np.divide(np.matmul(X.T, X), X.shape[0] - 1, dtype=np.float32)
_, U = np.linalg.eig(C)
eigv = U[:, :dim]

X_reduced = np.matmul(X_orig, eigv)

return (X_reduced, eigv)


def reduce_model(ft_model, target_dim):
"""
ft_model is an instance of `_FastText` class
This function computes the PCA of the input and the output matrices
and sets the reduced ones.
"""
inp_reduced, proj = _reduce_matrix(
ft_model.get_input_matrix(), target_dim, None)
out_reduced, _ = _reduce_matrix(
ft_model.get_output_matrix(), target_dim, proj)

ft_model.set_matrices(inp_reduced, out_reduced)

return ft_model


def _print_progress(downloaded_bytes, total_size):
percent = float(downloaded_bytes) / total_size
bar_size = 50
bar = int(percent * bar_size)
percent = round(percent * 100, 2)
sys.stdout.write(" (%0.2f%%) [" % percent)
sys.stdout.write("=" * bar)
sys.stdout.write(">")
sys.stdout.write(" " * (bar_size - bar))
sys.stdout.write("]\r")
sys.stdout.flush()

if downloaded_bytes >= total_size:
sys.stdout.write('\n')


def _download_file(url, write_file_name, chunk_size=2**13):
print("Downloading %s" % url)
response = urlopen(url)
if hasattr(response, 'getheader'):
file_size = int(response.getheader('Content-Length').strip())
else:
file_size = int(response.info().getheader('Content-Length').strip())
downloaded = 0
download_file_name = write_file_name + ".part"
with open(download_file_name, 'wb') as f:
while True:
chunk = response.read(chunk_size)
downloaded += len(chunk)
if not chunk:
break
f.write(chunk)
_print_progress(downloaded, file_size)

os.rename(download_file_name, write_file_name)


def _download_gz_model(gz_file_name, if_exists):
if os.path.isfile(gz_file_name):
if if_exists == 'ignore':
return True
elif if_exists == 'strict':
print("gzip File exists. Use --overwrite to download anyway.")
return False
elif if_exists == 'overwrite':
pass

url = "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/%s" % gz_file_name
_download_file(url, gz_file_name)

return True


def download_model(lang_id, if_exists='strict', dimension=None):
"""
Download pre-trained common-crawl vectors from fastText's website
https://fasttext.cc/docs/en/crawl-vectors.html
"""
if lang_id not in valid_lang_ids:
raise Exception("Invalid lang id. Please select among %s" %
repr(valid_lang_ids))

file_name = "cc.%s.300.bin" % lang_id
gz_file_name = "%s.gz" % file_name

if os.path.isfile(file_name):
if if_exists == 'ignore':
return file_name
elif if_exists == 'strict':
print("File exists. Use --overwrite to download anyway.")
return
elif if_exists == 'overwrite':
pass

if _download_gz_model(gz_file_name, if_exists):
with gzip.open(gz_file_name, 'rb') as f:
with open(file_name, 'wb') as f_out:
shutil.copyfileobj(f, f_out)

return file_name

0 comments on commit 02c61ef

Please sign in to comment.