Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion EduNLP/I2V/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# 2021/8/1 @ tongshiwei

from .i2v import I2V, get_pretrained_i2v
from .i2v import D2V
from .i2v import D2V, W2V
19 changes: 17 additions & 2 deletions EduNLP/I2V/i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..Tokenizer import Tokenizer, get_tokenizer
from EduNLP import logger

__all__ = ["I2V", "D2V", "get_pretrained_i2v"]
__all__ = ["I2V", "D2V", "W2V", "get_pretrained_i2v"]


class I2V(object):
Expand Down Expand Up @@ -89,18 +89,33 @@ class D2V(I2V):
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, key=lambda x: x, *args,
**kwargs) -> tuple:
tokens = self.tokenize(items, return_token=True, key=key) if tokenize is True else items
tokens = [token for token in tokens]
return self.t2v(tokens, *args, **kwargs), None

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
return cls("text", name, pretrained_t2v=True, model_dir=model_dir)
return cls("pure_text", name, pretrained_t2v=True, model_dir=model_dir)


class W2V(I2V):
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, key=lambda x: x, *args,
**kwargs) -> tuple:
tokens = self.tokenize(items, return_token=True) if tokenize is True else items
tokens = [token for token in tokens]
return self.t2v(tokens, *args, **kwargs), self.t2v.infer_tokens(tokens, *args, **kwargs)

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
return cls("pure_text", name, pretrained_t2v=True, model_dir=model_dir)


MODELS = {
"d2v_all_256": [D2V, "d2v_all_256"],
"d2v_sci_256": [D2V, "d2v_sci_256"],
"d2v_eng_256": [D2V, "d2v_eng_256"],
"d2v_lit_256": [D2V, "d2v_lit_256"],
"w2v_sci_300": [W2V, "w2v_sci_300"],
"w2v_lit_300": [W2V, "w2v_lit_300"],
}


Expand Down
15 changes: 11 additions & 4 deletions EduNLP/SIF/tokenization/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ def __init__(self, segment_list: SegmentList, text_params=None, formula_params=N
"s": []
}
self.text_params = text_params if text_params is not None else {}
if formula_params is not None and "symbolize_figure_formula" in formula_params:
self.symbolize_figure_formula = formula_params.pop("symbolize_figure_formula")
else:
self.symbolize_figure_formula = False
self.symbolize_figure_formula = False
self.skip_figure_formula = False
if formula_params is not None:
if "symbolize_figure_formula" in formula_params:
self.symbolize_figure_formula = formula_params.pop("symbolize_figure_formula")
if "skip_figure_formula" in formula_params:
self.skip_figure_formula = formula_params.pop("skip_figure_formula")

self.formula_params = formula_params if formula_params is not None else {"method": "linear"}
self.formula_tokenize_method = self.formula_params.get("method")
self.figure_params = figure_params if figure_params is not None else {}
Expand Down Expand Up @@ -175,6 +179,9 @@ def append_formula(self, segment, symbol=False, init=True):
if symbol is True:
self._formula_tokens.append(len(self._tokens))
self._tokens.append(segment)
elif self.skip_figure_formula and isinstance(segment, FigureFormulaSegment):
# skip the FigureFormulaSegment
pass
elif self.symbolize_figure_formula and isinstance(segment, FigureFormulaSegment):
self._formula_tokens.append(len(self._tokens))
self._tokens.append(Symbol(FORMULA_SYMBOL))
Expand Down
39 changes: 36 additions & 3 deletions EduNLP/Tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,26 @@
from ..SIF.segment import seg
from ..SIF.tokenization import tokenize

__all__ = ["TOKENIZER", "Tokenizer", "TextTokenizer", "get_tokenizer"]
__all__ = ["TOKENIZER", "Tokenizer", "PureTextTokenizer", "TextTokenizer", "get_tokenizer"]


class Tokenizer(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError


class TextTokenizer(Tokenizer):
class PureTextTokenizer(Tokenizer):
r"""

Examples
--------
>>> tokenizer = PureTextTokenizer()
>>> items = ["有公式$\\FormFigureID{wrong1?}$,如图$\\FigureID{088f15ea-xxx}$,\
... 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$"]
>>> tokens = tokenizer(items)
>>> next(tokens)[:10]
['公式', '如图', '[FIGURE]', 'x', ',', 'y', '约束条件', '公式', '[SEP]', 'z']
>>> items = ["已知集合$A=\\left\\{x \\mid x^{2}-3 x-4<0\\right\\}, \\quad B=\\{-4,1,3,5\\}, \\quad$ 则 $A \\cap B=$"]
>>> tokenizer = TextTokenizer()
>>> tokens = tokenizer(items)
>>> next(tokens) # doctest: +NORMALIZE_WHITESPACE
['已知', '集合', 'A', '=', '\\left', '\\{', 'x', '\\mid', 'x', '^', '{', '2', '}', '-', '3', 'x', '-', '4', '<',
Expand All @@ -40,6 +45,33 @@ def __init__(self, *args, **kwargs):
self.tokenization_params = {
"formula_params": {
"method": "linear",
"skip_figure_formula": True
}
}

def __call__(self, items: Iterable, key=lambda x: x, *args, **kwargs):
for item in items:
yield tokenize(seg(key(item), symbol="gmas"), **self.tokenization_params).tokens


class TextTokenizer(Tokenizer):
r"""

Examples
----------
>>> tokenizer = TextTokenizer()
>>> items = ["有公式$\\FormFigureID{wrong1?}$,如图$\\FigureID{088f15ea-xxx}$,\
... 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$"]
>>> tokens = tokenizer(items)
>>> next(tokens)[:10]
['公式', '[FORMULA]', '如图', '[FIGURE]', 'x', ',', 'y', '约束条件', '公式', '[FORMULA]']
"""

def __init__(self, *args, **kwargs):
self.tokenization_params = {
"formula_params": {
"method": "linear",
"symbolize_figure_formula": True
}
}

Expand All @@ -49,6 +81,7 @@ def __call__(self, items: Iterable, key=lambda x: x, *args, **kwargs):


TOKENIZER = {
"pure_text": PureTextTokenizer,
"text": TextTokenizer
}

Expand Down
13 changes: 11 additions & 2 deletions EduNLP/Vector/t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def __init__(self, model: str, *args, **kwargs):
def __call__(self, items, *args, **kwargs):
return self.i2v.infer_vector(items, *args, **kwargs)

def infer_vector(self, items, *args, **kwargs):
return self.i2v.infer_vector(items, *args, **kwargs)

def infer_tokens(self, items, *args, **kwargs):
return self.i2v.infer_tokens(items, *args, **kwargs)

@property
def vector_size(self) -> int:
return self.i2v.vector_size
Expand All @@ -41,6 +47,8 @@ def vector_size(self) -> int:
"d2v_sci_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_science_256.zip", "d2v"],
"d2v_eng_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_english_256.zip", "d2v"],
"d2v_lit_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_literal_256.zip", "d2v"],
"w2v_eng_300": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/general_english_300.zip", "w2v"],
"w2v_lit_300": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/general_literal_300.zip", "w2v"],
}


Expand All @@ -52,6 +60,7 @@ def get_pretrained_t2v(name, model_dir=MODEL_DIR):
)
url, model_name, *args = PRETRAINED_MODELS[name]
model_path = get_data(url, model_dir)
if model_name in ["d2v"]:
model_path = path_append(model_path, os.path.basename(model_path) + ".bin", to_str=True)
if model_name in ["d2v", "w2v"]:
postfix = ".bin" if model_name == "d2v" else ".kv"
model_path = path_append(model_path, os.path.basename(model_path) + postfix, to_str=True)
return T2V(model_name, model_path, *args)
Loading