Skip to content

Commit

Permalink
Add and implement the ReadFastText trait
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Aug 2, 2019
1 parent 57e52f2 commit 4bca170
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ordered-float = "1"
rand = "0.6"
rand_xorshift = "0.1"
reductive = "0.2"
serde = { version = "1", features = ["derive"] }
toml = "0.5"

[dev-dependencies]
Expand Down
327 changes: 327 additions & 0 deletions src/fasttext/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
use std::io::BufRead;

use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::{s, Array2, ErrorKind as ShapeErrorKind, ShapeError};
use serde::Serialize;
use toml::Value;

use crate::embeddings::Embeddings;
use crate::io::{Error, ErrorKind, Result};
use crate::metadata::Metadata;
use crate::norms::NdNorms;
use crate::storage::{NdArray, Storage, StorageViewMut};
use crate::subword::BucketIndexer;
use crate::util::{l2_normalize_array, read_string};
use crate::vocab::{FastTextSubwordVocab, Vocab};

use super::FastTextIndexer;

const FASTTEXT_FILEFORMAT_MAGIC: u32 = 793_712_314;
const FASTTEXT_VERSION: u32 = 12;

/// Read embeddings in the fastText format.
pub trait ReadFastText
where
Self: Sized,
{
/// Read embeddings in the fastText format.
fn read_fasttext(reader: &mut impl BufRead) -> Result<Self>;
}

impl ReadFastText for Embeddings<FastTextSubwordVocab, NdArray> {
fn read_fasttext(mut reader: &mut impl BufRead) -> Result<Self> {
let magic = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot fastText read magic", e))?;
if magic != FASTTEXT_FILEFORMAT_MAGIC {
return Err(ErrorKind::Format(format!(
"Expected {} as magic, got: {}",
FASTTEXT_FILEFORMAT_MAGIC, magic
))
.into());
}

let version = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read fastText version", e))?;
if version > FASTTEXT_VERSION {
return Err(ErrorKind::Format(format!(
"Expected {} as version, got: {}",
FASTTEXT_VERSION, version
))
.into());
}

let config = Config::read(&mut reader)?;

let vocab = read_vocab(&config, &mut reader)?;

let is_quantized = reader
.read_u8()
.map_err(|e| ErrorKind::io_error("Cannot read quantization information", e))?;
if is_quantized == 1 {
return Err(
ErrorKind::Format("Quantized fastText models are not supported".into()).into(),
);
}

// Read and prepare storage.
let mut storage = read_embeddings(&mut reader)?;
add_subword_embeddings(&vocab, &mut storage);
#[allow(clippy::deref_addrof)]
let norms = NdNorms(l2_normalize_array(
storage.view_mut().slice_mut(s![0..vocab.len(), ..]),
));

// Verify that vocab and storage shapes match.
if storage.shape().0 != vocab.len() + config.bucket as usize {
return Err(Error::Shape(ShapeError::from_kind(
ShapeErrorKind::IncompatibleShape,
)));
}

let metadata = Value::try_from(config).map_err(|e| {
ErrorKind::Format(format!("Cannot serialize model metadata to TOML: {}", e))
})?;

Ok(Embeddings::new(
Some(Metadata(metadata)),
vocab,
storage,
norms,
))
}
}

/// fastText model configuration.
#[derive(Copy, Clone, Debug, Serialize)]
struct Config {
pub dims: u32,
pub window_size: u32,
pub epoch: u32,
pub min_count: u32,
pub neg: u32,
pub word_ngrams: u32,
pub loss: Loss,
pub model: Model,
pub bucket: u32,
pub min_n: u32,
pub max_n: u32,
pub lr_update_rate: u32,
pub sampling_threshold: f64,
}

impl Config {
/// Read fastText model configuration.
fn read<R>(reader: &mut R) -> Result<Config>
where
R: BufRead,
{
let dims = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of dimensions", e))?;
let window_size = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read window size", e))?;
let epoch = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of epochs", e))?;
let min_count = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read minimum count", e))?;
let neg = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read negative samples", e))?;
let word_ngrams = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read word n-gram length", e))?;
let loss = Loss::read(reader)?;
let model = Model::read(reader)?;
let bucket = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of buckets", e))?;
let min_n = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read minimum subword length", e))?;
let max_n = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read maximum subword length", e))?;
let lr_update_rate = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read LR update rate", e))?;
let sampling_threshold = reader
.read_f64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read sampling threshold", e))?;

Ok(Config {
dims,
window_size,
epoch,
min_count,
neg,
word_ngrams,
loss,
model,
bucket,
min_n,
max_n,
lr_update_rate,
sampling_threshold,
})
}
}

/// fastText loss type.
#[derive(Copy, Clone, Debug, Serialize)]
enum Loss {
HierarchicalSoftmax,
NegativeSampling,
Softmax,
}

impl Loss {
fn read<R>(reader: &mut R) -> Result<Loss>
where
R: BufRead,
{
let loss = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read loss type", e))?;

use self::Loss::*;
match loss {
1 => Ok(HierarchicalSoftmax),
2 => Ok(NegativeSampling),
3 => Ok(Softmax),
l => Err(ErrorKind::Format(format!("Unknown loss: {}", l)).into()),
}
}
}

/// fastText model type.
#[derive(Copy, Clone, Debug, Serialize)]
enum Model {
CBOW,
SkipGram,
Supervised,
}

impl Model {
fn read<R>(reader: &mut R) -> Result<Model>
where
R: BufRead,
{
let model = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read model type", e))?;

use self::Model::*;
match model {
1 => Ok(CBOW),
2 => Ok(SkipGram),
3 => Ok(Supervised),
m => Err(ErrorKind::Format(format!("Unknown model: {}", m)).into()),
}
}
}

/// Add subword embeddings to word embeddings.
///
/// fastText stores word embeddings without subword embeddings. This method
/// adds the subword embeddings.
fn add_subword_embeddings(vocab: &FastTextSubwordVocab, embeds: &mut NdArray) {
for (idx, word) in vocab.words().iter().enumerate() {
if let Some(indices) = vocab.subword_indices(word) {
let n_embeds = indices.len() + 1;

// Sum the embedding and its subword embeddings.
let mut embed = embeds.embedding(idx).into_owned();
for subword_idx in indices {
embed += &embeds.embedding(subword_idx).as_view();
}

// Compute the average embedding.
embed /= n_embeds as f32;

embeds.view_mut().row_mut(idx).assign(&embed);
}
}
}

/// Read the embedding matrix.
fn read_embeddings<R>(reader: &mut R) -> Result<NdArray>
where
R: BufRead,
{
let m = reader
.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of embedding matrix rows", e))?;
let n = reader
.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of embedding matrix columns", e))?;

let mut data = vec![0.0; (m * n) as usize];
reader
.read_f32_into::<LittleEndian>(&mut data)
.map_err(|e| ErrorKind::io_error("Cannot read embeddings", e))?;

let data = Array2::from_shape_vec((m as usize, n as usize), data).map_err(Error::Shape)?;

Ok(NdArray(data))
}

/// Read the vocabulary.
fn read_vocab<R>(config: &Config, reader: &mut R) -> Result<FastTextSubwordVocab>
where
R: BufRead,
{
let size = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read vocabulary size", e))?;
reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of words", e))?;

let n_labels = reader
.read_u32::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot number of labels", e))?;
if n_labels > 0 {
return Err(
ErrorKind::Format("fastText prediction models are not supported".into()).into(),
);
}

reader
.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read number of tokens", e))?;

let prune_idx_size = reader
.read_i64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot read pruned vocabulary size", e))?;
if prune_idx_size > 0 {
return Err(ErrorKind::Format("Pruned vocabularies are not supported".into()).into());
}

let mut words = Vec::with_capacity(size as usize);
for _ in 0..size {
let word = read_string(reader, 0)?;
reader
.read_u64::<LittleEndian>()
.map_err(|e| ErrorKind::io_error("Cannot word frequency", e))?;
let entry_type = reader
.read_u8()
.map_err(|e| ErrorKind::io_error("Cannot read entry type", e))?;
if entry_type != 0 {
return Err(ErrorKind::Format("Non-word entry".into()).into());
}

words.push(word)
}

Ok(FastTextSubwordVocab::new(
words,
config.min_n,
config.max_n,
FastTextIndexer::new(config.bucket as usize),
))
}
3 changes: 3 additions & 0 deletions src/fasttext/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@

mod indexer;
pub use self::indexer::FastTextIndexer;

mod io;
pub use self::io::ReadFastText;
2 changes: 1 addition & 1 deletion src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ where
///
/// Returns `None` when the model does not support subwords or
/// when no subwords could be extracted.
fn subword_indices(&self, word: &str) -> Option<Vec<usize>> {
pub(crate) fn subword_indices(&self, word: &str) -> Option<Vec<usize>> {
let indices = Self::bracket(word)
.as_str()
.subword_indices(self.min_n as usize, self.max_n as usize, &self.indexer)
Expand Down

0 comments on commit 4bca170

Please sign in to comment.