Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ crate-type = ["cdylib"]

[dependencies]
biscuit-auth = { version = "6.0.0", features = ["pem"] }
pyo3 = { version = "0.24.1", features = ["extension-module", "chrono"] }
pyo3 = { version = "0.28.2", features = ["extension-module", "chrono"] }
hex = "0.4"
base64 = "0.13.0"
chrono = "0.4"
84 changes: 43 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use ::biscuit_auth::{
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::IntoPyObjectExt;

use pyo3::create_exception;

Expand Down Expand Up @@ -60,8 +61,7 @@ create_exception!(
BiscuitBlockError,
pyo3::exceptions::PyException
);

#[pyclass(eq, eq_int, name = "Algorithm")]
#[pyclass(eq, eq_int, name = "Algorithm", from_py_object)]
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum PyAlgorithm {
Ed25519,
Expand All @@ -86,12 +86,12 @@ impl From<PyAlgorithm> for builder::Algorithm {
}

struct PyKeyProvider {
py_value: PyObject,
py_value: Py<PyAny>,
}

impl RootKeyProvider for PyKeyProvider {
fn choose(&self, kid: Option<u32>) -> Result<PublicKey, error::Format> {
Python::with_gil(|py| {
Python::attach(|py| {
let bound = self.py_value.bind(py);
if bound.is_callable() {
let result = bound
Expand Down Expand Up @@ -322,7 +322,7 @@ impl PyBiscuit {
/// :return: the parsed and verified biscuit
/// :rtype: Biscuit
#[classmethod]
pub fn from_bytes(_: &Bound<PyType>, data: &[u8], root: PyObject) -> PyResult<PyBiscuit> {
pub fn from_bytes(_: &Bound<PyType>, data: &[u8], root: Py<PyAny>) -> PyResult<PyBiscuit> {
match Biscuit::from(data, PyKeyProvider { py_value: root }) {
Ok(biscuit) => Ok(PyBiscuit(biscuit)),
Err(error) => Err(BiscuitValidationError::new_err(error.to_string())),
Expand All @@ -340,7 +340,7 @@ impl PyBiscuit {
/// :return: the parsed and verified biscuit
/// :rtype: Biscuit
#[classmethod]
pub fn from_base64(_: &Bound<PyType>, data: &str, root: PyObject) -> PyResult<PyBiscuit> {
pub fn from_base64(_: &Bound<PyType>, data: &str, root: Py<PyAny>) -> PyResult<PyBiscuit> {
match Biscuit::from_base64(data, PyKeyProvider { py_value: root }) {
Ok(biscuit) => Ok(PyBiscuit(biscuit)),
Err(error) => Err(BiscuitValidationError::new_err(error.to_string())),
Expand Down Expand Up @@ -470,7 +470,7 @@ impl PyBiscuit {
#[pyclass(name = "AuthorizerBuilder")]
pub struct PyAuthorizerBuilder(Option<AuthorizerBuilder>);

#[pyclass(name = "AuthorizerLimits")]
#[pyclass(name = "AuthorizerLimits", from_py_object)]
#[derive(Clone)]
pub struct PyAuthorizerLimits {
#[pyo3(get, set)]
Expand Down Expand Up @@ -682,15 +682,15 @@ impl PyAuthorizerBuilder {
Ok(())
}

pub fn register_extern_func(&mut self, name: &str, func: PyObject) -> PyResult<()> {
pub fn register_extern_func(&mut self, name: &str, func: Py<PyAny>) -> PyResult<()> {
self.0 = Some(
self.0
.take()
.expect("builder already consumed")
.register_extern_func(
name.to_string(),
ExternFunc::new(Arc::new(move |left, right| {
Python::with_gil(|py| {
Python::attach(|py| {
let bound = func.bind(py);
if bound.is_callable() {
let left = term_to_py(&left).map_err(|e| e.to_string())?;
Expand All @@ -703,8 +703,8 @@ impl PyAuthorizerBuilder {
None => bound.call1((left,)).map_err(|e| e.to_string())?,
};
let py_result: PyTerm =
result.extract().map_err(|e| e.to_string())?;
Ok(py_result.to_term().map_err(|e| e.to_string())?)
result.extract().map_err(|e: PyErr| e.to_string())?;
Ok(py_result.to_term().map_err(|e: PyErr| e.to_string())?)
} else {
Err("expected a function".to_string())
}
Expand All @@ -715,15 +715,15 @@ impl PyAuthorizerBuilder {
Ok(())
}

pub fn register_extern_funcs(&mut self, funcs: HashMap<String, PyObject>) -> PyResult<()> {
pub fn register_extern_funcs(&mut self, funcs: HashMap<String, Py<PyAny>>) -> PyResult<()> {
for (name, func) in funcs {
self.register_extern_func(&name, func)?;
}

Ok(())
}

pub fn set_extern_funcs(&mut self, funcs: HashMap<String, PyObject>) -> PyResult<()> {
pub fn set_extern_funcs(&mut self, funcs: HashMap<String, Py<PyAny>>) -> PyResult<()> {
self.0 = Some(
self.0
.take()
Expand Down Expand Up @@ -920,7 +920,7 @@ impl PyAuthorizer {
/// :type parameters: dict, optional
/// :param scope_parameters: public keys for the public key parameters in the datalog snippet
/// :type scope_parameters: dict, optional
#[pyclass(name = "BlockBuilder")]
#[pyclass(name = "BlockBuilder", from_py_object)]
#[derive(Clone)]
pub struct PyBlockBuilder(Option<builder::BlockBuilder>);

Expand Down Expand Up @@ -1146,7 +1146,7 @@ impl Default for PyKeyPair {

/// ed25519 public key
#[derive(Clone)]
#[pyclass(name = "PublicKey")]
#[pyclass(name = "PublicKey", from_py_object)]
pub struct PyPublicKey(PublicKey);

#[pymethods]
Expand Down Expand Up @@ -1225,7 +1225,7 @@ impl PyPublicKey {
}

/// ed25519 private key
#[pyclass(name = "PrivateKey")]
#[pyclass(name = "PrivateKey", from_py_object)]
#[derive(Clone)]
pub struct PyPrivateKey(PrivateKey);

Expand Down Expand Up @@ -1314,26 +1314,24 @@ pub enum NestedPyTerm {
Bytes(Vec<u8>),
}

// TODO follow up-to-date conversion methods from pyo3
fn inner_term_to_py(t: &builder::Term, py: Python<'_>) -> PyResult<Py<PyAny>> {
match t {
builder::Term::Integer(i) => Ok((*i).into_pyobject(py).unwrap().into_any().unbind()),
builder::Term::Str(s) => Ok(s.into_pyobject(py).unwrap().into_any().unbind()),
builder::Term::Date(d) => Ok(Utc
.timestamp_opt(*d as i64, 0)
.unwrap()
.into_pyobject(py)
.unwrap()
.into_any()
.unbind()),
builder::Term::Bytes(bs) => Ok(bs.clone().into_pyobject(py).unwrap().into_any().unbind()),
builder::Term::Bool(b) => Ok(b.into_py(py)),
builder::Term::Integer(i) => (*i).into_py_any(py),
builder::Term::Str(s) => s.into_py_any(py),
builder::Term::Date(d) => {
Utc.timestamp_opt(*d as i64, 0)
.single()
.ok_or_else(|| DataLogError::new_err("Invalid timestamp".to_string()))?
.into_py_any(py)
}
builder::Term::Bytes(bs) => bs.clone().into_py_any(py),
builder::Term::Bool(b) => (*b).into_py_any(py),
_ => Err(DataLogError::new_err("Invalid term value".to_string())),
}
}

fn term_to_py(t: &builder::Term) -> PyResult<Py<PyAny>> {
Python::with_gil(|py| match t {
Python::attach(|py| match t {
builder::Term::Parameter(_) => Err(DataLogError::new_err("Invalid term value".to_string())),
builder::Term::Variable(_) => Err(DataLogError::new_err("Invalid term value".to_string())),
builder::Term::Set(_vs) => todo!(),
Expand All @@ -1345,7 +1343,7 @@ fn term_to_py(t: &builder::Term) -> PyResult<Py<PyAny>> {

/// Wrapper for a non-naïve python date
#[derive(FromPyObject)]
pub struct PyDate(Py<PyDateTime>);
pub struct PyDate(pub Py<PyDateTime>);

impl PartialEq for PyDate {
fn eq(&self, other: &Self) -> bool {
Expand Down Expand Up @@ -1384,19 +1382,23 @@ impl NestedPyTerm {
NestedPyTerm::Str(s) => Ok(builder::Term::Str(s.to_string())),
NestedPyTerm::Bytes(b) => Ok(b.clone().into()),
NestedPyTerm::Bool(b) => Ok((*b).into()),
NestedPyTerm::Date(PyDate(d)) => Python::with_gil(|py| {
let ts = d.extract::<DateTime<Utc>>(py)?.timestamp();
if ts < 0 {
return Err(PyValueError::new_err(
"Only positive timestamps are available".to_string(),
));
}
Ok(builder::Term::Date(ts as u64))
}),
NestedPyTerm::Date(PyDate(py_date)) => {
Python::attach(|py| {
let dt: chrono::DateTime<chrono::Utc> = py_date.extract(py)?;
let ts = dt.timestamp();
if ts < 0 {
return Err(PyValueError::new_err(
"Only positive timestamps supported",
));
}
Ok(builder::Term::Date(ts as u64))
})
}
}
}
}


impl PyTerm {
pub fn to_term(&self) -> PyResult<builder::Term> {
match self {
Expand Down Expand Up @@ -1460,7 +1462,7 @@ impl PyFact {

/// The fact terms
#[getter]
pub fn terms(&self) -> PyResult<Vec<PyObject>> {
pub fn terms(&self) -> PyResult<Vec<Py<PyAny>>> {
self.0.predicate.terms.iter().map(term_to_py).collect()
}

Expand Down Expand Up @@ -1681,7 +1683,7 @@ impl PyUnverifiedBiscuit {
.collect()
}

pub fn verify(&self, root: PyObject) -> PyResult<PyBiscuit> {
pub fn verify(&self, root: Py<PyAny>) -> PyResult<PyBiscuit> {
Ok(PyBiscuit(
self.0
.clone()
Expand Down