# 使用ELMo向量化容器
## 导入功能块

In [1]:
from EduNLP.Pretrain import ElmoTokenizer
from EduNLP.Vector import T2V, ElmoModel
import os

In [2]:
# 设置你的数据路径和输出路径
BASE_DIR = "../.."

data_dir = f"{BASE_DIR}/static/test_data"
output_dir = f"{BASE_DIR}/examples/test_model/elmo/elmo_768"

## 令牌化

In [5]:
# 加载之前训练的模型tokenizer
tokenizer = ElmoTokenizer(os.path.join(output_dir, "vocab.txt"))

# 对题目文本进行令牌化
items = [
    "有公式$\\FormFigureID{wrong1?}$，如图$\\FigureID{088f15ea-xxx}$,\
    若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$，则$z=x+7 y$的最大值为$\\SIFBlank$",
    "已知圆$x^{2}+y^{2}-6 x=0$，过点(1,2)的直线被该圆所截得的弦的长度的最小值为"
]

# 可以对单个题目进行令牌化
print(tokenizer(items[0], freeze_vocab=True))
print()

# 也可以对题目列表进行令牌化
print(tokenizer(items, freeze_vocab=True))
print()

token_items = tokenizer(items, pad_to_max_length=True)
_, lengths = token_items

{'seq_idx': tensor([ 804,   19,    6,   69,   26,   66, 1381,  804,    9,  254,   27,   69,
          70,  246,   66,  239,    7]), 'seq_len': tensor(17)}

{'seq_idx': tensor([[ 804,   19,    6,   69,   26,   66, 1381,  804,    9,  254,   27,   69,
           70,  246,   66,  239,    7,    0,    0,    0,    0,    0,    0,    0,
            0],
        [  64,  477,   69,   96,   81,   55,   82,   70,   66,   96,   81,   55,
           82,   71,  467,   69,   27,   78,  844,   77,  477, 1312,  865,  519,
          118]]), 'seq_len': tensor([17, 25])}



## 向量化

In [6]:
t2v = ElmoModel(output_dir)

# # 获得句表征
i_vec = t2v(token_items)
print(i_vec)
print()

# 获得句表征和词表征
i_vec = t2v.infer_vector(token_items, lengths=lengths)
t_vec = t2v.infer_tokens(token_items, lengths=lengths)
print(i_vec.size())
print(t_vec.size())
print()

[EduNLP, INFO] All the weights of ElmoLM were initialized from the model checkpoint at ../../examples/test_model/elmo/elmo_768.
If your task is similar to the task the model of the checkpoint was trained on, you can already use ElmoLM for predictions without further training.


ElmoLMOutput([('pred_forward', tensor([[[-307.3449, -307.3120, -307.3644,  ..., -310.1035, -307.8653,
          -305.8521],
         [-278.2352, -278.3191, -278.3070,  ..., -278.2227, -277.5887,
          -277.1101],
         [-363.3187, -363.3951, -363.4167,  ..., -365.0335, -361.3292,
          -363.0343],
         ...,
         [-283.5177, -283.5760, -283.6111,  ..., -284.5733, -282.6731,
          -283.2103],
         [-248.1853, -248.3669, -248.3075,  ..., -248.6257, -247.6015,
          -247.8452],
         [-241.4586, -241.4421, -241.4153,  ..., -240.6708, -240.4943,
          -240.6182]],

        [[-334.8899, -334.8294, -334.9643,  ..., -334.1731, -335.4581,
          -334.7304],
         [-355.3142, -355.3451, -355.4356,  ..., -356.5914, -352.9772,
          -354.9101],
         [-407.1169, -406.9889, -407.2259,  ..., -411.4367, -405.8418,
          -407.2929],
         ...,
         [-330.3282, -330.3368, -330.3389,  ..., -332.7447, -331.5250,
          -329.7366],
         

  (outputs.forward_output[torch.arange(len(items["seq_len"])), torch.tensor(items["seq_len"]) - 1],


torch.Size([2, 768])
torch.Size([2, 25, 768])

