Skip to content

Commit

Permalink
Remove kwargs mapping on Tokenizer decode/decode_batch as their is on…
Browse files Browse the repository at this point in the history
…ly one possible arg.

This is suggested by the current issue #54 (comment).

kwargs cannot be called as positional argument, they have to be named one, replacing kwargs with the actual skip_special_tokens
allows both (named and positional) syntax.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
  • Loading branch information
mfuntowicz committed Jan 15, 2020
1 parent a779714 commit 657602d
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;

use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
};

use super::decoders::Decoder;
use super::encoding::Encoding;
use super::error::{PyError, ToPyResult};
Expand All @@ -14,10 +18,6 @@ use super::processors::PostProcessor;
use super::trainers::Trainer;
use super::utils::Container;

use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
};

#[pyclass(dict)]
pub struct Tokenizer {
tokenizer: tk::tokenizer::Tokenizer,
Expand Down Expand Up @@ -190,33 +190,23 @@ impl Tokenizer {
}

#[args(kwargs = "**")]
fn decode(&self, ids: Vec<u32>, kwargs: Option<&PyDict>) -> PyResult<String> {
let mut skip_special_tokens = true;

if let Some(kwargs) = kwargs {
if let Some(skip) = kwargs.get_item("skip_special_tokens") {
skip_special_tokens = skip.extract()?;
}
}

ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
fn decode(&self, ids: Vec<u32>, skip_special_tokens: Option<bool>) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(
ids,
skip_special_tokens.unwrap_or(true),
)).into()
}

#[args(kwargs = "**")]
fn decode_batch(
&self,
sentences: Vec<Vec<u32>>,
kwargs: Option<&PyDict>,
skip_special_tokens: Option<bool>,
) -> PyResult<Vec<String>> {
let mut skip_special_tokens = true;

if let Some(kwargs) = kwargs {
if let Some(skip) = kwargs.get_item("skip_special_tokens") {
skip_special_tokens = skip.extract()?;
}
}

ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into()
ToPyResult(self.tokenizer.decode_batch(
sentences,
skip_special_tokens.unwrap_or(true),
)).into()
}

fn token_to_id(&self, token: &str) -> Option<u32> {
Expand Down

0 comments on commit 657602d

Please sign in to comment.