Skip to content

Commit

Permalink
Merge pull request #167 from huggingface/hotfix_models_save
Browse files Browse the repository at this point in the history
Hotfix models save
  • Loading branch information
n1t0 committed Feb 24, 2020
2 parents be08d95 + 2bae286 commit 2ae3062
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 36 deletions.
28 changes: 4 additions & 24 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
use super::{
super::OrderedVocabIter, Cache, Error, Pair, WithFirstLastIterator, Word,
DEFAULT_CACHE_CAPACITY,
};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde::{Serialize, Serializer};
use serde_json::Value;
use std::{
collections::HashMap,
Expand Down Expand Up @@ -463,28 +465,6 @@ impl Model for BPE {
}
}

/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
/// of token ID, smallest to largest.
struct OrderedVocabIter<'a> {
vocab_r: &'a HashMap<u32, String>,
}

impl<'a> OrderedVocabIter<'a> {
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
Self { vocab_r }
}
}

impl<'a> Serialize for OrderedVocabIter<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
serializer.collect_map(iter)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
25 changes: 25 additions & 0 deletions tokenizers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,28 @@
pub mod bpe;
pub mod wordlevel;
pub mod wordpiece;

use serde::{Serialize, Serializer};
use std::collections::HashMap;

/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
/// of token ID, smallest to largest.
struct OrderedVocabIter<'a> {
vocab_r: &'a HashMap<u32, String>,
}

impl<'a> OrderedVocabIter<'a> {
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
Self { vocab_r }
}
}

impl<'a> Serialize for OrderedVocabIter<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
serializer.collect_map(iter)
}
}
15 changes: 5 additions & 10 deletions tokenizers/src/models/wordlevel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::OrderedVocabIter;
use crate::tokenizer::{Model, Result, Token};
use serde_json::Value;
use std::collections::HashMap;
Expand Down Expand Up @@ -168,20 +169,14 @@ impl Model for WordLevel {
None => "vocab.json".to_string(),
};

// Write vocab.txt
// Write vocab.json
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
vocab.sort_unstable_by_key(|k| *k.1);
vocab_file.write_all(
&vocab
.into_iter()
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.flatten()
.collect::<Vec<_>>()[..],
)?;
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter)?;
vocab_file.write_all(&serialized.as_bytes())?;

Ok(vec![vocab_path])
}
Expand Down
4 changes: 2 additions & 2 deletions tokenizers/src/models/wordpiece/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ impl Model for WordPiece {

fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name),
None => "vocab.json".to_string(),
Some(name) => format!("{}-vocab.txt", name),
None => "vocab.txt".to_string(),
};

// Write vocab.txt
Expand Down

0 comments on commit 2ae3062

Please sign in to comment.