Skip to content

Commit

Permalink
Make mypy happy.
Browse files Browse the repository at this point in the history
  • Loading branch information
leobeeson committed Jun 1, 2023
1 parent e2f09e7 commit 2f85219
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.9
language_version: python3.10
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
Expand Down
28 changes: 17 additions & 11 deletions wordview/mwes/mwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pandas
import tqdm
from nltk import word_tokenize, RegexpParser
from nltk import RegexpParser, word_tokenize

from wordview import logger
from wordview.mwes.am import calculate_am
Expand Down Expand Up @@ -241,8 +241,8 @@ def extract_mwes_from_sent(self, tokens: list[str], mwe_type: str) -> Dict:
mwes_count_dic = Counter(mwes)
return mwes_count_dic

class HigherOrderMWEExtractor:

class HigherOrderMWEExtractor:
def __init__(self, tokens: list[str], pattern: str) -> None:
self.tokens = tokens
self.pattern = pattern
Expand All @@ -256,7 +256,7 @@ def _validate_input(self) -> None:
)
if len(self.tokens) == 0:
raise ValueError(
f'Input argument "tokens" must be a non-empty list of string.'
'Input argument "tokens" must be a non-empty list of string.'
)
if not isinstance(self.pattern, str):
raise TypeError(
Expand All @@ -265,7 +265,7 @@ def _validate_input(self) -> None:
)
if len(self.pattern) == 0:
raise ValueError(
f'Input argument "pattern" must be a non-zero length string.'
'Input argument "pattern" must be a non-zero length string.'
)

def extract_higher_order_mwes(self) -> dict:
Expand All @@ -277,7 +277,7 @@ def extract_higher_order_mwes(self) -> dict:
pattern (str): A string containing a user-defined pattern for nltk.RegexpParser.
Returns:
match_counter (dict[str, dict[str, int]]): A counter dictionary with count of matched strings, grouped by patter label.
match_counter (dict[str, dict[str, int]]): A counter dictionary with count of matched strings, grouped by patter label.
An empty list if none were found.
Examples of user-defined patterns:
Expand All @@ -296,19 +296,25 @@ def extract_higher_order_mwes(self) -> dict:
In this case, patterns of a clause are executed in order. An earlier
pattern may introduce a chunk boundary that prevents a later pattern from executing.
"""

tagged_tokens: list[tuple[str, str]] = get_pos_tags(self.tokens)
parser = RegexpParser(self.pattern)
parsed_tokens = parser.parse(tagged_tokens)

labels = [rule.split(":")[0].strip() for rule in self.pattern.split("\n") if rule]
labels: list[str] = [
rule.split(":")[0].strip() for rule in self.pattern.split("\n") if rule
]

matches = {label: [] for label in labels}
matches: dict[str, list[str]] = {label: [] for label in labels}

for subtree in parsed_tokens.subtrees():
label = subtree.label()
if label in matches:
matches[label].append(" ".join(word for (word, tag) in subtree.leaves()))
matches[label].append(
" ".join(word for (word, tag) in subtree.leaves())
)

matches = {label: dict(Counter(match_list)) for label, match_list in matches.items()}
return matches
matches_counter: dict[str, dict[str, int]] = {
label: dict(Counter(match_list)) for label, match_list in matches.items()
}
return matches_counter

0 comments on commit 2f85219

Please sign in to comment.