From 657602d078068e31b56378181b6cd9b2eb9d76e1 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Wed, 15 Jan 2020 10:55:34 +0100 Subject: [PATCH] Remove kwargs mapping on Tokenizer decode/decode_batch as their is only one possible arg. This is suggested by the current issue https://github.com/huggingface/tokenizers/issues/54#issuecomment-574104841. kwargs cannot be called as positional argument, they have to be named one, replacing kwargs with the actual skip_special_tokens allows both (named and positional) syntax. Signed-off-by: Morgan Funtowicz --- bindings/python/src/tokenizer.rs | 38 ++++++++++++-------------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 4acb7f9de..d1726daca 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -4,6 +4,10 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use tk::tokenizer::{ + PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, +}; + use super::decoders::Decoder; use super::encoding::Encoding; use super::error::{PyError, ToPyResult}; @@ -14,10 +18,6 @@ use super::processors::PostProcessor; use super::trainers::Trainer; use super::utils::Container; -use tk::tokenizer::{ - PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, -}; - #[pyclass(dict)] pub struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, @@ -190,33 +190,23 @@ impl Tokenizer { } #[args(kwargs = "**")] - fn decode(&self, ids: Vec, kwargs: Option<&PyDict>) -> PyResult { - let mut skip_special_tokens = true; - - if let Some(kwargs) = kwargs { - if let Some(skip) = kwargs.get_item("skip_special_tokens") { - skip_special_tokens = skip.extract()?; - } - } - - ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() + fn decode(&self, ids: Vec, skip_special_tokens: Option) -> PyResult { + ToPyResult(self.tokenizer.decode( + ids, + skip_special_tokens.unwrap_or(true), + )).into() } #[args(kwargs = "**")] fn decode_batch( &self, sentences: Vec>, - kwargs: Option<&PyDict>, + skip_special_tokens: Option, ) -> PyResult> { - let mut skip_special_tokens = true; - - if let Some(kwargs) = kwargs { - if let Some(skip) = kwargs.get_item("skip_special_tokens") { - skip_special_tokens = skip.extract()?; - } - } - - ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into() + ToPyResult(self.tokenizer.decode_batch( + sentences, + skip_special_tokens.unwrap_or(true), + )).into() } fn token_to_id(&self, token: &str) -> Option {