-
-
Notifications
You must be signed in to change notification settings - Fork 11
/
lib.rs
109 lines (98 loc) · 2.84 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use std::fs;
use std::io::{self, Write};
use crfs_rs::{Attribute, Model};
use ouroboros::self_referencing;
use pyo3::prelude::*;
#[pyclass(module = "crfs", name = "Attribute")]
#[derive(FromPyObject)]
struct PyAttribute {
/// Attribute name
#[pyo3(get, set)]
name: String,
/// Value of the attribute
#[pyo3(get, set)]
value: f64,
}
#[pymethods]
impl PyAttribute {
#[new]
#[args(name, value = "1.0")]
fn new(name: String, value: f64) -> Self {
Self { name, value }
}
}
#[derive(FromPyObject)]
enum PyAttributeInput {
#[pyo3(transparent)]
Attr(PyAttribute),
Dict {
/// Attribute name
#[pyo3(item("name"))]
name: String,
/// Value of the attribute
#[pyo3(item("value"))]
value: f64,
},
Tuple(String, f64),
#[pyo3(transparent)]
NameOnly(String),
}
impl From<PyAttributeInput> for Attribute {
fn from(attr: PyAttributeInput) -> Self {
match attr {
PyAttributeInput::Attr(PyAttribute { name, value }) => Attribute::new(name, value),
PyAttributeInput::Dict { name, value } => Attribute::new(name, value),
PyAttributeInput::Tuple(name, value) => Attribute::new(name, value),
PyAttributeInput::NameOnly(name) => Attribute::new(name, 1.0),
}
}
}
#[pyclass(module = "crfs", name = "Model")]
#[self_referencing]
struct PyModel {
data: Vec<u8>,
#[borrows(data)]
#[covariant]
model: Model<'this>,
}
#[pymethods]
impl PyModel {
/// Create an instance of a model object from a model in memory
#[new]
fn new_py(data: Vec<u8>) -> PyResult<Self> {
let model = PyModel::try_new(data, |data| Model::new(data))?;
Ok(model)
}
/// Create an instance of a model object from a local file
#[staticmethod]
fn open(path: &str) -> PyResult<Self> {
let data = fs::read(path)?;
Self::new_py(data)
}
/// Predict the label sequence for the item sequence.
pub fn tag(&self, xseq: Vec<Vec<PyAttributeInput>>) -> PyResult<Vec<String>> {
let mut tagger = self.borrow_model().tagger()?;
let xseq: Vec<Vec<Attribute>> = xseq
.into_iter()
.map(|xs| xs.into_iter().map(Into::into).collect())
.collect();
let labels = tagger.tag(&xseq)?;
Ok(labels.iter().map(|l| l.to_string()).collect())
}
/// Print the model in human-readable format
pub fn dump(&self) -> PyResult<()> {
let mut out = Vec::new();
self.borrow_model().dump(&mut out)?;
let stdout = io::stdout();
let mut handle = stdout.lock();
handle.write_all(&out)?;
Ok(())
}
}
#[pymodule]
fn crfs(_py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<PyAttribute>()?;
m.add_class::<PyModel>()?;
Ok(())
}