diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 4acb7f9de..d1726daca 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -4,6 +4,10 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use tk::tokenizer::{ + PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, +}; + use super::decoders::Decoder; use super::encoding::Encoding; use super::error::{PyError, ToPyResult}; @@ -14,10 +18,6 @@ use super::processors::PostProcessor; use super::trainers::Trainer; use super::utils::Container; -use tk::tokenizer::{ - PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, -}; - #[pyclass(dict)] pub struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, @@ -190,33 +190,23 @@ impl Tokenizer { } #[args(kwargs = "**")] - fn decode(&self, ids: Vec, kwargs: Option<&PyDict>) -> PyResult { - let mut skip_special_tokens = true; - - if let Some(kwargs) = kwargs { - if let Some(skip) = kwargs.get_item("skip_special_tokens") { - skip_special_tokens = skip.extract()?; - } - } - - ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() + fn decode(&self, ids: Vec, skip_special_tokens: Option) -> PyResult { + ToPyResult(self.tokenizer.decode( + ids, + skip_special_tokens.unwrap_or(true), + )).into() } #[args(kwargs = "**")] fn decode_batch( &self, sentences: Vec>, - kwargs: Option<&PyDict>, + skip_special_tokens: Option, ) -> PyResult> { - let mut skip_special_tokens = true; - - if let Some(kwargs) = kwargs { - if let Some(skip) = kwargs.get_item("skip_special_tokens") { - skip_special_tokens = skip.extract()?; - } - } - - ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into() + ToPyResult(self.tokenizer.decode_batch( + sentences, + skip_special_tokens.unwrap_or(true), + )).into() } fn token_to_id(&self, token: &str) -> Option {