diff --git a/EduNLP/I2V/__init__.py b/EduNLP/I2V/__init__.py
index 7735252d..db5cb958 100644
--- a/EduNLP/I2V/__init__.py
+++ b/EduNLP/I2V/__init__.py
@@ -2,4 +2,4 @@
# 2021/8/1 @ tongshiwei
from .i2v import I2V, get_pretrained_i2v
-from .i2v import D2V
+from .i2v import D2V, W2V
diff --git a/EduNLP/I2V/i2v.py b/EduNLP/I2V/i2v.py
index c9975376..4254fe1e 100644
--- a/EduNLP/I2V/i2v.py
+++ b/EduNLP/I2V/i2v.py
@@ -7,7 +7,7 @@
from ..Tokenizer import Tokenizer, get_tokenizer
from EduNLP import logger
-__all__ = ["I2V", "D2V", "get_pretrained_i2v"]
+__all__ = ["I2V", "D2V", "W2V", "get_pretrained_i2v"]
class I2V(object):
@@ -89,11 +89,24 @@ class D2V(I2V):
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, key=lambda x: x, *args,
**kwargs) -> tuple:
tokens = self.tokenize(items, return_token=True, key=key) if tokenize is True else items
+ tokens = [token for token in tokens]
return self.t2v(tokens, *args, **kwargs), None
@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
- return cls("text", name, pretrained_t2v=True, model_dir=model_dir)
+ return cls("pure_text", name, pretrained_t2v=True, model_dir=model_dir)
+
+
+class W2V(I2V):
+ def infer_vector(self, items, tokenize=True, indexing=False, padding=False, key=lambda x: x, *args,
+ **kwargs) -> tuple:
+ tokens = self.tokenize(items, return_token=True) if tokenize is True else items
+ tokens = [token for token in tokens]
+ return self.t2v(tokens, *args, **kwargs), self.t2v.infer_tokens(tokens, *args, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
+ return cls("pure_text", name, pretrained_t2v=True, model_dir=model_dir)
MODELS = {
@@ -101,6 +114,8 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
"d2v_sci_256": [D2V, "d2v_sci_256"],
"d2v_eng_256": [D2V, "d2v_eng_256"],
"d2v_lit_256": [D2V, "d2v_lit_256"],
+ "w2v_sci_300": [W2V, "w2v_sci_300"],
+ "w2v_lit_300": [W2V, "w2v_lit_300"],
}
diff --git a/EduNLP/SIF/tokenization/tokenization.py b/EduNLP/SIF/tokenization/tokenization.py
index 33085cbf..299eaf62 100644
--- a/EduNLP/SIF/tokenization/tokenization.py
+++ b/EduNLP/SIF/tokenization/tokenization.py
@@ -38,10 +38,14 @@ def __init__(self, segment_list: SegmentList, text_params=None, formula_params=N
"s": []
}
self.text_params = text_params if text_params is not None else {}
- if formula_params is not None and "symbolize_figure_formula" in formula_params:
- self.symbolize_figure_formula = formula_params.pop("symbolize_figure_formula")
- else:
- self.symbolize_figure_formula = False
+ self.symbolize_figure_formula = False
+ self.skip_figure_formula = False
+ if formula_params is not None:
+ if "symbolize_figure_formula" in formula_params:
+ self.symbolize_figure_formula = formula_params.pop("symbolize_figure_formula")
+ if "skip_figure_formula" in formula_params:
+ self.skip_figure_formula = formula_params.pop("skip_figure_formula")
+
self.formula_params = formula_params if formula_params is not None else {"method": "linear"}
self.formula_tokenize_method = self.formula_params.get("method")
self.figure_params = figure_params if figure_params is not None else {}
@@ -175,6 +179,9 @@ def append_formula(self, segment, symbol=False, init=True):
if symbol is True:
self._formula_tokens.append(len(self._tokens))
self._tokens.append(segment)
+ elif self.skip_figure_formula and isinstance(segment, FigureFormulaSegment):
+ # skip the FigureFormulaSegment
+ pass
elif self.symbolize_figure_formula and isinstance(segment, FigureFormulaSegment):
self._formula_tokens.append(len(self._tokens))
self._tokens.append(Symbol(FORMULA_SYMBOL))
diff --git a/EduNLP/Tokenizer/tokenizer.py b/EduNLP/Tokenizer/tokenizer.py
index 33f4d05b..bb7b47e5 100644
--- a/EduNLP/Tokenizer/tokenizer.py
+++ b/EduNLP/Tokenizer/tokenizer.py
@@ -5,7 +5,7 @@
from ..SIF.segment import seg
from ..SIF.tokenization import tokenize
-__all__ = ["TOKENIZER", "Tokenizer", "TextTokenizer", "get_tokenizer"]
+__all__ = ["TOKENIZER", "Tokenizer", "PureTextTokenizer", "TextTokenizer", "get_tokenizer"]
class Tokenizer(object):
@@ -13,13 +13,18 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError
-class TextTokenizer(Tokenizer):
+class PureTextTokenizer(Tokenizer):
r"""
Examples
--------
+ >>> tokenizer = PureTextTokenizer()
+ >>> items = ["有公式$\\FormFigureID{wrong1?}$,如图$\\FigureID{088f15ea-xxx}$,\
+ ... 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$"]
+ >>> tokens = tokenizer(items)
+ >>> next(tokens)[:10]
+ ['公式', '如图', '[FIGURE]', 'x', ',', 'y', '约束条件', '公式', '[SEP]', 'z']
>>> items = ["已知集合$A=\\left\\{x \\mid x^{2}-3 x-4<0\\right\\}, \\quad B=\\{-4,1,3,5\\}, \\quad$ 则 $A \\cap B=$"]
- >>> tokenizer = TextTokenizer()
>>> tokens = tokenizer(items)
>>> next(tokens) # doctest: +NORMALIZE_WHITESPACE
['已知', '集合', 'A', '=', '\\left', '\\{', 'x', '\\mid', 'x', '^', '{', '2', '}', '-', '3', 'x', '-', '4', '<',
@@ -40,6 +45,33 @@ def __init__(self, *args, **kwargs):
self.tokenization_params = {
"formula_params": {
"method": "linear",
+ "skip_figure_formula": True
+ }
+ }
+
+ def __call__(self, items: Iterable, key=lambda x: x, *args, **kwargs):
+ for item in items:
+ yield tokenize(seg(key(item), symbol="gmas"), **self.tokenization_params).tokens
+
+
+class TextTokenizer(Tokenizer):
+ r"""
+
+ Examples
+ ----------
+ >>> tokenizer = TextTokenizer()
+ >>> items = ["有公式$\\FormFigureID{wrong1?}$,如图$\\FigureID{088f15ea-xxx}$,\
+ ... 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$,则$z=x+7 y$的最大值为$\\SIFBlank$"]
+ >>> tokens = tokenizer(items)
+ >>> next(tokens)[:10]
+ ['公式', '[FORMULA]', '如图', '[FIGURE]', 'x', ',', 'y', '约束条件', '公式', '[FORMULA]']
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.tokenization_params = {
+ "formula_params": {
+ "method": "linear",
+ "symbolize_figure_formula": True
}
}
@@ -49,6 +81,7 @@ def __call__(self, items: Iterable, key=lambda x: x, *args, **kwargs):
TOKENIZER = {
+ "pure_text": PureTextTokenizer,
"text": TextTokenizer
}
diff --git a/EduNLP/Vector/t2v.py b/EduNLP/Vector/t2v.py
index 103fd3da..ec0887ef 100644
--- a/EduNLP/Vector/t2v.py
+++ b/EduNLP/Vector/t2v.py
@@ -31,6 +31,12 @@ def __init__(self, model: str, *args, **kwargs):
def __call__(self, items, *args, **kwargs):
return self.i2v.infer_vector(items, *args, **kwargs)
+ def infer_vector(self, items, *args, **kwargs):
+ return self.i2v.infer_vector(items, *args, **kwargs)
+
+ def infer_tokens(self, items, *args, **kwargs):
+ return self.i2v.infer_tokens(items, *args, **kwargs)
+
@property
def vector_size(self) -> int:
return self.i2v.vector_size
@@ -41,6 +47,8 @@ def vector_size(self) -> int:
"d2v_sci_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_science_256.zip", "d2v"],
"d2v_eng_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_english_256.zip", "d2v"],
"d2v_lit_256": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_literal_256.zip", "d2v"],
+ "w2v_eng_300": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/general_english_300.zip", "w2v"],
+ "w2v_lit_300": ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/general_literal_300.zip", "w2v"],
}
@@ -52,6 +60,7 @@ def get_pretrained_t2v(name, model_dir=MODEL_DIR):
)
url, model_name, *args = PRETRAINED_MODELS[name]
model_path = get_data(url, model_dir)
- if model_name in ["d2v"]:
- model_path = path_append(model_path, os.path.basename(model_path) + ".bin", to_str=True)
+ if model_name in ["d2v", "w2v"]:
+ postfix = ".bin" if model_name == "d2v" else ".kv"
+ model_path = path_append(model_path, os.path.basename(model_path) + postfix, to_str=True)
return T2V(model_name, model_path, *args)
diff --git a/examples/i2v/get_pretrained_i2v_d2v_w2v.ipynb b/examples/i2v/get_pretrained_i2v_d2v_w2v.ipynb
new file mode 100644
index 00000000..11ee2705
--- /dev/null
+++ b/examples/i2v/get_pretrained_i2v_d2v_w2v.ipynb
@@ -0,0 +1,599 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "a048a5ce",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "D:\\MySoftwares\\Anaconda\\envs\\data\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
+ " warnings.warn(msg)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from EduNLP import get_pretrained_i2v\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8aba69ce",
+ "metadata": {},
+ "source": [
+ "# 通过i2v从模型库中下载w2v模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "5ea68229",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EduNLP, INFO Use pretrained t2v model w2v_lit_300\n",
+ "downloader, INFO http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/general_literal_300.zip is saved as ..\\..\\data\\w2v\\general_literal_300.zip\n",
+ "downloader, INFO file existed, skipped\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "item_vector : \n",
+ " (1, 300) [[ 2.61620767e-02 1.09007880e-01 -1.00791618e-01 1.27770424e-01\n",
+ " 1.67004377e-01 -6.44749403e-03 -1.52755976e-01 -4.27449286e-01\n",
+ " 2.11064816e-01 -1.04606055e-01 -1.99688882e-01 1.73974875e-02\n",
+ " 6.59792274e-02 1.11049023e-02 2.03040883e-01 2.70116031e-01\n",
+ " 1.63561106e-01 -1.74831152e-01 5.22611514e-02 2.82236159e-01\n",
+ " 1.29357621e-01 2.07482532e-01 -1.18266962e-01 2.01199263e-01\n",
+ " -1.38246596e-01 -1.75465241e-01 6.99117854e-02 -2.66740382e-01\n",
+ " 1.02460384e-01 -3.31712782e-01 1.85757071e-01 -1.71908706e-01\n",
+ " 1.48611993e-01 2.97288179e-01 -1.17031179e-01 2.51475275e-01\n",
+ " 6.24234136e-03 1.33234814e-01 -2.94128358e-01 1.02684975e-01\n",
+ " 4.81339619e-02 1.33045152e-01 -1.90128967e-01 -1.55469418e-01\n",
+ " 6.80758357e-02 -4.08965945e-02 2.50210136e-01 1.73031300e-01\n",
+ " -7.00023025e-02 1.28045375e-03 -1.74963474e-01 3.89595836e-01\n",
+ " -4.13189828e-02 6.89548859e-03 -1.33688137e-01 -1.75952073e-02\n",
+ " -3.76380421e-02 1.67898625e-01 -1.04088224e-01 -1.10742941e-01\n",
+ " -1.17674254e-01 -2.07567930e-01 -1.81596592e-01 -2.53004730e-01\n",
+ " -1.97571829e-01 -4.96768355e-02 -2.56916851e-01 -2.02153012e-01\n",
+ " 4.78913775e-03 2.37575963e-01 1.90504581e-01 -2.35189229e-01\n",
+ " -3.92443053e-02 1.87763408e-01 1.11054853e-01 -3.68193090e-01\n",
+ " 7.92984962e-02 1.95566360e-02 -4.63362485e-02 3.17584276e-02\n",
+ " 5.38006723e-02 1.15047775e-01 -8.04764479e-02 -9.76808183e-03\n",
+ " 1.05274422e-02 1.44821331e-02 -5.76615818e-02 1.13344096e-01\n",
+ " 1.80485293e-01 2.50715613e-01 1.79755121e-01 1.59545705e-01\n",
+ " 6.87500462e-02 -3.99000123e-02 -4.68821228e-01 1.90297976e-01\n",
+ " -9.93505865e-02 1.17708622e-02 -1.77581891e-01 6.20286576e-02\n",
+ " 8.83746371e-02 7.68374279e-02 -5.02046868e-02 5.31805046e-02\n",
+ " 6.29951209e-02 2.76096761e-02 7.76083618e-02 -1.37116909e-01\n",
+ " 2.75426358e-02 7.84613043e-02 -5.05000539e-02 -2.76279032e-01\n",
+ " 2.02276036e-01 -3.94708999e-02 1.65654942e-01 2.06172436e-01\n",
+ " -8.61438364e-02 4.55883682e-01 3.67088914e-02 7.92133063e-02\n",
+ " -1.26422614e-01 1.97405070e-01 2.18987226e-01 -1.36595652e-01\n",
+ " -3.40153992e-01 1.33978128e-01 -7.38921538e-02 -1.53101951e-01\n",
+ " -7.67037272e-02 -3.98226939e-02 -3.23286623e-01 -1.23587668e-01\n",
+ " 1.55313849e-01 3.70106027e-02 8.56588557e-02 3.60074081e-02\n",
+ " -2.06803635e-01 3.22534516e-02 2.93804020e-01 -3.61453265e-01\n",
+ " 2.71856822e-02 -3.70388150e-01 -1.68041319e-01 2.06136346e-01\n",
+ " 1.71704680e-01 1.76861808e-01 -1.38827398e-01 8.31167921e-02\n",
+ " -2.24039908e-02 1.14384510e-01 -8.55618566e-02 1.70442518e-02\n",
+ " -9.18549523e-02 2.31103808e-01 -1.71056390e-02 -3.33124027e-02\n",
+ " 1.27449661e-01 2.02255957e-02 3.37275257e-03 -2.43208960e-01\n",
+ " -3.07017386e-01 -7.37199932e-02 -1.02498792e-01 -9.41229537e-02\n",
+ " 3.78994077e-01 2.09260024e-02 7.92986900e-02 -3.67502600e-01\n",
+ " 1.87170021e-02 8.99736732e-02 6.94804117e-02 1.06999993e-01\n",
+ " -2.30305761e-01 -5.22091016e-02 -1.93038240e-01 3.37167606e-02\n",
+ " 8.46201032e-02 1.61094978e-01 -6.32047281e-02 -8.15278366e-02\n",
+ " 9.86129791e-02 -2.81494051e-01 -8.01050290e-02 -2.24916548e-01\n",
+ " -1.95268039e-02 5.38655296e-02 -3.17246258e-01 2.84989655e-01\n",
+ " -1.30800530e-01 -8.36309046e-05 -9.80387777e-02 3.64487469e-02\n",
+ " -1.36125535e-01 -7.85720721e-02 -1.11089824e-02 1.87894493e-01\n",
+ " -5.74574396e-02 5.71854822e-02 -1.58534497e-01 2.22718008e-02\n",
+ " 3.85424078e-01 -1.54734880e-01 2.32275367e-01 2.26884276e-01\n",
+ " 1.19488701e-01 -6.18596673e-02 1.13015689e-01 -2.62939390e-02\n",
+ " 2.72549391e-01 -2.44262803e-04 4.14920673e-02 -3.60048413e-01\n",
+ " 1.93277806e-01 1.84728224e-02 -1.26960844e-01 -8.71551409e-02\n",
+ " -4.52413782e-02 -1.83736868e-02 2.14002077e-02 -1.35971278e-01\n",
+ " 1.20486699e-01 1.55642152e-01 -2.24970624e-01 2.02683121e-01\n",
+ " -1.49234146e-01 3.20993572e-01 -1.12141436e-02 2.65024126e-01\n",
+ " -3.61632854e-01 2.91804850e-01 -1.99684903e-01 -9.88604203e-02\n",
+ " -6.80872947e-02 3.27597111e-02 1.66607589e-01 -1.57291532e-01\n",
+ " 6.91130608e-02 1.52290225e-01 3.43445912e-02 2.35028073e-01\n",
+ " 7.02244192e-02 -8.61506015e-02 3.41188818e-01 -2.52313670e-02\n",
+ " -1.81033477e-01 -4.42802086e-02 -2.48439307e-03 8.45776424e-02\n",
+ " 1.57222867e-01 -2.34099865e-01 1.50260121e-01 1.66123807e-02\n",
+ " 1.35237128e-01 4.52173725e-02 1.43081829e-01 -5.87690845e-02\n",
+ " 2.47190185e-02 3.74217749e-01 -6.21478930e-02 -1.62466895e-02\n",
+ " -1.51219010e-01 2.54060566e-01 1.41428277e-01 2.29582638e-01\n",
+ " -1.97292000e-01 1.90529935e-02 -7.42673278e-02 1.91508979e-01\n",
+ " 1.96774215e-01 -5.79532012e-02 -7.21698180e-02 -8.17084908e-02\n",
+ " 1.19057134e-01 -1.53076604e-01 1.11696444e-01 1.01398312e-01\n",
+ " 2.75634527e-01 4.27773222e-02 -3.31790358e-01 2.73687184e-01\n",
+ " 3.32023646e-03 -3.18278879e-01 -1.16946623e-01 1.01567604e-01\n",
+ " 1.61648035e-01 -4.60138805e-02 -6.93662390e-02 4.34832335e-01\n",
+ " -1.48224920e-01 -2.61741936e-01 3.49430107e-02 6.77287430e-02\n",
+ " 3.37100588e-02 2.77218409e-02 6.84884787e-02 -5.99829368e-02\n",
+ " 2.25368775e-02 -1.97441965e-01 -2.54651662e-02 2.84060061e-01]]\n",
+ "token_vector : \n",
+ " (1, 5, 300) [[memmap([ 2.96194404e-01, 1.00701705e-01, 1.62643775e-01,\n",
+ " 2.18375280e-01, 2.50812080e-02, 6.36205375e-02,\n",
+ " 1.29178956e-01, -3.73890221e-01, 3.24861646e-01,\n",
+ " -1.05276950e-01, -7.28812099e-01, -4.62196946e-01,\n",
+ " 1.91645876e-01, -6.72883034e-01, -3.12813371e-02,\n",
+ " 2.02840507e-01, 3.12911361e-01, -5.17836250e-02,\n",
+ " -8.50540847e-02, 5.10810733e-01, 9.77203175e-02,\n",
+ " -3.49134654e-02, -6.01979971e-01, 1.56438991e-01,\n",
+ " -4.25550461e-01, -5.04473131e-03, 5.26125729e-01,\n",
+ " -3.64249855e-01, -4.04438749e-02, -3.40956479e-01,\n",
+ " 1.94766652e-02, 2.60560811e-01, 4.53176767e-01,\n",
+ " 1.37372896e-01, 5.91255911e-02, 5.35910010e-01,\n",
+ " -5.05340874e-01, 2.36500159e-01, -3.59221488e-01,\n",
+ " 6.09920084e-01, 6.38440698e-02, 7.62094930e-02,\n",
+ " -9.64345783e-02, 2.99222887e-01, -4.47020829e-02,\n",
+ " -6.57168388e-01, 7.97871426e-02, 3.06699723e-02,\n",
+ " -3.15478623e-01, -3.21717054e-01, -3.79555196e-01,\n",
+ " 1.61668345e-01, -1.41245142e-01, -3.45727175e-01,\n",
+ " -9.20769870e-02, -2.64545262e-01, -8.11353512e-03,\n",
+ " 3.95217657e-01, -1.12047307e-01, 6.07191026e-02,\n",
+ " -1.71708107e-01, 1.44910008e-01, 1.66939795e-02,\n",
+ " 1.12748951e-01, -5.63685410e-02, -1.75804406e-01,\n",
+ " -2.77096987e-01, -4.07692373e-01, -1.77572191e-01,\n",
+ " 4.71544087e-01, 5.98176539e-01, -3.49848896e-01,\n",
+ " 1.25177085e-01, 1.55241042e-01, 1.66931391e-01,\n",
+ " -5.74054539e-01, 2.66740501e-01, -2.51754194e-01,\n",
+ " 1.58882499e-01, 5.84154308e-01, -3.20607632e-01,\n",
+ " 1.28604308e-01, 2.90958554e-01, -4.55124304e-02,\n",
+ " 2.96064973e-01, -4.07742590e-01, 4.90606546e-01,\n",
+ " 1.05036154e-01, 4.38867778e-01, 2.15489373e-01,\n",
+ " -8.53411183e-02, -1.10689588e-01, -2.16460526e-01,\n",
+ " -1.04986355e-01, -5.63118100e-01, 5.05482435e-01,\n",
+ " -5.54384172e-01, -1.60113692e-01, -4.50889975e-01,\n",
+ " 7.76149407e-02, 1.64326075e-02, 5.44685945e-02,\n",
+ " 3.42964113e-01, 7.07274303e-03, 3.77554417e-01,\n",
+ " 5.83334684e-01, 4.92167659e-02, -6.91621304e-01,\n",
+ " 1.40255168e-01, 2.59836018e-01, -3.71690452e-01,\n",
+ " -7.14120984e-01, -7.75761083e-02, -3.66082728e-01,\n",
+ " 1.06013134e-01, 4.81697828e-01, 3.06762140e-02,\n",
+ " 4.75153327e-01, 8.58582705e-02, 6.68947041e-01,\n",
+ " 1.12745710e-01, 7.29243636e-01, -3.80197376e-01,\n",
+ " -3.56931314e-02, -6.32867157e-01, -1.92625262e-02,\n",
+ " -1.64794222e-01, -3.29848707e-01, -5.99450730e-02,\n",
+ " -4.63650785e-02, 9.67195779e-02, -3.68604630e-01,\n",
+ " -4.51609157e-02, -3.14569265e-01, -2.84906272e-02,\n",
+ " -2.35163167e-01, -1.50432557e-01, 3.86221521e-02,\n",
+ " 2.78316617e-01, -4.60250676e-01, -3.87477517e-01,\n",
+ " -9.15909886e-01, -2.33895734e-01, 5.26256382e-01,\n",
+ " 6.34461343e-01, 6.92536831e-01, -2.79118001e-01,\n",
+ " 3.32599223e-01, -2.89448529e-01, 5.31987101e-02,\n",
+ " 1.11375339e-01, -4.38545674e-01, -3.26774865e-01,\n",
+ " 6.49300516e-02, 1.30401582e-01, 3.44999492e-01,\n",
+ " 3.68947685e-01, 2.34351233e-01, -1.06941594e-03,\n",
+ " 1.15981139e-01, -9.17514637e-02, -5.27929783e-01,\n",
+ " -5.06562227e-03, -7.60384127e-02, 1.47737056e-01,\n",
+ " -4.76852991e-02, 4.07397181e-01, -5.29512823e-01,\n",
+ " 7.61286169e-02, -1.10785462e-01, -2.00572740e-02,\n",
+ " -2.71321267e-01, -3.07893217e-01, -2.52417505e-01,\n",
+ " 5.85605055e-02, 8.53852481e-02, 3.23953509e-01,\n",
+ " -2.87146736e-02, 1.88025355e-01, -2.74873108e-01,\n",
+ " 6.14540577e-02, -1.73081174e-01, -2.23378509e-01,\n",
+ " -3.75009328e-01, 1.78364992e-01, -6.29342556e-01,\n",
+ " -3.91042441e-01, 7.45088905e-02, -5.45797467e-01,\n",
+ " 2.20331490e-01, 7.26294070e-02, 3.76544744e-01,\n",
+ " -2.94531047e-01, -2.82253653e-01, -5.42173861e-03,\n",
+ " 2.48213515e-01, 3.44242826e-02, -1.11324355e-01,\n",
+ " -1.02989756e-01, 2.07821682e-01, 3.42748880e-01,\n",
+ " -2.75986165e-01, 2.17512369e-01, 7.80139387e-01,\n",
+ " 2.08300203e-02, -5.74298143e-01, 1.93183735e-01,\n",
+ " 4.98499572e-02, 2.30552554e-01, 1.45267397e-01,\n",
+ " 4.03060645e-01, -6.18973076e-01, 2.22808436e-01,\n",
+ " -2.12161049e-01, -6.64360464e-01, -8.86285603e-02,\n",
+ " -2.60846853e-01, 3.35359931e-01, 1.28258839e-01,\n",
+ " -6.62824094e-01, 1.42240420e-01, -4.19811439e-03,\n",
+ " -2.65478313e-01, 3.23501050e-01, -4.91243511e-01,\n",
+ " 1.07489243e-01, -1.54155448e-01, 5.92532396e-01,\n",
+ " -1.95484996e-01, 5.40172398e-01, -1.65053040e-01,\n",
+ " -3.23891145e-04, 3.37871283e-01, 2.61556298e-01,\n",
+ " 1.12927049e-01, -2.20960543e-01, 1.60125479e-01,\n",
+ " 1.26879781e-01, 2.13119745e-01, 4.59046572e-01,\n",
+ " 1.33997157e-01, -1.54183894e-01, 6.40111387e-01,\n",
+ " -4.70493376e-01, -5.36377311e-01, -3.75289559e-01,\n",
+ " 1.87592462e-01, 3.55601102e-01, 6.20387085e-02,\n",
+ " -2.40593962e-02, 1.98887423e-01, -6.95033669e-02,\n",
+ " 1.56768542e-02, 2.65629649e-01, -8.66776258e-02,\n",
+ " -1.78666204e-01, -2.14785278e-01, 8.16439152e-01,\n",
+ " 3.90334278e-02, -8.54058266e-02, -1.37500688e-01,\n",
+ " 1.21018156e-01, 3.42288762e-01, 4.27406132e-01,\n",
+ " -8.53472233e-01, 2.43805587e-01, 3.50546800e-02,\n",
+ " 2.98663616e-01, -1.45337895e-01, -1.86531141e-01,\n",
+ " -3.85284901e-01, 2.70582736e-01, 2.59389400e-01,\n",
+ " -3.24853659e-01, 4.36976790e-01, 5.09902477e-01,\n",
+ " 6.68077826e-01, -1.33508623e-01, -3.25732917e-01,\n",
+ " 3.32559854e-01, -1.56141296e-01, -2.14704543e-01,\n",
+ " -1.04073279e-01, -7.48956129e-02, 7.27221727e-01,\n",
+ " 9.71424207e-02, -2.42689922e-01, 6.48995101e-01,\n",
+ " 1.15117133e-01, -3.32194477e-01, -4.44986552e-01,\n",
+ " -1.50502846e-01, 8.56051296e-02, 2.24367138e-02,\n",
+ " 5.09765148e-01, -5.63690662e-01, -2.64978737e-01,\n",
+ " -3.98635745e-01, 9.64068696e-02, 5.66759288e-01], dtype=float32), memmap([ 3.15287e-01, 1.77051e-01, 6.08810e-02, 1.79364e-01,\n",
+ " 3.29245e-01, -4.41467e-01, -9.25660e-02, -3.26262e-01,\n",
+ " 2.84319e-01, -2.16962e-01, -4.64039e-01, -5.69310e-02,\n",
+ " -3.75078e-01, -9.44900e-03, 4.41745e-01, 3.42132e-01,\n",
+ " 1.48341e-01, -6.26920e-02, -1.70950e-02, 2.52995e-01,\n",
+ " 2.15242e-01, 2.09728e-01, -3.90740e-01, 4.25264e-01,\n",
+ " -1.01685e-01, -2.81497e-01, -2.99113e-01, -5.25200e-01,\n",
+ " 5.58118e-01, -3.23426e-01, 1.34162e-01, -5.23788e-01,\n",
+ " 1.58793e-01, 3.54716e-01, -1.84164e-01, 1.04444e-01,\n",
+ " 1.18977e-01, 2.31664e-01, 1.36028e-01, 1.61303e-01,\n",
+ " -1.48298e-01, -5.86860e-02, 9.02600e-03, 9.46720e-02,\n",
+ " 6.31800e-03, 3.88390e-01, 1.55361e-01, 4.15720e-01,\n",
+ " 1.46882e-01, -4.31860e-02, -2.44843e-01, 2.61356e-01,\n",
+ " 3.92755e-01, 2.18738e-01, -6.17970e-01, -4.19673e-01,\n",
+ " -1.09865e-01, -9.17090e-02, -2.17481e-01, -1.70685e-01,\n",
+ " 1.07850e-01, -2.05242e-01, -2.19795e-01, -1.80055e-01,\n",
+ " -1.98632e-01, -1.60960e-02, -5.56096e-01, -5.85160e-02,\n",
+ " 7.84220e-02, 1.55170e-02, 2.23050e-02, -2.58786e-01,\n",
+ " -1.70680e-02, 6.30700e-02, 3.46040e-01, -4.39487e-01,\n",
+ " -4.08797e-01, 3.79016e-01, -5.29580e-02, 7.72520e-02,\n",
+ " -8.82410e-02, 1.35282e-01, -2.65969e-01, 2.77906e-01,\n",
+ " 7.22100e-02, -1.77479e-01, -5.27770e-01, -7.36720e-02,\n",
+ " 1.75278e-01, 1.09576e-01, 1.16799e-01, 2.08992e-01,\n",
+ " -3.30430e-01, -1.46533e-01, -8.96960e-01, 1.62308e-01,\n",
+ " 3.49921e-01, 3.72650e-02, -4.37638e-01, 7.00240e-02,\n",
+ " 8.03180e-02, 3.32803e-01, -3.21030e-02, -5.89190e-02,\n",
+ " -5.65890e-02, -2.64965e-01, 1.98342e-01, -1.27030e-02,\n",
+ " -2.91000e-03, 1.45464e-01, 4.61800e-02, -6.56810e-01,\n",
+ " 2.39402e-01, -1.53483e-01, 2.87671e-01, 1.03870e-02,\n",
+ " -2.53580e-01, 6.18041e-01, 6.12350e-02, -1.07438e-01,\n",
+ " -2.35784e-01, 1.71414e-01, 1.82341e-01, -1.92271e-01,\n",
+ " -4.52226e-01, -3.59716e-01, 4.73700e-03, 9.80390e-02,\n",
+ " 1.85067e-01, 8.80600e-03, -4.47691e-01, -2.70899e-01,\n",
+ " 2.81138e-01, -9.61080e-02, -4.48080e-02, -1.61216e-01,\n",
+ " 7.78320e-02, -1.61700e-03, 3.68315e-01, -5.19670e-01,\n",
+ " 8.82300e-03, -3.15155e-01, -2.02781e-01, 4.53902e-01,\n",
+ " -2.99813e-01, 1.84788e-01, -3.68420e-01, -3.58470e-02,\n",
+ " 7.50730e-02, 6.78120e-02, -3.30320e-02, -1.94240e-02,\n",
+ " -3.41162e-01, 2.97251e-01, -4.17041e-01, -2.18284e-01,\n",
+ " 4.44630e-02, -3.54110e-02, -3.09810e-02, -5.99018e-01,\n",
+ " -5.07050e-02, -3.02726e-01, -3.04077e-01, 3.80173e-01,\n",
+ " 4.11235e-01, 2.54120e-02, 2.13311e-01, -4.82600e-01,\n",
+ " 1.15049e-01, 2.54317e-01, 1.18104e-01, 4.42089e-01,\n",
+ " -1.87696e-01, -4.10800e-03, -9.71170e-02, -7.70940e-02,\n",
+ " -9.95350e-02, 3.45661e-01, 9.02660e-02, -2.73226e-01,\n",
+ " 2.79475e-01, -3.24840e-02, 7.30300e-02, -3.36870e-01,\n",
+ " -5.09357e-01, -1.34780e-01, -1.30971e-01, 2.60989e-01,\n",
+ " 6.71760e-01, -1.65672e-01, -1.30996e-01, 1.39132e-01,\n",
+ " -1.98931e-01, 2.36968e-01, 2.97339e-01, -5.79600e-02,\n",
+ " -1.18475e-01, 4.11962e-01, 4.48970e-02, 3.09170e-02,\n",
+ " 7.25566e-01, -4.23277e-01, 1.17551e-01, 2.90054e-01,\n",
+ " 1.56932e-01, -2.13589e-01, 3.76500e-03, -3.84753e-01,\n",
+ " 1.21688e-01, -5.94567e-01, 1.20173e-01, -1.92981e-01,\n",
+ " 3.16644e-01, 8.67220e-02, -2.36165e-01, -2.22114e-01,\n",
+ " -2.01835e-01, -1.27558e-01, 2.33966e-01, -1.30535e-01,\n",
+ " 3.43284e-01, 2.20745e-01, 2.50888e-01, 5.07130e-02,\n",
+ " -2.33000e-03, 5.73916e-01, -1.32728e-01, -2.22690e-02,\n",
+ " -1.77108e-01, 1.54599e-01, -1.90919e-01, -3.06968e-01,\n",
+ " -6.31542e-01, -2.19120e-02, 7.11860e-02, 1.87242e-01,\n",
+ " -5.13420e-02, -1.81382e-01, -8.44100e-02, 3.57852e-01,\n",
+ " 1.82538e-01, 1.08436e-01, 1.26448e-01, 1.45818e-01,\n",
+ " -1.81700e-01, 3.17561e-01, 1.58116e-01, 4.08697e-01,\n",
+ " 1.45106e-01, -2.15379e-01, 1.76390e-01, -2.38288e-01,\n",
+ " 8.57340e-02, -2.85810e-02, 2.71000e-04, 2.70229e-01,\n",
+ " -1.23640e-01, 2.40153e-01, -1.04036e-01, -5.58530e-02,\n",
+ " -3.87714e-01, 2.27796e-01, 3.40641e-01, 4.54995e-01,\n",
+ " -1.27151e-01, 9.56080e-02, -2.77914e-01, -2.38077e-01,\n",
+ " 5.10054e-01, -4.73260e-02, 2.12810e-02, -2.68700e-01,\n",
+ " -3.50190e-01, -2.06947e-01, 5.17880e-02, -1.82510e-01,\n",
+ " 3.10557e-01, 5.06110e-02, -5.40549e-01, 2.47812e-01,\n",
+ " -7.91090e-02, -2.40019e-01, -4.57742e-01, 4.65154e-01,\n",
+ " -1.25500e-01, -1.43546e-01, -1.17472e-01, 5.04478e-01,\n",
+ " -3.03387e-01, -4.90942e-01, 2.46055e-01, 3.90808e-01,\n",
+ " 1.05975e-01, -1.20549e-01, -3.29972e-01, 3.67824e-01,\n",
+ " 2.10258e-01, -2.14879e-01, 1.16083e-01, 2.06681e-01],\n",
+ " dtype=float32), memmap([ 0.14404905, 0.0936775 , -0.26291716, -0.12771836, 0.01545167,\n",
+ " 0.44028044, -0.61337113, -0.92851853, 0.12839419, -0.16835642,\n",
+ " 0.36100066, 0.21191075, -0.07703327, 0.7357299 , 0.36904803,\n",
+ " 0.70485693, 0.12831394, -0.4766474 , -0.08118793, 0.40624678,\n",
+ " 0.3513442 , 0.24623403, 0.43019554, 0.5177061 , -0.13603646,\n",
+ " -0.38863388, 0.01621091, 0.04971632, -0.34193242, -0.5217188 ,\n",
+ " 0.0033188 , -0.12346203, -0.06514671, 0.420109 , -0.6023694 ,\n",
+ " 0.34664539, 0.32674462, 0.15733871, -0.5390332 , 0.0500537 ,\n",
+ " 0.08395436, 0.03941365, -0.30773544, -0.5121797 , 0.01682046,\n",
+ " 0.00768686, 0.3433431 , 0.31499794, 0.02857671, 0.06469491,\n",
+ " 0.0216397 , 0.8923505 , 0.01500839, -0.3624238 , 0.15296638,\n",
+ " 0.13749424, -0.20227903, 0.36492267, 0.17421733, -0.669397 ,\n",
+ " -0.1707407 , -0.21751766, -0.08253048, -0.5727479 , -0.5022914 ,\n",
+ " 0.06078883, -0.28755787, -0.16986571, 0.6962558 , 0.16540614,\n",
+ " 0.02669252, -0.59127265, -0.2615447 , 0.40730473, 0.19505776,\n",
+ " -0.28997967, -0.06306509, 0.26059225, -0.0678969 , -0.40062913,\n",
+ " 0.37419748, 0.37631658, -0.56704307, -0.2506342 , -0.39875877,\n",
+ " 0.63701975, 0.14794521, 0.2918154 , 0.40491712, 0.41561818,\n",
+ " 0.48484847, 0.5031539 , 0.41204095, -0.12326399, -0.35311183,\n",
+ " 0.03818782, -0.20176806, 0.0031631 , 0.21188988, 0.00367524,\n",
+ " 0.6108648 , -0.25123575, -0.5373104 , 0.43367553, 0.03464083,\n",
+ " -0.46051365, 0.02321884, 0.10928306, -0.05194477, 0.22879702,\n",
+ " -0.19275598, 0.15480393, 0.5197272 , 0.00286526, -0.11272032,\n",
+ " -0.05369571, 0.1293415 , 0.9088507 , -0.18155459, 0.28983495,\n",
+ " -0.4458984 , -0.25753102, 0.6857928 , -0.3646571 , -0.48275086,\n",
+ " 0.5842543 , -0.1746529 , -0.05872086, -0.17744718, -0.1891076 ,\n",
+ " -0.41149625, -0.32803285, 0.31197363, 0.5518509 , 0.25632608,\n",
+ " 0.16129981, -0.36874938, 0.2077312 , 0.32983354, -0.13315345,\n",
+ " 0.4167703 , -0.46767968, -0.22729716, 0.19266164, 0.02987919,\n",
+ " 0.32493445, 0.25293395, 0.21078078, -0.09153535, 0.05758817,\n",
+ " -0.4803477 , 0.2874605 , 0.2396772 , 0.5231111 , -0.30249462,\n",
+ " -0.5583735 , 0.32819614, 0.03221998, -0.04401642, -0.2795492 ,\n",
+ " -1.0318391 , 0.4627127 , 0.07561864, -0.19804536, 0.7015638 ,\n",
+ " 0.19577141, 0.13986789, -0.7613535 , -0.06759125, 0.06247181,\n",
+ " 0.13955595, 0.3875517 , -0.20644519, 0.19915171, -0.19630045,\n",
+ " 0.02884698, -0.01836812, 0.2699957 , -0.10787025, 0.16178446,\n",
+ " 0.241514 , -0.62566847, 0.01112214, -0.2537296 , -0.00996091,\n",
+ " 0.5084142 , -0.02890078, 0.7239531 , -0.08677629, -0.06265563,\n",
+ " -0.17496312, 0.11926087, 0.12080985, -0.21921651, 0.20798117,\n",
+ " 0.146067 , -0.11120407, 0.04849955, -0.01842963, -0.18510175,\n",
+ " 0.46087536, -0.49844128, 0.47089085, 0.08805874, 0.04799706,\n",
+ " -0.04020219, 0.19448115, 0.28060904, 0.59498143, 0.6286977 ,\n",
+ " 0.25782728, -0.52900165, -0.19055991, 0.34723598, 0.3437882 ,\n",
+ " -0.26684594, -0.24123895, 0.03135664, -0.2108534 , -0.29566905,\n",
+ " -0.25923997, 0.37185898, -0.319579 , 0.08019554, -0.4479196 ,\n",
+ " 0.46615997, 0.2930632 , 0.5264077 , -0.86748123, 0.67958534,\n",
+ " -0.35869944, -0.01112498, -0.1485505 , 0.12948273, -0.02525989,\n",
+ " -0.2596611 , 0.11038531, 0.48764285, 0.14068526, -0.03798692,\n",
+ " -0.08467396, -0.59372187, 0.42302644, -0.04826857, -0.06486781,\n",
+ " -0.141831 , -0.19134766, 0.03321413, -0.16950962, -0.17763533,\n",
+ " -0.00633396, -0.10417579, 0.49993476, 0.07888182, 0.15845068,\n",
+ " 0.26553053, 0.093398 , 0.42140028, 0.17526335, 0.11634717,\n",
+ " -0.3166491 , 0.48706874, -0.12781279, 0.45372677, 0.11972222,\n",
+ " -0.02334037, -0.46765664, 0.65633947, 0.49807596, -0.04202469,\n",
+ " 0.3194007 , -0.09065161, 0.5444353 , -0.06324179, -0.19702166,\n",
+ " 0.42725286, 0.2660896 , 0.37858534, -0.74921227, 0.3308406 ,\n",
+ " 0.30193445, -0.8343235 , 0.26368877, -0.0231375 , -0.23532704,\n",
+ " -0.11724953, 0.44338655, 0.4932377 , 0.11040114, -0.40214172,\n",
+ " -0.11391973, 0.04885897, -0.07771134, -0.23785509, 0.17074856,\n",
+ " -0.40348914, 0.22249588, -0.12133241, -0.02102586, 0.37342763],\n",
+ " dtype=float32), memmap([ 0.001818, 0.066143, -0.010453, -0.059094, 0.109722, 0.089641,\n",
+ " -0.097714, -0.171619, 0.076991, -0.149663, 0.041489, 0.038116,\n",
+ " -0.067629, 0.204549, 0.088233, 0.106667, 0.006188, -0.102942,\n",
+ " 0.121798, 0.128443, 0.046703, 0.1317 , -0.096763, 0.054752,\n",
+ " -0.102518, -0.142806, -0.036076, -0.081603, 0.044383, -0.164296,\n",
+ " 0.070078, 0.055948, -0.004155, 0.149297, 0.056086, 0.134398,\n",
+ " 0.140546, 0.034062, -0.175468, -0.003967, 0.011093, 0.018521,\n",
+ " -0.126103, -0.111318, 0.091405, -0.037058, 0.063119, 0.144764,\n",
+ " 0.058746, 0.077937, -0.033163, 0.193789, -0.102106, 0.129214,\n",
+ " -0.072628, 0.059842, -0.103701, 0.023954, 0.057887, -0.046507,\n",
+ " -0.064417, -0.023924, -0.134178, -0.108881, -0.13709 , 0.01039 ,\n",
+ " -0.179701, -0.214993, 0.000531, 0.02743 , 0.130067, -0.077183,\n",
+ " -0.107584, 0.150011, 0.079931, -0.147898, 0.0969 , -0.068616,\n",
+ " 0.034206, -0.022598, 0.144717, 0.011096, -0.047544, -0.05113 ,\n",
+ " 0.103629, -0.023143, -0.051355, -0.035767, -0.007348, 0.044828,\n",
+ " 0.105631, 0.021768, 0.040458, -0.029844, -0.257838, 0.139309,\n",
+ " 0.068758, 0.118263, 0.050123, 0.083626, 0.013022, -0.084703,\n",
+ " -0.013072, 0.069715, 0.14884 , 0.017856, 0.056341, -0.004121,\n",
+ " 0.020378, -0.003951, 0.004127, -0.052104, -0.060354, -0.001392,\n",
+ " 0.107853, 0.174782, 0.040377, 0.003421, 0.116459, -0.02237 ,\n",
+ " -0.087313, 0.139657, 0.041649, -0.03943 , -0.020199, -0.016723,\n",
+ " -0.182761, -0.133106, -0.129202, 0.073727, 0.002839, 0.028117,\n",
+ " -0.129523, -0.036806, -0.019341, 0.023931, -0.132967, -0.059774,\n",
+ " 0.181932, -0.08003 , -0.043331, -0.117484, -0.048691, -0.006077,\n",
+ " 0.020172, 0.0353 , -0.036286, -0.044798, 0.039891, 0.06012 ,\n",
+ " -0.063556, -0.094195, 0.044249, 0.121303, -0.049312, 0.020298,\n",
+ " -0.002654, 0.034145, -0.013738, -0.232038, -0.054745, -0.136469,\n",
+ " -0.039832, 0.064544, 0.215784, 0.030318, 0.020144, -0.233173,\n",
+ " -0.015026, 0.057002, 0.053973, 0.157845, -0.04885 , -0.097335,\n",
+ " -0.185291, -0.070751, 0.108747, 0.097324, -0.009647, -0.067794,\n",
+ " -0.036606, -0.154182, 0.073696, -0.097431, -0.110284, 0.017217,\n",
+ " -0.100543, 0.137706, -0.181598, 0.02755 , 0.047599, -0.086254,\n",
+ " -0.088046, 0.01736 , -0.05082 , -0.027821, -0.066403, -0.001615,\n",
+ " -0.059845, 0.001858, 0.095231, -0.100169, 0.102903, -0.045637,\n",
+ " 0.082969, -0.046515, -0.073889, -0.050359, 0.1704 , 0.044194,\n",
+ " -0.068043, -0.10089 , 0.150201, -0.044282, -0.087403, -0.014644,\n",
+ " -0.026875, 0.05111 , 0.061729, -0.069272, -0.000424, 0.009669,\n",
+ " -0.14954 , 0.015503, -0.134024, 0.247477, -0.095109, -0.079825,\n",
+ " -0.09438 , 0.078368, 0.044009, 0.081116, 0.062616, -0.046864,\n",
+ " 0.117752, -0.071434, 0.061858, 0.117382, -0.062786, -0.092905,\n",
+ " 0.044746, -0.031947, 0.13324 , 0.125241, -0.124592, -0.059041,\n",
+ " -0.092122, -0.040809, 0.125956, -0.112321, 0.11526 , -0.059185,\n",
+ " 0.053145, -0.046326, 0.000989, -0.064012, -0.035191, 0.178805,\n",
+ " 0.04306 , -0.006983, 0.059768, 0.095112, 0.07 , 0.128332,\n",
+ " -0.008888, -0.082776, 0.021835, 0.014393, 0.151615, 0.023218,\n",
+ " -0.000706, -0.09446 , 0.079094, -0.071359, 0.006972, 0.066266,\n",
+ " 0.158137, 0.016611, 0.080212, 0.342714, 0.067966, -0.116047,\n",
+ " -0.076305, -0.092242, -0.001183, -0.02487 , 0.034171, 0.176997,\n",
+ " -0.052631, -0.124433, 0.140774, 0.036618, 0.058951, 0.179688,\n",
+ " 0.044621, -0.061049, 0.141201, -0.16003 , -0.057293, 0.065208],\n",
+ " dtype=float32), memmap([-0.6265381 , 0.10746624, -0.45411274, 0.4279252 , 0.35552195,\n",
+ " -0.18431246, -0.08930775, -0.3369567 , 0.24075818, 0.11722808,\n",
+ " -0.20808306, 0.35608864, 0.6579905 , -0.20242234, 0.14745978,\n",
+ " -0.00591643, 0.22205125, -0.18009076, 0.32284477, 0.11268537,\n",
+ " -0.0642214 , 0.4846641 , 0.06795264, -0.14816476, 0.07455695,\n",
+ " -0.05934457, 0.14241129, -0.4123653 , 0.29217723, -0.30816668,\n",
+ " 0.70174986, -0.52880234, 0.20039181, 0.4249459 , 0.0861659 ,\n",
+ " 0.13597897, -0.04971504, 0.00660922, -0.53294706, -0.30388495,\n",
+ " 0.23007637, 0.5897676 , -0.4293978 , -0.54774433, 0.27053782,\n",
+ " 0.09366655, 0.6094405 , -0.0409955 , -0.2687376 , 0.22867341,\n",
+ " -0.23889586, 0.43881533, -0.37100714, 0.39467642, -0.03873207,\n",
+ " 0.39890605, 0.23576836, 0.14710787, -0.42301714, 0.27215523,\n",
+ " -0.28935546, -0.736066 , -0.48817343, -0.51608884, -0.09347728,\n",
+ " -0.12766261, 0.01586757, -0.15969804, -0.57369095, 0.5079826 ,\n",
+ " 0.17528182, 0.1011444 , 0.06479808, 0.16319028, -0.2326859 ,\n",
+ " -0.38954633, 0.5047141 , -0.22145489, -0.30391484, -0.07938702,\n",
+ " 0.15893753, -0.07606 , 0.18721525, 0.02053021, -0.02050801,\n",
+ " 0.0437555 , -0.3477347 , 0.2793079 , -0.10928848, 0.46806648,\n",
+ " 0.2768383 , 0.17450425, 0.4381418 , 0.20512727, -0.2730783 ,\n",
+ " 0.10620265, -0.15927972, 0.0602769 , -0.26139432, 0.07520308,\n",
+ " -0.27876425, 0.33285427, -0.01150214, -0.18564174, -0.18947059,\n",
+ " 0.26233634, 0.06092324, -0.08642223, 0.03193478, -0.2378395 ,\n",
+ " 0.26163915, -0.11316426, 0.39018112, 0.32073796, 0.43945786,\n",
+ " 0.41769108, -0.37753388, 0.27395254, 0.10154679, -0.43290746,\n",
+ " 0.02413667, 0.20424177, 0.5653507 , -0.05092707, -0.11272695,\n",
+ " 0.48133785, 0.14801037, -0.34187314, -0.2019914 , -0.04617379,\n",
+ " -0.85680455, 0.3214811 , 0.3581415 , 0.08068537, 0.26460782,\n",
+ " 0.3911854 , -0.45970127, -0.0236951 , 0.310623 , -0.6141622 ,\n",
+ " 0.1411436 , -0.03571229, -0.12754174, -0.13606131, 0.47382388,\n",
+ " -0.35325027, -0.26324692, -0.04715102, 0.15399992, 0.33320367,\n",
+ " 0.03775111, 0.34992543, -0.07526408, 0.14892383, 0.55291784,\n",
+ " 0.24479802, -0.10170451, -0.16417724, 0.1066686 , -0.22142069,\n",
+ " -0.30604634, 0.13581215, -0.23913799, -0.641248 , 0.4186506 ,\n",
+ " -0.0991861 , -0.3842266 , 0.1691263 , -0.01497534, 0.18686303,\n",
+ " 0.0558264 , -0.18116452, -0.40064445, -0.10633671, -0.54504323,\n",
+ " 0.20219657, 0.10830315, 0.12120886, -0.47679773, 0.04646945,\n",
+ " -0.05277218, -0.42205456, -0.33499476, -0.06154288, 0.3536029 ,\n",
+ " 0.507819 , -0.934774 , 0.2277912 , -0.5115909 , -0.01997201,\n",
+ " -0.30446318, -0.36643985, -0.21992955, -0.1457182 , -0.50462335,\n",
+ " 0.6309729 , -0.02562943, -0.06159483, -0.65630513, 0.05586406,\n",
+ " 0.30269915, 0.52419907, 0.25251964, 0.02180627, 0.28871545,\n",
+ " 0.565306 , 0.24753755, -0.02681671, 0.24512488, -0.2248134 ,\n",
+ " -0.5055576 , -0.3583963 , 0.46729556, -0.08515082, 0.00933607,\n",
+ " 0.15645684, 0.50458896, -0.382137 , -0.10609938, 0.4784437 ,\n",
+ " 0.37657306, 0.18013594, -0.6411438 , 0.54350305, 0.3293464 ,\n",
+ " 0.20992567, 0.03285853, 0.30827454, -0.47371006, 0.00629947,\n",
+ " -0.32776198, -0.25700122, 0.03916875, -0.15846448, 0.5564328 ,\n",
+ " -0.42164397, 0.0645385 , 0.2109285 , -0.03488603, 0.4891337 ,\n",
+ " 0.07451486, 0.24066378, 0.3831182 , 0.12154611, 0.00236973,\n",
+ " 0.03719952, -0.07466076, -0.33381504, 0.62252325, -0.6411046 ,\n",
+ " 0.26709715, 0.55421406, 0.02169511, -0.04351762, 0.6423761 ,\n",
+ " -0.58692676, 0.40381336, 0.21429129, -0.46406025, -0.04933878,\n",
+ " 0.02600067, 0.33930784, 0.08202445, -0.31654668, -0.11667101,\n",
+ " -0.13803223, 0.31734437, 0.22622582, -0.03053601, -0.03710218,\n",
+ " -0.3155399 , -0.22531359, 0.06255695, -0.09898163, 0.25976712,\n",
+ " -0.31391978, -0.02468867, -0.09841212, -0.12366964, 0.11450931,\n",
+ " -0.11804897, -0.18630046, -0.21030162, 0.23295914, 0.44302842,\n",
+ " -0.04154629, -0.4642268 , 0.35045376, -0.61062485, 0.04100152,\n",
+ " 0.34679234, 0.01286158, -0.0042695 , 0.29488856, -0.0527203 ,\n",
+ " 0.3604901 , -0.19629179, -0.09233268, -0.26149684, 0.20822442],\n",
+ " dtype=float32)]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "i2v = get_pretrained_i2v(\"w2v_lit_300\", \"../../data/w2v\")\n",
+ "item_vector, token_vector = i2v([\"有学者认为:‘向西方学习’,必须适应和结合实际才有作用\"])\n",
+ "\n",
+ "print(\"item_vector : \\n\", np.array(item_vector).shape, item_vector)\n",
+ "print(\"token_vector : \\n\", np.array(token_vector).shape, token_vector)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "af9f0356",
+ "metadata": {},
+ "source": [
+ "# 通过i2v从模型库中下载d2v模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "a02cdbb7",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "EduNLP, INFO Use pretrained t2v model d2v_lit_256\n",
+ "downloader, INFO http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/general_literal_256.zip is saved as ..\\..\\data\\d2v\\general_literal_256.zip\n",
+ "downloader, INFO file existed, skipped\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "item_vector : \n",
+ " (1, 256) [array([ 0.05168616, -0.06012261, 0.03626082, 0.0233949 , 0.05300888,\n",
+ " -0.00135876, -0.0449402 , 0.11617043, 0.0536709 , -0.06298565,\n",
+ " -0.09526226, -0.00348714, -0.00342583, -0.1060208 , 0.09653808,\n",
+ " -0.01833516, -0.06531033, -0.00634937, 0.00376058, -0.05239441,\n",
+ " 0.00384591, 0.02987942, -0.04880023, -0.05722035, -0.16377625,\n",
+ " -0.0551496 , -0.03123471, 0.0167222 , -0.01512162, -0.06904794,\n",
+ " -0.05699088, 0.07458019, -0.04610786, -0.05459723, -0.01972048,\n",
+ " 0.01310308, -0.1661665 , 0.04316062, 0.0106346 , 0.00441967,\n",
+ " -0.11227966, -0.0216081 , -0.01657858, 0.0390131 , -0.09489819,\n",
+ " -0.04717609, 0.02318504, -0.01164363, -0.0863398 , 0.0356447 ,\n",
+ " -0.02488741, -0.02131925, -0.0375803 , 0.03112153, 0.11519039,\n",
+ " 0.05714484, 0.02241669, 0.01309644, -0.00934663, -0.00812965,\n",
+ " 0.00223179, -0.01906266, 0.07378528, 0.04675245, 0.03961517,\n",
+ " 0.03800032, -0.02641658, 0.0242183 , -0.03764332, 0.00098523,\n",
+ " -0.00980584, 0.01296336, -0.12580062, -0.02333812, -0.0174764 ,\n",
+ " 0.01273129, -0.02679108, 0.06894456, 0.03383744, -0.03417306,\n",
+ " 0.03598411, 0.05460283, 0.06407865, 0.11297213, -0.0056845 ,\n",
+ " 0.05433899, 0.02189578, -0.05511612, 0.02214252, 0.03282089,\n",
+ " -0.07074569, 0.00459485, -0.02246627, 0.03741897, -0.06186739,\n",
+ " 0.02809795, -0.01132116, 0.01077965, -0.02932515, 0.05372041,\n",
+ " -0.07060602, -0.03364111, 0.00287449, -0.03367238, -0.00707588,\n",
+ " -0.01266356, -0.0116119 , 0.03763589, -0.05030849, -0.0506805 ,\n",
+ " 0.01617473, 0.02902891, 0.02766665, -0.01659654, -0.09169317,\n",
+ " 0.04292378, 0.04663622, 0.02827189, 0.03266542, 0.01195693,\n",
+ " -0.05572838, -0.03722275, -0.02789672, 0.0252539 , 0.01657911,\n",
+ " 0.02054286, 0.02932693, -0.05625787, 0.02080808, -0.0690353 ,\n",
+ " 0.01416201, -0.11937889, 0.01444815, 0.05260929, 0.0005712 ,\n",
+ " -0.05261262, -0.01543314, 0.01705966, -0.04396763, 0.02431965,\n",
+ " 0.05881024, 0.03761204, 0.01830121, -0.00149444, 0.1358502 ,\n",
+ " -0.11587104, -0.02003725, 0.00385013, 0.01632271, -0.00488979,\n",
+ " 0.03184082, -0.0014026 , 0.06440724, -0.03781892, -0.09144403,\n",
+ " 0.0433217 , -0.04358204, 0.01135502, -0.09185286, 0.05404984,\n",
+ " -0.03470675, -0.0862116 , 0.04344686, 0.06999089, 0.04938227,\n",
+ " -0.01028743, 0.04629426, -0.06526747, -0.09721855, -0.03276761,\n",
+ " -0.01811158, -0.07921333, -0.03268831, -0.01052403, -0.05022546,\n",
+ " 0.02974997, -0.03412613, 0.04961331, 0.0138158 , 0.09043111,\n",
+ " 0.01316238, 0.00163702, 0.07805788, 0.0250666 , -0.00450815,\n",
+ " -0.00470929, 0.00449593, -0.14033723, -0.01469393, 0.0516893 ,\n",
+ " -0.08275685, -0.08630146, 0.0458499 , 0.01399075, 0.09536003,\n",
+ " 0.01121633, -0.05079496, -0.04410382, -0.07359479, -0.03120217,\n",
+ " 0.01741385, 0.0133559 , 0.0501571 , 0.03164428, -0.0443478 ,\n",
+ " 0.03276857, -0.0196498 , -0.00507194, 0.02113156, -0.02301252,\n",
+ " -0.030987 , 0.01116967, -0.00229194, -0.04506126, -0.0973313 ,\n",
+ " -0.00457067, 0.05663091, 0.07327795, 0.04432167, -0.01299081,\n",
+ " 0.06884655, 0.0184796 , 0.00599279, 0.02315673, -0.00633527,\n",
+ " -0.04402763, -0.00453509, -0.06445812, 0.02564598, 0.04351281,\n",
+ " -0.04120508, 0.00388152, -0.01782226, -0.02949523, -0.06305063,\n",
+ " 0.02963926, -0.16031711, 0.00996824, 0.05458128, 0.02867853,\n",
+ " 0.03086467, 0.09444657, 0.0420283 , -0.11675379, -0.0280523 ,\n",
+ " -0.00560202, -0.01304273, 0.00658127, -0.00189307, 0.01767397,\n",
+ " 0.06018311, -0.00552854, 0.00151099, -0.02198849, -0.03597561,\n",
+ " -0.06512164], dtype=float32)]\n",
+ "token_vector : None\n"
+ ]
+ }
+ ],
+ "source": [
+ "i2v = get_pretrained_i2v(\"d2v_lit_256\", \"../../data/d2v\")\n",
+ "item_vector, token_vector = i2v([\"有学者认为:‘向西方学习’,必须适应和结合实际才有作用\"])\n",
+ "\n",
+ "print(\"item_vector : \\n\", np.array(item_vector).shape, item_vector)\n",
+ "print(\"token_vector : \", token_vector)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Data",
+ "language": "python",
+ "name": "data"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/tokenizer/tokenizier.ipynb b/examples/tokenizer/tokenizier.ipynb
new file mode 100644
index 00000000..8dcec093
--- /dev/null
+++ b/examples/tokenizer/tokenizier.ipynb
@@ -0,0 +1,180 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "source": [
+ "from EduNLP.Tokenizer import PureTextTokenizer, TextTokenizer, get_tokenizer"
+ ],
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "D:\\MySoftwares\\Anaconda\\envs\\data\\lib\\site-packages\\gensim\\similarities\\__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
+ " warnings.warn(msg)\n"
+ ]
+ }
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# TextTokenizer and PureTextTokenizer\r\n",
+ "\r\n",
+ "- ‘text’ Tokenizer ignores and skips the FormulaFigures and tokenize latex Formulas as Text\r\n",
+ "- ‘pure_text’ Tokenizer symbolizes the FormulaFigures as [FUMULA] and tokenize latex Formulas as Text"
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## TextTokenizer"
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "source": [
+ "items = [{\r\n",
+ " \"stem\": \"已知集合$A=\\\\left\\\\{x \\\\mid x^{2}-3 x-4<0\\\\right\\\\}, \\\\quad B=\\\\{-4,1,3,5\\\\}, \\\\quad$ 则 $A \\\\cap B=$\",\r\n",
+ " \"options\": [\"1\", \"2\"]\r\n",
+ " }]\r\n",
+ "tokenizer = get_tokenizer(\"text\") # tokenizer = TextTokenizer()\r\n",
+ "tokens = tokenizer(items, key=lambda x: x[\"stem\"])\r\n",
+ "print(next(tokens))"
+ ],
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "['已知', '集合', 'A', '=', '\\\\left', '\\\\{', 'x', '\\\\mid', 'x', '^', '{', '2', '}', '-', '3', 'x', '-', '4', '<', '0', '\\\\right', '\\\\}', ',', '\\\\quad', 'B', '=', '\\\\{', '-', '4', ',', '1', ',', '3', ',', '5', '\\\\}', ',', '\\\\quad', 'A', '\\\\cap', 'B', '=']\n"
+ ]
+ }
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "source": [
+ "items = [\"有公式$\\\\FormFigureID{wrong1?}$,如图$\\\\FigureID{088f15ea-xxx}$,若$x,y$满足约束条件公式$\\\\FormFigureBase64{wrong2?}$,$\\\\SIFSep$,则$z=x+7 y$的最大值为$\\\\SIFBlank$\"]"
+ ],
+ "outputs": [],
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "source": [
+ "\r\n",
+ "tokenizer = get_tokenizer(\"text\") # tokenizer = TextTokenizer()\r\n",
+ "tokens = [t for t in tokenizer(items)]\r\n",
+ "tokens"
+ ],
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[['公式',\n",
+ " '[FORMULA]',\n",
+ " '如图',\n",
+ " '[FIGURE]',\n",
+ " 'x',\n",
+ " ',',\n",
+ " 'y',\n",
+ " '约束条件',\n",
+ " '公式',\n",
+ " '[FORMULA]',\n",
+ " '[SEP]',\n",
+ " 'z',\n",
+ " '=',\n",
+ " 'x',\n",
+ " '+',\n",
+ " '7',\n",
+ " 'y',\n",
+ " '最大值',\n",
+ " '[MARK]']]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 4
+ }
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## PureTextTokenizer"
+ ],
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "source": [
+ "tokenizer = get_tokenizer(\"pure_text\") # tokenizer = PureTextTokenizer()\r\n",
+ "tokens = [t for t in tokenizer(items)]\r\n",
+ "tokens"
+ ],
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[['公式',\n",
+ " '如图',\n",
+ " '[FIGURE]',\n",
+ " 'x',\n",
+ " ',',\n",
+ " 'y',\n",
+ " '约束条件',\n",
+ " '公式',\n",
+ " '[SEP]',\n",
+ " 'z',\n",
+ " '=',\n",
+ " 'x',\n",
+ " '+',\n",
+ " '7',\n",
+ " 'y',\n",
+ " '最大值',\n",
+ " '[MARK]']]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 5
+ }
+ ],
+ "metadata": {}
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3.6.13 64-bit ('data': conda)"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.13"
+ },
+ "interpreter": {
+ "hash": "776957673adb719a00031a24ed5efd2fa5ce8a13405e5193f8d278edd3805d55"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
\ No newline at end of file
diff --git a/tests/test_i2v/test_pretrained.py b/tests/test_i2v/test_pretrained.py
index a07f469a..53825ba8 100644
--- a/tests/test_i2v/test_pretrained.py
+++ b/tests/test_i2v/test_pretrained.py
@@ -4,19 +4,25 @@
from EduNLP import get_pretrained_i2v
from EduNLP.Vector.t2v import PRETRAINED_MODELS
from EduNLP.I2V.i2v import MODELS
-from EduNLP.I2V import D2V
+from EduNLP.I2V import D2V, W2V
def test_pretrained_i2v(tmp_path):
- PRETRAINED_MODELS["test"] = ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/test_256.zip", "d2v"]
- MODELS["test"] = [D2V, "test"]
+ PRETRAINED_MODELS["test_d2v"] = ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/d2v/test_256.zip", "d2v"]
+ MODELS["test_d2v"] = [D2V, "test_d2v"]
d = tmp_path / "model"
d.mkdir()
- get_pretrained_i2v("test", d)
+ get_pretrained_i2v("test_d2v", d)
with pytest.raises(KeyError):
get_pretrained_i2v("error")
- get_pretrained_i2v("test", d)
+ PRETRAINED_MODELS["test_w2v"] = ["http://base.ustc.edu.cn/data/model_zoo/EduNLP/w2v/test_w2v_256.zip", "w2v"]
+ MODELS["test_w2v"] = [W2V, "test_w2v"]
+
+ get_pretrained_i2v("test_w2v", d)
+
+ with pytest.raises(KeyError):
+ get_pretrained_i2v("error")
diff --git a/tests/test_vec/test_vec.py b/tests/test_vec/test_vec.py
index 35394933..d97210d8 100644
--- a/tests/test_vec/test_vec.py
+++ b/tests/test_vec/test_vec.py
@@ -6,7 +6,8 @@
import pytest
from EduNLP.Pretrain import train_vector, GensimWordTokenizer
from EduNLP.Vector import W2V, D2V, RNNModel, T2V, Embedding
-from EduNLP.I2V import D2V as I_D2V
+from EduNLP.I2V import D2V as I_D2V, W2V as I_W2V
+from EduNLP.Tokenizer import get_tokenizer
@pytest.fixture(scope="module")
@@ -31,6 +32,16 @@ def stem_tokens(stem_data):
return _data
+@pytest.fixture(scope="module")
+def stem_text_tokens(stem_data):
+ _data = []
+ tokenizer = get_tokenizer("pure_text")
+ tokens = tokenizer(stem_data)
+ _data = [d for d in tokens]
+ assert _data
+ return _data
+
+
@pytest.fixture(scope="module")
def stem_data_general(data):
test_items = [
@@ -78,6 +89,8 @@ def test_w2v(stem_tokens, tmpdir, method, binary):
t2v = T2V("w2v", filepath=filepath, method=method, binary=binary)
assert len(t2v(stem_tokens[:1])[0]) == t2v.vector_size
+ assert len(t2v.infer_vector(stem_tokens[:1])[0]) == t2v.vector_size
+ assert len(t2v.infer_tokens(stem_tokens[:1])[0][0]) == t2v.vector_size
for _w2v in [[filepath, method, binary], dict(filepath=filepath, method=method, binary=binary)]:
embedding = Embedding(_w2v, device="cpu")
@@ -85,6 +98,33 @@ def test_w2v(stem_tokens, tmpdir, method, binary):
assert items.shape == (5, max(item_len), embedding.embedding_dim)
+def test_w2v_i2v(stem_text_tokens, tmpdir, stem_data):
+ method = "sg"
+ filepath_prefix = str(tmpdir.mkdir(method).join("stem_tf_"))
+ filepath = train_vector(
+ stem_text_tokens,
+ filepath_prefix,
+ 10,
+ method=method,
+ train_params=dict(min_count=0)
+ )
+
+ i2v = I_W2V("pure_text", "w2v", filepath)
+ i_vec, t_vec = i2v(stem_data[:1])
+ assert len(i_vec[0]) == i2v.vector_size
+ assert len(t_vec[0][0]) == i2v.vector_size
+
+ cfg_path = str(tmpdir / method / "i2v_config.json")
+ i2v.save(config_path=cfg_path)
+ i2v = I_W2V.load(cfg_path)
+
+ i_vec = i2v.infer_item_vector(stem_data[:1])
+ assert len(i_vec[0]) == i2v.vector_size
+
+ t_vec = i2v.infer_token_vector(stem_data[:1])
+ assert len(t_vec[0][0]) == i2v.vector_size
+
+
def test_embedding():
with pytest.raises(TypeError):
Embedding("error")
@@ -129,24 +169,24 @@ def test_rnn(stem_tokens, tmpdir):
assert torch.equal(item, item_vec1)
-def test_d2v(stem_tokens, tmpdir, stem_data):
+def test_d2v(stem_text_tokens, tmpdir, stem_data):
method = "d2v"
filepath_prefix = str(tmpdir.mkdir(method).join("stem_tf_"))
filepath = train_vector(
- stem_tokens,
+ stem_text_tokens,
filepath_prefix,
10,
method=method,
train_params=dict(min_count=0)
)
d2v = D2V(filepath)
- assert len(d2v(stem_tokens[0])) == 10
+ assert len(d2v(stem_text_tokens[0])) == 10
assert d2v.vector_size == 10
t2v = T2V("d2v", filepath)
- assert len(t2v(stem_tokens[:1])[0]) == t2v.vector_size
+ assert len(t2v(stem_text_tokens[:1])[0]) == t2v.vector_size
- i2v = I_D2V("text", "d2v", filepath)
+ i2v = I_D2V("pure_text", "d2v", filepath)
i_vec, t_vec = i2v(stem_data[:1])
assert len(i_vec[0]) == i2v.vector_size
assert t_vec is None