Skip to content

Commit

Permalink
clean up pos ids
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Jan 11, 2021
1 parent 337907b commit 29c0ac1
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 35 deletions.
35 changes: 26 additions & 9 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,12 @@ impl PyTagger {
#[pyclass(name = "Token", module = "nlprule")]
pub struct PyToken {
token: OwnedToken,
tagger: Arc<Tagger>,
}

impl From<OwnedToken> for PyToken {
fn from(token: OwnedToken) -> Self {
PyToken { token }
impl PyToken {
fn new(token: OwnedToken, tagger: Arc<Tagger>) -> Self {
PyToken { token, tagger }
}
}

Expand All @@ -324,12 +325,12 @@ impl PyToken {
}

#[getter]
fn data(&self) -> Vec<(&str, u16)> {
fn data(&self) -> Vec<(&str, &str)> {
self.token
.word
.tags
.iter()
.map(|x| (x.lemma.as_str(), x.pos_id))
.map(|x| (x.lemma.as_str(), self.tagger.id_to_tag(x.pos_id)))
.collect()
}

Expand All @@ -354,8 +355,14 @@ impl PyToken {
}

#[getter]
fn tags(&self) -> Vec<u16> {
let mut tags: Vec<_> = self.token.word.tags.iter().map(|x| x.pos_id).collect();
fn tags(&self) -> Vec<&str> {
let mut tags: Vec<_> = self
.token
.word
.tags
.iter()
.map(|x| self.tagger.id_to_tag(x.pos_id))
.collect();
tags.sort_unstable();
tags.dedup();
tags
Expand Down Expand Up @@ -498,7 +505,12 @@ impl PyTokenizer {
.disambiguate(self.tokenizer.tokenize(&sentence)),
)
.into_iter()
.map(|x| PyCell::new(py, PyToken::from(x.to_owned_token())))
.map(|x| {
PyCell::new(
py,
PyToken::new(x.to_owned_token(), self.tokenizer.tagger().clone()),
)
})
.collect::<PyResult<Vec<_>>>()
})
}
Expand All @@ -520,7 +532,12 @@ impl PyTokenizer {
.disambiguate(self.tokenizer.tokenize(&sentence)),
)
.into_iter()
.map(|x| PyCell::new(py, PyToken::from(x.to_owned_token())))
.map(|x| {
PyCell::new(
py,
PyToken::new(x.to_owned_token(), self.tokenizer.tagger().clone()),
)
})
.collect::<PyResult<Vec<_>>>()?;
output.extend(tokens);
}
Expand Down
4 changes: 4 additions & 0 deletions configs/de/tokenizer.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,9 @@
"ignore_ids": [
"SUB_BEAMTE.1",
"SUB_BEAMTE.2"
],
"extra_tags": [
"PKT",
"PRO:IND:DAT:SIN:NEU"
]
}
6 changes: 6 additions & 0 deletions configs/en/tokenizer.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,11 @@
"always_add_lower_tags": true,
"ignore_ids": [
"BEST_JJS"
],
"extra_tags": [
"PCT",
"ORD",
"SYM",
"RB_SENT"
]
}
24 changes: 15 additions & 9 deletions nlprule/src/bin/compile.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use clap::Clap;
use nlprule::{
rules::Rules,
tokenizer::{chunk::Chunker, tag::Tagger, Tokenizer},
rules::{Rules, RulesOptions},
tokenizer::{chunk::Chunker, tag::Tagger, Tokenizer, TokenizerOptions},
};
use std::{
collections::HashSet,
Expand Down Expand Up @@ -50,7 +50,17 @@ fn main() {
.collect()
});

let tagger = Tagger::from_dumps(&opts.tag_paths, &opts.tag_remove_paths).unwrap();
let tokenizer_options: TokenizerOptions =
serde_json::from_str(&read_to_string(opts.tokenizer_config_path).unwrap()).unwrap();
let rules_options: RulesOptions =
serde_json::from_str(&read_to_string(opts.rules_config_path).unwrap()).unwrap();

let tagger = Tagger::from_dumps(
&opts.tag_paths,
&opts.tag_remove_paths,
&tokenizer_options.extra_tags,
)
.unwrap();
let mut tokenizer = Tokenizer::from_xml(
opts.disambiguation_path,
Arc::new(tagger),
Expand All @@ -61,19 +71,15 @@ fn main() {
} else {
None
},
serde_json::from_str(&read_to_string(opts.tokenizer_config_path).unwrap()).unwrap(),
tokenizer_options,
)
.unwrap();
tokenizer.populate_cache(&common_words);

let f = BufWriter::new(File::create(&opts.out_tokenizer_path).unwrap());
bincode::serialize_into(f, &tokenizer).unwrap();

let mut rules = Rules::from_xml(
opts.grammar_path,
tokenizer.tagger(),
serde_json::from_str(&read_to_string(opts.rules_config_path).unwrap()).unwrap(),
);
let mut rules = Rules::from_xml(opts.grammar_path, tokenizer.tagger(), rules_options);
rules.populate_cache(&common_words);

let f = BufWriter::new(File::create(&opts.out_rules_path).unwrap());
Expand Down
1 change: 1 addition & 0 deletions nlprule/src/rule/disambiguation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct POSFilter {
}

impl POSFilter {
#[allow(dead_code)]
pub fn new(matcher: PosMatcher) -> Self {
POSFilter { matcher }
}
Expand Down
3 changes: 1 addition & 2 deletions nlprule/src/rules.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! Sets of grammatical error correction rules.

use crate::rule::{Cache, Rule};
use crate::tokenizer::tag::Tagger;
use crate::tokenizer::Tokenizer;
use crate::types::*;
use crate::utils::parallelism::MaybeParallelRefIterator;
Expand Down Expand Up @@ -48,7 +47,7 @@ impl Rules {
#[cfg(feature = "compile")]
pub fn from_xml<P: AsRef<std::path::Path>>(
path: P,
tagger: &Tagger,
tagger: &crate::tokenizer::tag::Tagger,
options: RulesOptions,
) -> Self {
use log::warn;
Expand Down
4 changes: 4 additions & 0 deletions nlprule/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ pub struct TokenizerOptions {
/// Specific examples in the notation `{id}:{example_index}` which are known to fail.
#[serde(default)]
pub known_failures: Vec<String>,
/// Used part-of-speech tags which are not in the tagger dictionary.
#[serde(default)]
pub extra_tags: Vec<String>,
}

impl Default for TokenizerOptions {
Expand All @@ -114,6 +117,7 @@ impl Default for TokenizerOptions {
ids: Vec::new(),
ignore_ids: Vec::new(),
known_failures: Vec::new(),
extra_tags: Vec::new(),
}
}
}
Expand Down
9 changes: 2 additions & 7 deletions nlprule/src/tokenizer/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,11 @@ impl Chunker {
}

// the number is singular unless any token in the noun chunk has the part-of-speech tag `NNS` assigned
let nns_id = tokens[0].tagger.tag_to_id("NNS");
if tokens
.iter()
.find(|token| token.char_span == char_span)
.map(|token| {
token
.word
.tags
.iter()
.any(|tag| tag.pos_id == token.tagger.tag_to_id("NNS"))
})
.map(|token| token.word.tags.iter().any(|tag| tag.pos_id == nns_id))
.unwrap_or(false)
{
number = "plural";
Expand Down
13 changes: 5 additions & 8 deletions nlprule/src/tokenizer/tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,25 @@ impl Tagger {
/// separated by tabs, to be added to the tagger.
/// * `remove_paths`: Paths to files where each line contains the word, lemma and tag, respectively,
/// separated by tabs, to be removed from the tagger if present in the files from `paths`.
pub fn from_dumps<S1: AsRef<str>, S2: AsRef<str>>(
pub fn from_dumps<S1: AsRef<str>, S2: AsRef<str>, S3: AsRef<str>>(
paths: &[S1],
remove_paths: &[S2],
extra_tags: &[S3],
) -> std::io::Result<Self> {
let mut tags = HashMap::new();
let mut groups = HashMap::new();

let mut tag_store = HashSet::new();
let mut word_store = HashSet::new();

// special tags
// hardcoded special tags
tag_store.insert("");
tag_store.insert("SENT_START");
tag_store.insert("SENT_END");
tag_store.insert("UNKNOWN");
tag_store.insert("PCT");
tag_store.insert("ORD");
tag_store.insert("SYM");
tag_store.insert("RB_SENT");

tag_store.insert("PKT");
tag_store.insert("PRO:IND:DAT:SIN:NEU");
// add language specific special tags
tag_store.extend(extra_tags.iter().map(|x| x.as_ref()));

let lines = Tagger::get_lines(paths, remove_paths)?;

Expand Down

0 comments on commit 29c0ac1

Please sign in to comment.