Skip to content

Commit

Permalink
Speed up word2vec binary model loading (piskvorky#2642)
Browse files Browse the repository at this point in the history
  • Loading branch information
lopusz committed Nov 8, 2019
1 parent 44ea793 commit 5586c66
Showing 1 changed file with 67 additions and 40 deletions.
107 changes: 67 additions & 40 deletions gensim/models/utils_any2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import logging
from gensim import utils

from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring
from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, frombuffer

from six.moves import range
from six import iteritems, PY2
Expand Down Expand Up @@ -147,7 +147,7 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL):
limit=None, datatype=REAL, binary_chunk_size=100 * 1024):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.
Note that the information stored in the file is incomplete (the binary tree is missing),
Expand Down Expand Up @@ -176,14 +176,64 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
datatype : type, optional
(Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory.
Such types may result in much slower bulk operations or incompatibility with optimized routines.)
binary_chunk_size : int, optional
Size of chunk in which binary files are read. Used mostly for testing. Defalut value 100 kB.
Returns
-------
object
Returns the loaded model as an instance of :class:`cls`.
"""

def __add_word_to_result(result, counts, word, weights):
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

def __remove_initial_new_line(s):
i = 0
while i < len(s) and s[i] == '\n':
i += 1
return s[i:]

def __add_words_from_binary_chunk_to_result(result, counts, max_words, chunk, vector_size, datatype):
start = 0
n = len(chunk)
processed_words = 0
n_bytes_per_vector = vector_size * dtype(REAL).itemsize

for _ in range(0, max_words):
i_space = chunk.find(b' ', start)
i_vector = i_space + 1
if i_space != -1 and (n - i_vector) >= n_bytes_per_vector:
word = chunk[start:i_space].decode("utf-8", errors=unicode_errors)
# Some binary files are reported to have obsolete new line in the beginning of word, remove it
word = __remove_initial_new_line(word)
vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype)
__add_word_to_result(result, counts, word, vector)
start = i_vector + n_bytes_per_vector
processed_words += 1
else:
break

return processed_words, chunk[start:]

from gensim.models.keyedvectors import Vocab

counts = None
if fvocab is not None:
logger.info("loading word counts from %s", fvocab)
Expand All @@ -192,7 +242,6 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
for line in fin:
word, count = utils.to_unicode(line, errors=unicode_errors).strip().split()
counts[word] = int(count)

logger.info("loading projection weights from %s", fname)
with utils.open(fname, 'rb') as fin:
header = utils.to_unicode(fin.readline(), encoding=encoding)
Expand All @@ -202,43 +251,21 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
result = cls(vector_size)
result.vector_size = vector_size
result.vectors = zeros((vocab_size, vector_size), dtype=datatype)

def add_word(word, weights):
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

if binary:
binary_len = dtype(REAL).itemsize * vector_size
for _ in range(vocab_size):
# mixed text and binary: read text first, then binary
word = []
while True:
ch = fin.read(1) # Python uses I/O buffering internally
if ch == b' ':
break
if ch == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
with utils.ignore_deprecation_warning():
# TODO use frombuffer or something similar
weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype)
add_word(word, weights)
chunk = b''
tot_processed_words = 0

while tot_processed_words < vocab_size:
new_chunk = fin.read(binary_chunk_size)
chunk += new_chunk
max_words = vocab_size - len(result.vocab)
processed_words, chunk = __add_words_from_binary_chunk_to_result(result, counts, max_words,
chunk, vector_size, datatype)
tot_processed_words += processed_words
if len(new_chunk) < binary_chunk_size:
break
if tot_processed_words != vocab_size:
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
else:
for line_no in range(vocab_size):
line = fin.readline()
Expand All @@ -248,7 +275,7 @@ def add_word(word, weights):
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
add_word(word, weights)
__add_word_to_result(result, counts, word, weights)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
"duplicate words detected, shrinking matrix size from %i to %i",
Expand Down

0 comments on commit 5586c66

Please sign in to comment.