Skip to content

Commit

Permalink
Actually write norms chunk, verify with unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 21, 2019
1 parent 6d8fc0d commit 46e5ff1
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ where

self.vocab.write_chunk(write)?;
self.storage.write_chunk(write)?;

if let Some(norms) = self.norms() {
norms.write_chunk(write)?;
}

Ok(())
}
}
Expand Down Expand Up @@ -489,11 +494,13 @@ mod tests {
use std::fs::File;
use std::io::{BufReader, Cursor, Seek, SeekFrom};

use ndarray::array;
use toml::toml;

use super::Embeddings;
use crate::io::{MmapEmbeddings, ReadEmbeddings, WriteEmbeddings};
use crate::metadata::Metadata;
use crate::norms::NdNorms;
use crate::storage::{MmapArray, NdArray, StorageView};
use crate::vocab::SimpleVocab;
use crate::word2vec::ReadWord2VecRaw;
Expand Down Expand Up @@ -525,6 +532,27 @@ mod tests {
assert_eq!(embeds.storage().view(), check_embeds.storage().view());
}

#[test]
fn norms() {
let vocab = SimpleVocab::new(vec!["norms".to_string(), "test".to_string()]);
let storage = NdArray(array![[1f32], [-1f32]]);
let norms = NdNorms(array![2f32, 3f32]);
let check = Embeddings::new(None, vocab, storage, norms);

let mut serialized = Cursor::new(Vec::new());
check.write_embeddings(&mut serialized).unwrap();
serialized.seek(SeekFrom::Start(0)).unwrap();

let embeddings: Embeddings<SimpleVocab, NdArray> =
Embeddings::read_embeddings(&mut serialized).unwrap();

assert!(check
.norms()
.unwrap()
.0
.all_close(&embeddings.norms().unwrap().0, 1e-8),);
}

#[test]
fn write_read_simple_roundtrip() {
let check_embeds = test_embeddings();
Expand Down

0 comments on commit 46e5ff1

Please sign in to comment.