Skip to content

Commit

Permalink
Add FastTextSubwordVocab and I/O boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Aug 2, 2019
1 parent b0dea82 commit 57e52f2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 3 deletions.
9 changes: 8 additions & 1 deletion src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use crate::storage::{
StorageView, StorageViewWrap, StorageWrap,
};
use crate::util::l2_normalize;
use crate::vocab::{FinalfusionSubwordVocab, SimpleVocab, Vocab, VocabWrap, WordIndex};
use crate::vocab::{
FastTextSubwordVocab, FinalfusionSubwordVocab, SimpleVocab, Vocab, VocabWrap, WordIndex,
};

/// Word embeddings.
///
Expand Down Expand Up @@ -239,6 +241,11 @@ impl_embeddings_from!(FinalfusionSubwordVocab, NdArray, StorageViewWrap);
impl_embeddings_from!(FinalfusionSubwordVocab, MmapArray, StorageWrap);
impl_embeddings_from!(FinalfusionSubwordVocab, MmapArray, StorageViewWrap);
impl_embeddings_from!(FinalfusionSubwordVocab, QuantizedArray, StorageWrap);
impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageWrap);
impl_embeddings_from!(FastTextSubwordVocab, NdArray, StorageViewWrap);
impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageWrap);
impl_embeddings_from!(FastTextSubwordVocab, MmapArray, StorageViewWrap);
impl_embeddings_from!(FastTextSubwordVocab, QuantizedArray, StorageWrap);

impl<'a, V, S> IntoIterator for &'a Embeddings<V, S>
where
Expand Down
3 changes: 3 additions & 0 deletions src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ pub(crate) mod private {
QuantizedArray = 4,
Metadata = 5,
NdNorms = 6,
FastTextSubwordVocab = 7,
}

impl ChunkIdentifier {
Expand All @@ -198,6 +199,7 @@ pub(crate) mod private {
4 => Some(QuantizedArray),
5 => Some(Metadata),
6 => Some(NdNorms),
7 => Some(FastTextSubwordVocab),
_ => None,
}
}
Expand Down Expand Up @@ -233,6 +235,7 @@ pub(crate) mod private {
Header => write!(f, "Header"),
SimpleVocab => write!(f, "SimpleVocab"),
NdArray => write!(f, "NdArray"),
FastTextSubwordVocab => write!(f, "FastTextSubwordVocab"),
FinalfusionSubwordVocab => write!(f, "FinalfusionSubwordVocab"),
QuantizedArray => write!(f, "QuantizedArray"),
Metadata => write!(f, "Metadata"),
Expand Down
68 changes: 66 additions & 2 deletions src/vocab.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::mem::size_of;

use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};

use crate::fasttext::FastTextIndexer;
use crate::io::private::{ChunkIdentifier, ReadChunk, WriteChunk};
use crate::io::{Error, ErrorKind, Result};
use crate::subword::{BucketIndexer, FinalfusionHashIndexer, Indexer, SubwordIndices};
Expand Down Expand Up @@ -110,6 +111,9 @@ impl WriteChunk for SimpleVocab {
}
}

/// fastText subword vocabulary.
pub type FastTextSubwordVocab = SubwordVocab<FastTextIndexer>;

/// Native finalfusion subword vocabulary.
pub type FinalfusionSubwordVocab = SubwordVocab<FinalfusionHashIndexer>;

Expand Down Expand Up @@ -171,6 +175,15 @@ where
}
}

impl ReadChunk for FastTextSubwordVocab {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
R: Read + Seek,
{
Self::read_bucketed_chunk(read, ChunkIdentifier::FastTextSubwordVocab)
}
}

impl ReadChunk for FinalfusionSubwordVocab {
fn read_chunk<R>(read: &mut R) -> Result<Self>
where
Expand All @@ -180,6 +193,19 @@ impl ReadChunk for FinalfusionSubwordVocab {
}
}

impl WriteChunk for FastTextSubwordVocab {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::FastTextSubwordVocab
}

fn write_chunk<W>(&self, write: &mut W) -> Result<()>
where
W: Write + Seek,
{
self.write_bucketed_chunk(write, self.chunk_identifier())
}
}

impl WriteChunk for FinalfusionSubwordVocab {
fn chunk_identifier(&self) -> ChunkIdentifier {
ChunkIdentifier::FinalfusionSubwordVocab
Expand Down Expand Up @@ -316,6 +342,7 @@ where
#[derive(Clone, Debug)]
pub enum VocabWrap {
SimpleVocab(SimpleVocab),
FastTextSubwordVocab(FastTextSubwordVocab),
FinalfusionSubwordVocab(FinalfusionSubwordVocab),
}

Expand All @@ -325,6 +352,12 @@ impl From<SimpleVocab> for VocabWrap {
}
}

impl From<FastTextSubwordVocab> for VocabWrap {
fn from(v: FastTextSubwordVocab) -> Self {
VocabWrap::FastTextSubwordVocab(v)
}
}

impl From<FinalfusionSubwordVocab> for VocabWrap {
fn from(v: FinalfusionSubwordVocab) -> Self {
VocabWrap::FinalfusionSubwordVocab(v)
Expand Down Expand Up @@ -354,12 +387,16 @@ impl ReadChunk for VocabWrap {
ChunkIdentifier::SimpleVocab => {
SimpleVocab::read_chunk(read).map(VocabWrap::SimpleVocab)
}
ChunkIdentifier::FastTextSubwordVocab => {
SubwordVocab::read_chunk(read).map(VocabWrap::FastTextSubwordVocab)
}
ChunkIdentifier::FinalfusionSubwordVocab => {
SubwordVocab::read_chunk(read).map(VocabWrap::FinalfusionSubwordVocab)
}
_ => Err(ErrorKind::Format(format!(
"Invalid chunk identifier, expected one of: {} or {}, got: {}",
"Invalid chunk identifier, expected one of: {}, {} or {}, got: {}",
ChunkIdentifier::SimpleVocab,
ChunkIdentifier::FastTextSubwordVocab,
ChunkIdentifier::FinalfusionSubwordVocab,
chunk_id
))
Expand All @@ -372,6 +409,7 @@ impl WriteChunk for VocabWrap {
fn chunk_identifier(&self) -> ChunkIdentifier {
match self {
VocabWrap::SimpleVocab(inner) => inner.chunk_identifier(),
VocabWrap::FastTextSubwordVocab(inner) => inner.chunk_identifier(),
VocabWrap::FinalfusionSubwordVocab(inner) => inner.chunk_identifier(),
}
}
Expand All @@ -382,6 +420,7 @@ impl WriteChunk for VocabWrap {
{
match self {
VocabWrap::SimpleVocab(inner) => inner.write_chunk(write),
VocabWrap::FastTextSubwordVocab(inner) => inner.write_chunk(write),
VocabWrap::FinalfusionSubwordVocab(inner) => inner.write_chunk(write),
}
}
Expand Down Expand Up @@ -441,6 +480,7 @@ impl Vocab for VocabWrap {
fn idx(&self, word: &str) -> Option<WordIndex> {
match self {
VocabWrap::SimpleVocab(inner) => inner.idx(word),
VocabWrap::FastTextSubwordVocab(inner) => inner.idx(word),
VocabWrap::FinalfusionSubwordVocab(inner) => inner.idx(word),
}
}
Expand All @@ -449,6 +489,7 @@ impl Vocab for VocabWrap {
fn len(&self) -> usize {
match self {
VocabWrap::SimpleVocab(inner) => inner.len(),
VocabWrap::FastTextSubwordVocab(inner) => inner.len(),
VocabWrap::FinalfusionSubwordVocab(inner) => inner.len(),
}
}
Expand All @@ -457,6 +498,7 @@ impl Vocab for VocabWrap {
fn words(&self) -> &[String] {
match self {
VocabWrap::SimpleVocab(inner) => inner.words(),
VocabWrap::FastTextSubwordVocab(inner) => inner.words(),
VocabWrap::FinalfusionSubwordVocab(inner) => inner.words(),
}
}
Expand All @@ -478,10 +520,22 @@ mod tests {

use byteorder::{LittleEndian, ReadBytesExt};

use super::{FinalfusionSubwordVocab, SimpleVocab, SubwordVocab};
use super::{FastTextSubwordVocab, FinalfusionSubwordVocab, SimpleVocab, SubwordVocab};
use crate::fasttext::FastTextIndexer;
use crate::io::private::{ReadChunk, WriteChunk};
use crate::subword::{BucketIndexer, FinalfusionHashIndexer};

fn test_fasttext_subword_vocab() -> FastTextSubwordVocab {
let words = vec![
"this".to_owned(),
"is".to_owned(),
"a".to_owned(),
"test".to_owned(),
];
let indexer = FastTextIndexer::new(20);
SubwordVocab::new(words, 3, 6, indexer)
}

fn test_simple_vocab() -> SimpleVocab {
let words = vec![
"this".to_owned(),
Expand Down Expand Up @@ -512,6 +566,16 @@ mod tests {
read.read_u64::<LittleEndian>().unwrap()
}

#[test]
fn fasttext_subword_vocab_write_read_roundtrip() {
let check_vocab = test_fasttext_subword_vocab();
let mut cursor = Cursor::new(Vec::new());
check_vocab.write_chunk(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let vocab = SubwordVocab::read_chunk(&mut cursor).unwrap();
assert_eq!(vocab, check_vocab);
}

#[test]
fn simple_vocab_write_read_roundtrip() {
let check_vocab = test_simple_vocab();
Expand Down

0 comments on commit 57e52f2

Please sign in to comment.