Skip to content

Commit

Permalink
Merge pull request #190 from huggingface/improve-models-builders
Browse files Browse the repository at this point in the history
Improve models builders
  • Loading branch information
n1t0 committed Mar 6, 2020
2 parents 6b3dfcf + 7d932a6 commit 6d6e72b
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 76 deletions.
4 changes: 2 additions & 2 deletions bindings/node/native/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 20 additions & 12 deletions bindings/node/native/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ pub fn bpe_from_files(mut cx: FunctionContext) -> JsResult<JsModel> {
let options = cx.argument_opt(2);

let mut model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let mut builder = tk::models::bpe::BPE::from_files(&vocab, &merges)
.or_else(|e| cx.throw_error(format!("{}", e)))?;
let mut builder = tk::models::bpe::BPE::from_files(&vocab, &merges);

if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
Expand Down Expand Up @@ -133,33 +132,42 @@ pub fn bpe_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
/// wordpiece_from_files(vocab: String, options?: {
/// unkToken?: String = "[UNK]",
/// maxInputCharsPerWord?: number = 100,
/// continuingSubwordPrefix?: "##",
/// })
pub fn wordpiece_from_files(mut cx: FunctionContext) -> JsResult<JsModel> {
let vocab = cx.argument::<JsString>(0)?.value() as String;
let options = cx.argument_opt(1);

let mut unk_token = String::from("[UNK]");
let mut max_input_chars_per_word = Some(100);
let mut builder = tk::models::wordpiece::WordPiece::from_files(&vocab);

if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(unk) = options.get(&mut cx, "unkToken") {
if let Err(_) = unk.downcast::<JsUndefined>() {
unk_token = unk.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
if unk.downcast::<JsUndefined>().is_err() {
builder = builder
.unk_token(unk.downcast::<JsString>().or_throw(&mut cx)?.value() as String);
}
}
if let Ok(max) = options.get(&mut cx, "maxInputCharsPerWord") {
if let Err(_) = max.downcast::<JsUndefined>() {
max_input_chars_per_word =
Some(max.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize);
if max.downcast::<JsUndefined>().is_err() {
builder = builder.max_input_chars_per_word(
max.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize,
);
}
}
if let Ok(prefix) = options.get(&mut cx, "continuingSubwordPrefix") {
if prefix.downcast::<JsUndefined>().is_err() {
builder = builder.continuing_subword_prefix(
prefix.downcast::<JsString>().or_throw(&mut cx)?.value() as String,
);
}
}
}
}

let wordpiece =
tk::models::wordpiece::WordPiece::from_files(&vocab, unk_token, max_input_chars_per_word)
.or_else(|e| cx.throw_error(format!("{}", e)))?;
let wordpiece = builder
.build()
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;

let mut model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
Expand Down
25 changes: 12 additions & 13 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ impl BPE {
#[staticmethod]
#[args(kwargs = "**")]
fn from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let builder: PyResult<_> =
ToPyResult(tk::models::bpe::BPE::from_files(vocab, merges)).into();
let mut builder = builder?;

let mut builder = tk::models::bpe::BPE::from_files(vocab, merges);
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
Expand Down Expand Up @@ -115,25 +112,27 @@ impl WordPiece {
#[staticmethod]
#[args(kwargs = "**")]
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let mut unk_token = String::from("[UNK]");
let mut max_input_chars_per_word = Some(100);
let mut builder = tk::models::wordpiece::WordPiece::from_files(vocab);

if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"unk_token" => unk_token = val.extract()?,
"max_input_chars_per_word" => max_input_chars_per_word = Some(val.extract()?),
"unk_token" => {
builder = builder.unk_token(val.extract()?);
}
"max_input_chars_per_word" => {
builder = builder.max_input_chars_per_word(val.extract()?);
}
"continuing_subword_prefix" => {
builder = builder.continuing_subword_prefix(val.extract()?);
}
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}

match tk::models::wordpiece::WordPiece::from_files(
vocab,
unk_token,
max_input_chars_per_word,
) {
match builder.build() {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
Expand Down
1 change: 1 addition & 0 deletions tokenizers/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Changes:
- Keep only one progress bar while reading files during training. This is better for use-cases with
a high number of files as it avoids having too many progress bar on screen.
- Improve BPE and WordPiece builders.

# v0.8.0

Expand Down
2 changes: 0 additions & 2 deletions tokenizers/benches/bpe_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ fn iter_bench_encode_batch(

fn bench_gpt2(c: &mut Criterion) {
let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
.unwrap()
.build()
.unwrap();
let tokenizer = create_gpt2_tokenizer(bpe);
Expand All @@ -95,7 +94,6 @@ fn bench_gpt2(c: &mut Criterion) {
});

let bpe = BPE::from_files("benches/gpt2-vocab.json", "benches/gpt2-merges.txt")
.unwrap()
.cache_capacity(0)
.build()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
.value_of("merges")
.expect("Must give a merges.txt file");

let bpe = BPE::from_files(vocab, merges)?.build()?;
let bpe = BPE::from_files(vocab, merges).build()?;
let mut tokenizer = Tokenizer::new(Box::new(bpe));
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::new(true)));
tokenizer.with_decoder(Box::new(ByteLevel::new(false)));
Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
//! use tokenizers::models::bpe::BPE;
//!
//! fn main() -> Result<()> {
//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt")?;
//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt");
//! let bpe = bpe_builder
//! .dropout(0.1)
//! .unk_token("[UNK]".into())
Expand Down
74 changes: 48 additions & 26 deletions tokenizers/src/models/bpe/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ use std::{
path::{Path, PathBuf},
};

type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;
type Merges = HashMap<Pair, (u32, u32)>;

struct Config {
vocab: HashMap<String, u32>,
merges: HashMap<Pair, (u32, u32)>,
files: Option<(String, String)>,
vocab: Vocab,
merges: Merges,
cache_capacity: usize,
dropout: Option<f32>,
unk_token: Option<String>,
Expand All @@ -31,6 +36,7 @@ impl Default for BpeBuilder {
fn default() -> Self {
Self {
config: Config {
files: None,
vocab: HashMap::new(),
merges: HashMap::new(),
cache_capacity: DEFAULT_CACHE_CAPACITY,
Expand All @@ -49,12 +55,14 @@ impl BpeBuilder {
Self::default()
}

/// Set the input files.
pub fn files(mut self, vocab: String, merges: String) -> Self {
self.config.files = Some((vocab, merges));
self
}

/// Set the vocab (token -> ID) and merges mappings.
pub fn vocab_and_merges(
mut self,
vocab: HashMap<String, u32>,
merges: HashMap<Pair, (u32, u32)>,
) -> Self {
pub fn vocab_and_merges(mut self, vocab: Vocab, merges: Merges) -> Self {
self.config.vocab = vocab;
self.config.merges = merges;
self
Expand Down Expand Up @@ -91,14 +99,21 @@ impl BpeBuilder {
}

/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(self) -> Result<BPE> {
pub fn build(mut self) -> Result<BPE> {
// Validate dropout.
if let Some(p) = self.config.dropout {
if p <= 0.0 || p > 1.0 {
return Err(Error::InvalidDropout.into());
}
}

// Read files if necessary
if let Some((vocab, merges)) = self.config.files {
let (v, m) = BPE::read_files(&vocab, &merges)?;
self.config.vocab = v;
self.config.merges = m;
}

let vocab_r = self
.config
.vocab
Expand Down Expand Up @@ -126,11 +141,11 @@ impl BpeBuilder {
/// A [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
pub struct BPE {
/// The vocabulary assigns a number to each token.
pub(crate) vocab: HashMap<String, u32>,
pub(crate) vocab: Vocab,
/// Reversed vocabulary, to rebuild sentences.
pub(crate) vocab_r: HashMap<u32, String>,
pub(crate) vocab_r: VocabR,
/// Contains the mapping between Pairs and their (rank, new_id).
pub(crate) merges: HashMap<Pair, (u32, u32)>,
pub(crate) merges: Merges,
/// Contains the cache for optimizing the encoding step.
cache: Option<Cache<String, Word>>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
Expand Down Expand Up @@ -175,15 +190,20 @@ impl BPE {
}

/// Create a new BPE model with the given vocab and merges.
pub fn new(vocab: HashMap<String, u32>, merges: HashMap<Pair, (u32, u32)>) -> Self {
pub fn new(vocab: Vocab, merges: Merges) -> Self {
Self::builder()
.vocab_and_merges(vocab, merges)
.build()
.unwrap()
}

/// Initialize a BPE model from vocab and merges file.
pub fn from_files(vocab: &str, merges: &str) -> Result<BpeBuilder> {
/// Initialize a BpeBuilder model from vocab and merges files
pub fn from_files(vocab: &str, merges: &str) -> BpeBuilder {
BPE::builder().files(vocab.to_owned(), merges.to_owned())
}

/// Read the given files to extract the vocab and merges
pub fn read_files(vocab: &str, merges: &str) -> Result<(Vocab, Merges)> {
// Read vocab.json
let vocab_file = File::open(vocab)?;
let mut vocab_file = BufReader::new(vocab_file);
Expand All @@ -207,7 +227,7 @@ impl BPE {
// Read merges file
let merge_file = File::open(merges)?;
let merge_file = BufReader::new(merge_file);
let mut merges = HashMap::<Pair, (u32, u32)>::new();
let mut merges = HashMap::new();
for (rank, line) in merge_file.lines().enumerate() {
let line = line?;
if line.starts_with("#version") {
Expand Down Expand Up @@ -235,7 +255,7 @@ impl BPE {
merges.insert(pair, (rank as u32, *new_id));
}

Ok(BPE::builder().vocab_and_merges(vocab, merges))
Ok((vocab, merges))
}

/// Reset the cache.
Expand All @@ -245,7 +265,7 @@ impl BPE {
}
}

pub fn get_vocab(&self) -> &HashMap<String, u32> {
pub fn get_vocab(&self) -> &Vocab {
&self.vocab
}

Expand Down Expand Up @@ -433,7 +453,7 @@ mod tests {

#[test]
fn test_ordered_vocab_iter() {
let vocab_r: HashMap<u32, String> = [
let vocab_r: VocabR = [
(0, "a".into()),
(1, "b".into()),
(2, "c".into()),
Expand All @@ -453,7 +473,7 @@ mod tests {
//
// To test this, we'll build a simple model to tokenize the word 'unrelated'.
fn test_tokenize_with_and_without_dropout() {
let vocab: HashMap<String, u32> = [
let vocab: Vocab = [
("u".into(), 0),
("n".into(), 1),
("r".into(), 2),
Expand All @@ -474,7 +494,7 @@ mod tests {
.iter()
.cloned()
.collect();
let merges: HashMap<Pair, (u32, u32)> = [
let merges: Merges = [
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
((vocab["e"], vocab["d"]), (3u32, vocab["ed"])), // 'e-d' -> 'ed'
Expand Down Expand Up @@ -533,13 +553,11 @@ mod tests {
merges_file.write_all(b"#version: 0.2\na b").unwrap();

// Make sure we can instantiate a BPE model from the files.
let result = BPE::from_files(
let builder = BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap(),
);
assert!(result.is_ok());

let bpe = result.unwrap().build().unwrap();
let bpe = builder.build().unwrap();

// Check merges.
assert_eq!(bpe.merges.get(&(0u32, 1u32)).unwrap(), &(1u32, 3u32));
Expand Down Expand Up @@ -568,7 +586,9 @@ mod tests {
match BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap(),
) {
)
.build()
{
Ok(_) => unreachable!(),
Err(err) => match err.downcast_ref::<Error>() {
Some(Error::MergeTokenOutOfVocabulary(token)) => {
Expand Down Expand Up @@ -597,7 +617,9 @@ mod tests {
match BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap(),
) {
)
.build()
{
Ok(_) => unreachable!(),
Err(err) => match err.downcast_ref::<Error>() {
Some(Error::BadMerges(line)) => assert_eq!(*line, 3usize),
Expand Down
Loading

0 comments on commit 6d6e72b

Please sign in to comment.