In [1]:
import torch
import numpy as np
import os
import json
from EduNLP.Pretrain import BertTokenizer, finetune_bert
from EduNLP.Vector import T2V
from EduNLP.I2V import Bert, get_pretrained_i2v



# 训练自己的Bert模型
## 1. 数据

In [2]:
BASE_DIR = "E:\Workustc\EduNLP\workMaster\EduNLP"

data_dir = f"{BASE_DIR}/tests/test_vec/test_data"
output_dir = f"{BASE_DIR}/examples/test_model/data/bert"

In [None]:
def raw_data():
    _data = []
    data_path = os.path.join(data_dir, "OpenLUNA.json")
    with open(data_path, encoding="utf-8") as f:
        for line in f.readlines():
            _data.append(json.loads(line))
    return _data

def stem_data(data):
    _data = []
    tokenizer = BertTokenizer()
    for e in data:
        d = tokenizer(e["stem"])
        if d is not None:
            _data.append(d)
    assert _data
    return _data

raw_data = raw_data()
train_items = stem_data(raw_data)

## 2. 训练和评估

In [None]:
train_params = {
  'epochs': 1,
  'save_steps': 100,
  'batch_size': 1,
  'logging_steps': 3
}


finetune_bert(
  train_items,
  output_dir,
  train_params=train_params
)

## 3.使用模型

In [None]:
item = {'stem': '如图$\\FigureID{088f15ea-8b7c-11eb-897e-b46bfc50aa29}$, \
        若$x,y$满足约束条件$\\SIFSep$，则$z=x+7 y$的最大值为$\\SIFBlank$'}

tokenizer_kwargs = {"pretrain_model": output_dir}
i2v = Bert('bert', 'bert', output_dir, tokenizer_kwargs=tokenizer_kwargs)

i_vec, t_vec = i2v(item['stem'])

i_vec, t_vec = i2v([ item['stem'] ])
# i_vec, t_vec = i2v([item['stem'],item2['stem']]) # same output!
# or
# i_vec = i2v.infer_item_vector([item['stem']])
# t_vec = i2v.infer_token_vector([item['stem']])

print(i_vec.shape) # == torch.Size([x, x])
print(t_vec.shape) # == torch.Size([x, x, x])

In [None]:
tokenizer = BertTokenizer()
item = "有公式$\\FormFigureID{wrong1?}$，如图$\\FigureID{088f15ea-xxx}$,\
 若$x,y$满足约束条件公式$\\FormFigureBase64{wrong2?}$,$\\SIFSep$，则$z=x+7 y$的最大值为$\\SIFBlank$"
token_item = tokenizer(item)
print(token_item.input_ids[:10])
[101, 1062, 2466, 1963, 1745, 21129, 166, 117, 167, 5276]
print(tokenizer.tokenize(item)[:10])
['公', '式', '如', '图', '[FIGURE]', 'x', ',', 'y', '约', '束']
items = [item, item]
token_items = tokenizer(items, return_tensors='pt')
print(token_items.input_ids.shape)
# torch.Size([2, 27])
print(len(tokenizer.tokenize(items)))

In [None]:
token_item = tokenizer(item, return_tensors='pt')
print(token_item.input_ids.shape)

token_item = tokenizer(item)
print(token_item.input_ids)

print(len(tokenizer.tokenize(item)))

print(tokenizer.tokenize(item))
print(tokenizer.tokenize(items))

In [4]:
test_items = [
        {"content": "10 米 的 (2/5) = 多少 米 的 (1/2),有 公 式"},
        {"content": "10 米 的 (2/5) = 多少 米 的 (1/2),有 公 式 , 如 图 , 若 $x,y$ 满 足 约 束 条 件 公 式"},
    ]

pretrained_dir = f"{BASE_DIR}/examples/test_model/data/disenq"
i2v = get_pretrained_i2v("disenq_pub_128", model_dir=pretrained_dir)
i_vec, t_vec = i2v(test_items[0], key=lambda x: x["content"])
assert len(i_vec) == 2
assert t_vec.shape[2] == i2v.vector_size

t_vec = i2v.infer_token_vector(test_items[0], key=lambda x: x["content"])
i_vec_k = i2v.infer_item_vector(test_items[0], key=lambda x: x["content"], vector_type="k")
i_vec_i = i2v.infer_item_vector(test_items[0], key=lambda x: x["content"], vector_type="i")
assert i_vec_k.shape == torch.Size([1, 128])
assert i_vec_i.shape == torch.Size([1, 128])
assert t_vec.shape == torch.Size([1, 11, 128])
assert i2v.vector_size == i_vec_k.shape[1]

i_vec, t_vec = i2v.infer_vector(test_items[0], key=lambda x: x["content"], vector_type=None)
assert len(i_vec) == 2
assert i_vec[0].shape == torch.Size([1, 128])
assert i_vec[1].shape == torch.Size([1, 128])
assert t_vec.shape == torch.Size([1, 11, 128])

i_vec, t_vec = i2v(test_items, key=lambda x: x["content"])
assert t_vec.shape == torch.Size([2, 23, 128])

EduNLP, INFO model_dir: E:\Workustc\EduNLP\workMaster\EduNLP\examples\test_model\data\disenq\disenq_pub_128
EduNLP, INFO Use pretrained t2v model disenq_pub_128
downloader, INFO http://base.ustc.edu.cn/data/model_zoo/modelhub/disenq_public/1/disenq_pub_128.zip is saved as E:\Workustc\EduNLP\workMaster\EduNLP\examples\test_model\data\disenq\disenq_pub_128.zip
downloader, INFO file existed, skipped


token_items ['<num>', '米', '的', '<num>', '=', '多少', '米', '的', '(1/2),有', '公', '式']
seqs [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]]
ret {'content_idx': [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]], 'content_len': [11]}
token_items ['<num>', '米', '的', '<num>', '=', '多少', '米', '的', '(1/2),有', '公', '式']
seqs [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]]
ret {'content_idx': [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]], 'content_len': [11]}
token_items ['<num>', '米', '的', '<num>', '=', '多少', '米', '的', '(1/2),有', '公', '式']
seqs [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]]
ret {'content_idx': [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]], 'content_len': [11]}
token_items ['<num>', '米', '的', '<num>', '=', '多少', '米', '的', '(1/2),有', '公', '式']
seqs [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]]
ret {'content_idx': [[0, 2805, 2598, 0, 10, 1147, 2805, 2598, 1, 1, 1553]], 'content_len': [11]}
token_items ['<num>', '米', '的', '<num>',