From 4bca170bfca22d07f071f5401d551d829ae16dc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 2 Aug 2019 15:12:27 +0200 Subject: [PATCH] Add and implement the ReadFastText trait --- Cargo.toml | 1 + src/fasttext/io.rs | 327 ++++++++++++++++++++++++++++++++++++++++++++ src/fasttext/mod.rs | 3 + src/vocab.rs | 2 +- 4 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 src/fasttext/io.rs diff --git a/Cargo.toml b/Cargo.toml index 3ca7692..737c028 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ ordered-float = "1" rand = "0.6" rand_xorshift = "0.1" reductive = "0.2" +serde = { version = "1", features = ["derive"] } toml = "0.5" [dev-dependencies] diff --git a/src/fasttext/io.rs b/src/fasttext/io.rs new file mode 100644 index 0000000..a90fddc --- /dev/null +++ b/src/fasttext/io.rs @@ -0,0 +1,327 @@ +use std::io::BufRead; + +use byteorder::{LittleEndian, ReadBytesExt}; +use ndarray::{s, Array2, ErrorKind as ShapeErrorKind, ShapeError}; +use serde::Serialize; +use toml::Value; + +use crate::embeddings::Embeddings; +use crate::io::{Error, ErrorKind, Result}; +use crate::metadata::Metadata; +use crate::norms::NdNorms; +use crate::storage::{NdArray, Storage, StorageViewMut}; +use crate::subword::BucketIndexer; +use crate::util::{l2_normalize_array, read_string}; +use crate::vocab::{FastTextSubwordVocab, Vocab}; + +use super::FastTextIndexer; + +const FASTTEXT_FILEFORMAT_MAGIC: u32 = 793_712_314; +const FASTTEXT_VERSION: u32 = 12; + +/// Read embeddings in the fastText format. +pub trait ReadFastText +where + Self: Sized, +{ + /// Read embeddings in the fastText format. + fn read_fasttext(reader: &mut impl BufRead) -> Result; +} + +impl ReadFastText for Embeddings { + fn read_fasttext(mut reader: &mut impl BufRead) -> Result { + let magic = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot fastText read magic", e))?; + if magic != FASTTEXT_FILEFORMAT_MAGIC { + return Err(ErrorKind::Format(format!( + "Expected {} as magic, got: {}", + FASTTEXT_FILEFORMAT_MAGIC, magic + )) + .into()); + } + + let version = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read fastText version", e))?; + if version > FASTTEXT_VERSION { + return Err(ErrorKind::Format(format!( + "Expected {} as version, got: {}", + FASTTEXT_VERSION, version + )) + .into()); + } + + let config = Config::read(&mut reader)?; + + let vocab = read_vocab(&config, &mut reader)?; + + let is_quantized = reader + .read_u8() + .map_err(|e| ErrorKind::io_error("Cannot read quantization information", e))?; + if is_quantized == 1 { + return Err( + ErrorKind::Format("Quantized fastText models are not supported".into()).into(), + ); + } + + // Read and prepare storage. + let mut storage = read_embeddings(&mut reader)?; + add_subword_embeddings(&vocab, &mut storage); + #[allow(clippy::deref_addrof)] + let norms = NdNorms(l2_normalize_array( + storage.view_mut().slice_mut(s![0..vocab.len(), ..]), + )); + + // Verify that vocab and storage shapes match. + if storage.shape().0 != vocab.len() + config.bucket as usize { + return Err(Error::Shape(ShapeError::from_kind( + ShapeErrorKind::IncompatibleShape, + ))); + } + + let metadata = Value::try_from(config).map_err(|e| { + ErrorKind::Format(format!("Cannot serialize model metadata to TOML: {}", e)) + })?; + + Ok(Embeddings::new( + Some(Metadata(metadata)), + vocab, + storage, + norms, + )) + } +} + +/// fastText model configuration. +#[derive(Copy, Clone, Debug, Serialize)] +struct Config { + pub dims: u32, + pub window_size: u32, + pub epoch: u32, + pub min_count: u32, + pub neg: u32, + pub word_ngrams: u32, + pub loss: Loss, + pub model: Model, + pub bucket: u32, + pub min_n: u32, + pub max_n: u32, + pub lr_update_rate: u32, + pub sampling_threshold: f64, +} + +impl Config { + /// Read fastText model configuration. + fn read(reader: &mut R) -> Result + where + R: BufRead, + { + let dims = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read number of dimensions", e))?; + let window_size = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read window size", e))?; + let epoch = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read number of epochs", e))?; + let min_count = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read minimum count", e))?; + let neg = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read negative samples", e))?; + let word_ngrams = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read word n-gram length", e))?; + let loss = Loss::read(reader)?; + let model = Model::read(reader)?; + let bucket = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read number of buckets", e))?; + let min_n = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read minimum subword length", e))?; + let max_n = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read maximum subword length", e))?; + let lr_update_rate = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read LR update rate", e))?; + let sampling_threshold = reader + .read_f64::() + .map_err(|e| ErrorKind::io_error("Cannot read sampling threshold", e))?; + + Ok(Config { + dims, + window_size, + epoch, + min_count, + neg, + word_ngrams, + loss, + model, + bucket, + min_n, + max_n, + lr_update_rate, + sampling_threshold, + }) + } +} + +/// fastText loss type. +#[derive(Copy, Clone, Debug, Serialize)] +enum Loss { + HierarchicalSoftmax, + NegativeSampling, + Softmax, +} + +impl Loss { + fn read(reader: &mut R) -> Result + where + R: BufRead, + { + let loss = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read loss type", e))?; + + use self::Loss::*; + match loss { + 1 => Ok(HierarchicalSoftmax), + 2 => Ok(NegativeSampling), + 3 => Ok(Softmax), + l => Err(ErrorKind::Format(format!("Unknown loss: {}", l)).into()), + } + } +} + +/// fastText model type. +#[derive(Copy, Clone, Debug, Serialize)] +enum Model { + CBOW, + SkipGram, + Supervised, +} + +impl Model { + fn read(reader: &mut R) -> Result + where + R: BufRead, + { + let model = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read model type", e))?; + + use self::Model::*; + match model { + 1 => Ok(CBOW), + 2 => Ok(SkipGram), + 3 => Ok(Supervised), + m => Err(ErrorKind::Format(format!("Unknown model: {}", m)).into()), + } + } +} + +/// Add subword embeddings to word embeddings. +/// +/// fastText stores word embeddings without subword embeddings. This method +/// adds the subword embeddings. +fn add_subword_embeddings(vocab: &FastTextSubwordVocab, embeds: &mut NdArray) { + for (idx, word) in vocab.words().iter().enumerate() { + if let Some(indices) = vocab.subword_indices(word) { + let n_embeds = indices.len() + 1; + + // Sum the embedding and its subword embeddings. + let mut embed = embeds.embedding(idx).into_owned(); + for subword_idx in indices { + embed += &embeds.embedding(subword_idx).as_view(); + } + + // Compute the average embedding. + embed /= n_embeds as f32; + + embeds.view_mut().row_mut(idx).assign(&embed); + } + } +} + +/// Read the embedding matrix. +fn read_embeddings(reader: &mut R) -> Result +where + R: BufRead, +{ + let m = reader + .read_u64::() + .map_err(|e| ErrorKind::io_error("Cannot read number of embedding matrix rows", e))?; + let n = reader + .read_u64::() + .map_err(|e| ErrorKind::io_error("Cannot read number of embedding matrix columns", e))?; + + let mut data = vec![0.0; (m * n) as usize]; + reader + .read_f32_into::(&mut data) + .map_err(|e| ErrorKind::io_error("Cannot read embeddings", e))?; + + let data = Array2::from_shape_vec((m as usize, n as usize), data).map_err(Error::Shape)?; + + Ok(NdArray(data)) +} + +/// Read the vocabulary. +fn read_vocab(config: &Config, reader: &mut R) -> Result +where + R: BufRead, +{ + let size = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read vocabulary size", e))?; + reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read number of words", e))?; + + let n_labels = reader + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot number of labels", e))?; + if n_labels > 0 { + return Err( + ErrorKind::Format("fastText prediction models are not supported".into()).into(), + ); + } + + reader + .read_u64::() + .map_err(|e| ErrorKind::io_error("Cannot read number of tokens", e))?; + + let prune_idx_size = reader + .read_i64::() + .map_err(|e| ErrorKind::io_error("Cannot read pruned vocabulary size", e))?; + if prune_idx_size > 0 { + return Err(ErrorKind::Format("Pruned vocabularies are not supported".into()).into()); + } + + let mut words = Vec::with_capacity(size as usize); + for _ in 0..size { + let word = read_string(reader, 0)?; + reader + .read_u64::() + .map_err(|e| ErrorKind::io_error("Cannot word frequency", e))?; + let entry_type = reader + .read_u8() + .map_err(|e| ErrorKind::io_error("Cannot read entry type", e))?; + if entry_type != 0 { + return Err(ErrorKind::Format("Non-word entry".into()).into()); + } + + words.push(word) + } + + Ok(FastTextSubwordVocab::new( + words, + config.min_n, + config.max_n, + FastTextIndexer::new(config.bucket as usize), + )) +} diff --git a/src/fasttext/mod.rs b/src/fasttext/mod.rs index e2554c0..f8be8a2 100644 --- a/src/fasttext/mod.rs +++ b/src/fasttext/mod.rs @@ -2,3 +2,6 @@ mod indexer; pub use self::indexer::FastTextIndexer; + +mod io; +pub use self::io::ReadFastText; diff --git a/src/vocab.rs b/src/vocab.rs index e048b0d..faeda0e 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -160,7 +160,7 @@ where /// /// Returns `None` when the model does not support subwords or /// when no subwords could be extracted. - fn subword_indices(&self, word: &str) -> Option> { + pub(crate) fn subword_indices(&self, word: &str) -> Option> { let indices = Self::bracket(word) .as_str() .subword_indices(self.min_n as usize, self.max_n as usize, &self.indexer)