Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3 0.21. #1494

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ name = "tokenizers"
crate-type = ["cdylib"]

[dependencies]
rayon = "1.8"
rayon = "1.10"
serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0"
libc = "0.2"
env_logger = "0.10.0"
pyo3 = { version = "0.20" }
numpy = "0.20.0"
env_logger = "0.11"
pyo3 = { version = "0.21" }
numpy = "0.21"
ndarray = "0.15"
onig = { version = "6.4", default-features = false }
itertools = "0.11"
itertools = "0.12"

[dependencies.tokenizers]
version = "0.16.0-dev.0"
path = "../../tokenizers"

[dev-dependencies]
tempfile = "3.8"
pyo3 = { version = "0.20", features = ["auto-initialize"] }
tempfile = "3.10"
pyo3 = { version = "0.21", features = ["auto-initialize"] }

[features]
defaut = ["pyo3/extension-module"]
33 changes: 14 additions & 19 deletions bindings/python/src/decoders.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::sync::{Arc, RwLock};

use crate::pre_tokenizers::from_string;
use crate::utils::PyChar;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
Expand Down Expand Up @@ -85,7 +84,7 @@ impl PyDecoder {
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
Expand Down Expand Up @@ -161,7 +160,7 @@ pub struct PyByteLevelDec {}
impl PyByteLevelDec {
#[new]
#[pyo3(signature = (**_kwargs), text_signature = "(self)")]
fn new(_kwargs: Option<&PyDict>) -> (Self, PyDecoder) {
fn new(_kwargs: Option<&Bound<'_, PyDict>>) -> (Self, PyDecoder) {
(PyByteLevelDec {}, ByteLevel::default().into())
}
}
Expand Down Expand Up @@ -318,8 +317,8 @@ impl PyMetaspaceDec {
}

#[setter]
fn set_replacement(self_: PyRef<Self>, replacement: PyChar) {
setter!(self_, Metaspace, @set_replacement, replacement.0);
fn set_replacement(self_: PyRef<Self>, replacement: char) {
setter!(self_, Metaspace, @set_replacement, replacement);
}

#[getter]
Expand Down Expand Up @@ -352,16 +351,12 @@ impl PyMetaspaceDec {
}

#[new]
#[pyo3(signature = (replacement = PyChar('▁'), prepend_scheme = String::from("always"), split = true), text_signature = "(self, replacement = \"▁\", prepend_scheme = \"always\", split = True)")]
fn new(
replacement: PyChar,
prepend_scheme: String,
split: bool,
) -> PyResult<(Self, PyDecoder)> {
#[pyo3(signature = (replacement = '▁', prepend_scheme = String::from("always"), split = true), text_signature = "(self, replacement = \"▁\", prepend_scheme = \"always\", split = True)")]
fn new(replacement: char, prepend_scheme: String, split: bool) -> PyResult<(Self, PyDecoder)> {
let prepend_scheme = from_string(prepend_scheme)?;
Ok((
PyMetaspaceDec {},
Metaspace::new(replacement.0, prepend_scheme, split).into(),
Metaspace::new(replacement, prepend_scheme, split).into(),
))
}
}
Expand Down Expand Up @@ -463,7 +458,7 @@ pub struct PySequenceDecoder {}
impl PySequenceDecoder {
#[new]
#[pyo3(signature = (decoders_py), text_signature = "(self, decoders)")]
fn new(decoders_py: &PyList) -> PyResult<(Self, PyDecoder)> {
fn new(decoders_py: &Bound<'_, PyList>) -> PyResult<(Self, PyDecoder)> {
let mut decoders: Vec<DecoderWrapper> = Vec::with_capacity(decoders_py.len());
for decoder_py in decoders_py.iter() {
let decoder: PyRef<PyDecoder> = decoder_py.extract()?;
Expand All @@ -476,8 +471,8 @@ impl PySequenceDecoder {
Ok((PySequenceDecoder {}, Sequence::new(decoders).into()))
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, [PyList::empty(py)])
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}
}

Expand All @@ -497,7 +492,7 @@ impl Decoder for CustomDecoder {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.call_method_bound(py, "decode", (tokens,), None)?
.extract(py)?;
Ok(decoded)
})
Expand All @@ -507,7 +502,7 @@ impl Decoder for CustomDecoder {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode_chain", (tokens,), None)?
.call_method_bound(py, "decode_chain", (tokens,), None)?
.extract(py)?;
Ok(decoded)
})
Expand Down Expand Up @@ -572,7 +567,7 @@ impl Decoder for PyDecoderWrapper {

/// Decoders Module
#[pymodule]
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
pub fn decoders(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyReplaceDec>()?;
Expand Down Expand Up @@ -602,7 +597,7 @@ mod test {
Python::with_gil(|py| {
let py_dec = PyDecoder::new(Metaspace::default().into());
let py_meta = py_dec.get_as_subtype(py).unwrap();
assert_eq!("Metaspace", py_meta.as_ref(py).get_type().name().unwrap());
assert_eq!("Metaspace", py_meta.bind(py).get_type().qualname().unwrap());
})
}

Expand Down
8 changes: 4 additions & 4 deletions bindings/python/src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl PyEncoding {
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
Expand Down Expand Up @@ -391,10 +391,10 @@ impl PyEncoding {
#[pyo3(
text_signature = "(self, length, direction='right', pad_id=0, pad_type_id=0, pad_token='[PAD]')"
)]
fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
fn pad(&mut self, length: usize, kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<()> {
let mut pad_id = 0;
let mut pad_type_id = 0;
let mut pad_token = "[PAD]";
let mut pad_token = "[PAD]".to_string();
let mut direction = PaddingDirection::Right;

if let Some(kwargs) = kwargs {
Expand Down Expand Up @@ -422,7 +422,7 @@ impl PyEncoding {
}
}
self.encoding
.pad(length, pad_id, pad_type_id, pad_token, direction);
.pad(length, pad_id, pad_type_id, &pad_token, direction);
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions bindings/python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<T> ToPyResult<T> {
}

pub(crate) fn deprecation_warning(py: Python<'_>, version: &str, message: &str) -> PyResult<()> {
let deprecation_warning = py.import("builtins")?.getattr("DeprecationWarning")?;
let deprecation_warning = py.import_bound("builtins")?.getattr("DeprecationWarning")?;
let full_message = format!("Deprecated in {}: {}", version, message);
pyo3::PyErr::warn(py, deprecation_warning, &full_message, 0)
pyo3::PyErr::warn_bound(py, &deprecation_warning, &full_message, 0)
}
2 changes: 1 addition & 1 deletion bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ extern "C" fn child_after_fork() {

/// Tokenizers Module
#[pymodule]
pub fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
pub fn tokenizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
let _ = env_logger::try_init_from_env("TOKENIZERS_LOG");

// Register the fork callback
Expand Down
35 changes: 19 additions & 16 deletions bindings/python/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl PyModel {
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
Expand Down Expand Up @@ -260,7 +260,10 @@ impl PyModel {
pub struct PyBPE {}

impl PyBPE {
fn with_builder(mut builder: BpeBuilder, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
fn with_builder(
mut builder: BpeBuilder,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
Expand Down Expand Up @@ -321,14 +324,14 @@ macro_rules! setter {
}

#[derive(FromPyObject)]
enum PyVocab<'a> {
enum PyVocab {
Vocab(Vocab),
Filename(&'a str),
Filename(String),
}
#[derive(FromPyObject)]
enum PyMerges<'a> {
enum PyMerges {
Merges(Merges),
Filename(&'a str),
Filename(String),
}

#[pymethods]
Expand Down Expand Up @@ -417,7 +420,7 @@ impl PyBPE {
py: Python<'_>,
vocab: Option<PyVocab>,
merges: Option<PyMerges>,
kwargs: Option<&PyDict>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::PyValueError::new_err(
Expand Down Expand Up @@ -502,11 +505,11 @@ impl PyBPE {
#[pyo3(signature = (vocab, merges, **kwargs))]
#[pyo3(text_signature = "(cls, vocab, merge, **kwargs)")]
fn from_file(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
py: Python,
vocab: &str,
merges: &str,
kwargs: Option<&PyDict>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<Self>> {
let (vocab, merges) = BPE::read_file(vocab, merges).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading BPE files: {}", e))
Expand Down Expand Up @@ -540,7 +543,7 @@ pub struct PyWordPiece {}
impl PyWordPiece {
fn with_builder(
mut builder: WordPieceBuilder,
kwargs: Option<&PyDict>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
Expand Down Expand Up @@ -612,7 +615,7 @@ impl PyWordPiece {
fn new(
py: Python<'_>,
vocab: Option<PyVocab>,
kwargs: Option<&PyDict>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<(Self, PyModel)> {
let mut builder = WordPiece::builder();

Expand Down Expand Up @@ -677,10 +680,10 @@ impl PyWordPiece {
#[pyo3(signature = (vocab, **kwargs))]
#[pyo3(text_signature = "(vocab, **kwargs)")]
fn from_file(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
py: Python,
vocab: &str,
kwargs: Option<&PyDict>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<Self>> {
let vocab = WordPiece::read_file(vocab).map_err(|e| {
exceptions::PyException::new_err(format!("Error while reading WordPiece file: {}", e))
Expand Down Expand Up @@ -796,7 +799,7 @@ impl PyWordLevel {
#[pyo3(signature = (vocab, unk_token = None))]
#[pyo3(text_signature = "(vocab, unk_token)")]
fn from_file(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
py: Python,
vocab: &str,
unk_token: Option<String>,
Expand Down Expand Up @@ -849,7 +852,7 @@ impl PyUnigram {

/// Models Module
#[pymodule]
pub fn models(_py: Python, m: &PyModule) -> PyResult<()> {
pub fn models(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyModel>()?;
m.add_class::<PyBPE>()?;
m.add_class::<PyWordPiece>()?;
Expand All @@ -870,7 +873,7 @@ mod test {
Python::with_gil(|py| {
let py_model = PyModel::from(BPE::default());
let py_bpe = py_model.get_as_subtype(py).unwrap();
assert_eq!("BPE", py_bpe.as_ref(py).get_type().name().unwrap());
assert_eq!("BPE", py_bpe.bind(py).get_type().qualname().unwrap());
})
}

Expand Down
20 changes: 10 additions & 10 deletions bindings/python/src/normalizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl PyNormalizer {
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
Ok(PyBytes::new_bound(py, data.as_bytes()).to_object(py))
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
Expand Down Expand Up @@ -345,7 +345,7 @@ pub struct PySequence {}
impl PySequence {
#[new]
#[pyo3(text_signature = None)]
fn new(normalizers: &PyList) -> PyResult<(Self, PyNormalizer)> {
fn new(normalizers: &Bound<'_, PyList>) -> PyResult<(Self, PyNormalizer)> {
let mut sequence = Vec::with_capacity(normalizers.len());
for n in normalizers.iter() {
let normalizer: PyRef<PyNormalizer> = n.extract()?;
Expand All @@ -360,8 +360,8 @@ impl PySequence {
))
}

fn __getnewargs__<'p>(&self, py: Python<'p>) -> &'p PyTuple {
PyTuple::new(py, [PyList::empty(py)])
fn __getnewargs__<'p>(&self, py: Python<'p>) -> Bound<'p, PyTuple> {
PyTuple::new_bound(py, [PyList::empty_bound(py)])
}

fn __len__(&self) -> usize {
Expand Down Expand Up @@ -467,11 +467,11 @@ pub struct PyPrecompiled {}
impl PyPrecompiled {
#[new]
#[pyo3(text_signature = "(self, precompiled_charsmap)")]
fn new(py_precompiled_charsmap: &PyBytes) -> PyResult<(Self, PyNormalizer)> {
let precompiled_charsmap: &[u8] = FromPyObject::extract(py_precompiled_charsmap)?;
fn new(precompiled_charsmap: Vec<u8>) -> PyResult<(Self, PyNormalizer)> {
// let precompiled_charsmap: Vec<u8> = FromPyObject::extract(py_precompiled_charsmap)?;
Ok((
PyPrecompiled {},
Precompiled::from(precompiled_charsmap)
Precompiled::from(&precompiled_charsmap)
.map_err(|e| {
exceptions::PyException::new_err(format!(
"Error while attempting to build Precompiled normalizer: {}",
Expand Down Expand Up @@ -512,7 +512,7 @@ impl tk::tokenizer::Normalizer for CustomNormalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
Python::with_gil(|py| {
let normalized = PyNormalizedStringRefMut::new(normalized);
let py_normalized = self.inner.as_ref(py);
let py_normalized = self.inner.bind(py);
py_normalized.call_method("normalize", (normalized.get(),), None)?;
Ok(())
})
Expand Down Expand Up @@ -635,7 +635,7 @@ impl Normalizer for PyNormalizerWrapper {

/// Normalizers Module
#[pymodule]
pub fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyNormalizer>()?;
m.add_class::<PyBertNormalizer>()?;
m.add_class::<PyNFD>()?;
Expand Down Expand Up @@ -667,7 +667,7 @@ mod test {
Python::with_gil(|py| {
let py_norm = PyNormalizer::new(NFC.into());
let py_nfc = py_norm.get_as_subtype(py).unwrap();
assert_eq!("NFC", py_nfc.as_ref(py).get_type().name().unwrap());
assert_eq!("NFC", py_nfc.bind(py).get_type().qualname().unwrap());
})
}

Expand Down