Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose attribute_filter option #11

Merged
merged 2 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nh3.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, Optional, Set
from typing import Callable, Dict, Optional, Set

def clean(
html: str,
tags: Optional[Set[str]] = None,
attributes: Optional[Dict[str, Set[str]]] = None,
attribute_filter: Optional[Callable[[str, str, str]], Optional[str]] = None,
strip_comments: bool = True,
link_rel: Optional[str] = "noopener noreferrer",
) -> str: ...
Expand Down
57 changes: 54 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};

use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyString, PyTuple};

/// Clean HTML with a conservative set of defaults
#[pyfunction(signature = (
html,
tags = None,
attributes = None,
attribute_filter = None,
strip_comments = true,
link_rel = "noopener noreferrer",
))]
Expand All @@ -15,12 +19,20 @@ fn clean(
html: &str,
tags: Option<HashSet<&str>>,
attributes: Option<HashMap<&str, HashSet<&str>>>,
attribute_filter: Option<PyObject>,
strip_comments: bool,
link_rel: Option<&str>,
) -> String {
py.allow_threads(|| {
) -> PyResult<String> {
if let Some(callback) = attribute_filter.as_ref() {
if !callback.as_ref(py).is_callable() {
return Err(PyTypeError::new_err("attribute_filter must be callable"));
}
}

let cleaned = py.allow_threads(|| {
if tags.is_some()
|| attributes.is_some()
|| attribute_filter.is_some()
|| !strip_comments
|| link_rel != Some("noopener noreferrer")
{
Expand All @@ -34,13 +46,52 @@ fn clean(
}
cleaner.tag_attributes(attrs);
}
if let Some(callback) = attribute_filter {
cleaner.attribute_filter(move |element, attribute, value| {
Python::with_gil(|py| {
let res = callback.call(
py,
PyTuple::new(
py,
[
PyString::new(py, element),
PyString::new(py, attribute),
PyString::new(py, value),
],
),
None,
);
let err = match res {
Ok(val) => {
if val.is_none(py) {
return None;
} else if let Ok(s) = val.downcast::<PyString>(py) {
match s.to_str() {
Ok(s) => return Some(Cow::<str>::Owned(s.to_string())),
Err(err) => err,
}
} else {
PyTypeError::new_err(
"expected attribute_filter to return str or None",
)
}
}
Err(err) => err,
};
err.restore(py);
Some(value.into())
})
});
}
cleaner.strip_comments(strip_comments);
cleaner.link_rel(link_rel);
cleaner.clean(html).to_string()
} else {
ammonia::clean(html)
}
})
});

Ok(cleaned)
}

/// Turn an arbitrary string into unformatted HTML
Expand Down
23 changes: 23 additions & 0 deletions tests/test_nh3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import nh3
import pytest


def test_clean():
Expand All @@ -19,6 +20,28 @@ def test_clean():
)


def test_clean_with_attribute_filter():
html = "<a href=/><img alt=Home src=foo></a>"

def attribute_filter(element, attribute, value):
if element == "img" and attribute == "src":
return None
return value

assert (
nh3.clean(html, attribute_filter=attribute_filter, link_rel=None)
== '<a href="/"><img alt="Home"></a>'
)

with pytest.raises(TypeError):
nh3.clean(html, attribute_filter="not a callable")

with pytest.raises(SystemError):
# FIXME: attribute_filter may raise exception, but it's an infallible API
# which causes Python to raise SystemError instead of the intended TypeError
nh3.clean(html, attribute_filter=lambda _element, _attribute, _value: True)


def test_clean_text():
res = nh3.clean_text('Robert"); abuse();//')
assert res == "Robert&quot;);&#32;abuse();&#47;&#47;"