## 一、数据集预处理

### 1、将训练集的mentions转换为text对应的标记

In [209]:
import pandas as pd

# 读取训练集
df = pd.read_excel("data/train.xlsx")

# entity的起始和后续标记
B_ENT = 3
I_ENT = 4

# 逐个样本处理
for index, row in df.iterrows():

    # mentions的列表
    mention_list = eval(row['mentions'])

    # text字符串的长度
    text_len = len(row['text'])

    # 创建与text等长的label
    label = [0]*text_len

    # 对mention列表的每一个mention做处理，其中每个mention是一个dict
    for mention in mention_list:

        # mention字典中的'offset'字段表示某个实体的开始和结束位置，左闭右开
        (start, end) = eval(mention['offset'])

        # entity的起始位置标记为B_ENT
        label[start] = B_ENT
        
        # entity的其余位置标记为I_ENT
        for num in range(start+1, end):
            label[num] = I_ENT

    # 存储label
    df.loc[index, "labels"] = str(label)
    
# 保存为新的CSV文件
df.to_csv("data/train_labeled.csv")


### 2、分句

#### 方案一：用句号划分

In [210]:
# df0 = pd.read_csv('data/train_labeled.csv')

# # 以句号为分隔
# char_sep = '。'
# li = []

# for index, row in df0.iterrows():
#     # indices是分隔的位置list，第一句从头开始
#     indices = [-1]

#     text = row['text']
#     text_len = len(text)

#     # 找到所有的句号
#     for i in range(len(text)):
#         if text[i] == char_sep:
#             indices.append(i)

#     # 如果结尾不是句号，添加最后一个字符为分隔
#     if indices[-1] != text_len - 1:
#         indices.append(text_len - 1)

#     # 逐句处理
#     for j in range(len(indices)-1):
#         begin = indices[j] + 1
#         end = indices[j+1] + 1
#         sent = text[begin: end]
#         label = eval(row['labels'])[begin: end]
#         length = end - begin
#         li.append([length, sent, label])
    
# df = pd.DataFrame(li, columns = ['length', 'text', 'labels'])

# df.to_csv('data/sep_by_。.csv')


### 方案二、用多种标点划分

In [211]:
df0 = pd.read_csv('data/train_labeled.csv')

max_size = 100

# 以句号为分隔
char_sep = '。'
sep_list1 = [',',' ',';']
sep_list2 = ['、',':']
li = []

for index, row in df0.iterrows():
    # indices是分隔的位置list，第一句从头开始
    indices = [-1]

    text = row['text']
    text_len = len(text)

    # 找到所有的句号
    for i in range(len(text)):
        if text[i] == char_sep:
            indices.append(i)
        else: 
            if i - indices[-1] >= max_size - 2:
                for j in range (i, indices[-1], -1):
                    if text[j] in sep_list1:
                        indices.append(j)
                        break
                    else:
                        if text[j] in sep_list2:
                            indices.append(j)
                            break
        

    # 如果结尾不是句号，添加最后一个字符为分隔
    if indices[-1] != text_len - 1:
        indices.append(text_len - 1)

    # 逐句处理
    for j in range(len(indices)-1):
        begin = indices[j] + 1
        end = indices[j+1] + 1
        sent = text[begin: end]
        label = eval(row['labels'])[begin: end]
        length = end - begin
        if length >= 10:
            li.append([length, sent, label])
    
df = pd.DataFrame(li, columns = ['length', 'text', 'labels'])

df = df.sort_values(by='length', ascending=False)

df.to_csv('data/sep_'+str(max_size)+'.csv')


#### 方案二、指定最大长度

In [212]:
# df = pd.read_csv('data/train_labeled.csv')
# df_s = pd.DataFrame(columns = ['segment_number','id','segment_len','text','mentions','labels'])
# max_size = 100
# #句段长度上限
# #print(df_s.shape)
# #header = 0 把第一行作为列名称，首位空缺设置为默认的0
# #print(df.loc[0])
# #显示列标
# #预期效果：每行是一个短于max_size样本片段，一个样本占有若干行（若干片段）。
# #预期效果：每列是一个片段的信息：
# #         列标：0,               1，        2，      3，      4，          5，          
# #              所属样本第几个片段 所属样本id 片段长度 片段文本  片段mentions  片段labels
# seg_num = 0
# #总片段下标

# for row in df.values:
#     #print(type(row),row.shape)
#     mention_list = eval(row[3])
#     #print(eval(row[5]))
#     #al_len = 0
#     al_seg_len = 0
#     segs = 0
#     #该文本已经切割的片段总长度和片段数目
    
#     df_s.loc[seg_num] = [segs,row[1],0,'',[],'']
#     #df_s = df_s.append(pd.DataFrame([[segs,row[1],0,'',[],'']],columns = df_s.columns))
#     #print(type(df_s.loc[seg_num]))
#     #初始化该文本第一个片段
#     for i in range(len(mention_list)):
#     # 对mention列表的每一个mention做处理，其中每个mention是一个dict
        
#         start, end = eval(mention_list[i]['offset'])
#         # mention字典中的'offset'字段表示某个实体的开始和结束位置，左闭右开
#         if end < al_seg_len + max_size:
#         #实体全部在长度范围内
#             df_s.iloc[seg_num,4].append(mention_list[i])
#             #print(mention_list[i])
#             if i == len(mention_list):
#             #最后一个mention
#                 df_s.iloc[seg_num,3]=row[2][al_seg_len:min(al_seg_len + max_size,len(row[2]))]
#                 df_s.iloc[seg_num,5]=str(eval(row[5])[al_seg_len:min(al_seg_len + max_size,len(eval(row[5])))])
#                 al_seg_len = min(al_seg_len + max_size,len(row[2]))

#                 df_s.iloc[seg_num,2]=len(df_s.iloc[seg_num,3])
#                 seg_num +=1   
#                 segs += 1
#         else:
#         #过长，该实体放入下一个片段,开始整理已经确定的片段
#             i -= 1
#             #下一次还要放入该mention
#             df_s.iloc[seg_num,3]=row[2][al_seg_len:min(al_seg_len + max_size,start)]
#             df_s.iloc[seg_num,5]=str(eval(row[5])[al_seg_len:min(al_seg_len + max_size,start)])
            
#             al_seg_len = min(al_seg_len + max_size,start)

#             df_s.iloc[seg_num,2]=len(df_s.iloc[seg_num,3])
#             seg_num +=1   
#             segs += 1
#             if al_seg_len < len(row[2]): df_s.loc[seg_num] = [segs,row[1],0,'',[],'']
#             #新片段
#     #最后处理尾部无mentions区域
#     while al_seg_len < len(row[2]):
#         df_s.iloc[seg_num,3]=row[2][al_seg_len:min(al_seg_len + max_size,len(row[2]))]
#         df_s.iloc[seg_num,5]=str(eval(row[5])[al_seg_len:min(al_seg_len + max_size,len(eval(row[5])))])
#         al_seg_len = min(al_seg_len + max_size,len(row[2]))
#         if len(eval(df_s.iloc[seg_num,5])) != len(df_s.iloc[seg_num,3]) or len(eval(df_s.iloc[seg_num,5]))>max_size: 
#            print(len(df_s.iloc[seg_num,3]),len(eval(df_s.iloc[seg_num,5])))
#         df_s.iloc[seg_num,2]=len(df_s.iloc[seg_num,3])
#         seg_num +=1   
#         segs += 1
#         if al_seg_len < len(row[2]): df_s.loc[seg_num] = [segs,row[1],0,'',[],'']
#     #print(df_s.loc[0])

# df_s.to_csv("data/train_separated.csv")

## 二、数据处理

### 1、加载分词器并测试

In [213]:
from transformers import AutoTokenizer

#加载分词器
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt6')

# 测试文本
text = ["红土创新基金管理有限公司6月30日发布公告，红土创新盐田港REIT(180301)的基金经理新聘陈超。"
,"金能科技: 金能科技股份有限公司关于“金能转债”转股价格调整的提示性公告"]

# 将文本进行逐字分解
def split_sent(text):
    return [char for char in text]

def split(text_list):

    text_split = []
    for sent in text_list:
        sent_split = split_sent(sent)
        text_split.append(sent_split)
    return text_split


# 分解并编码
def split_and_encode(text_list):

    inputs = tokenizer.batch_encode_plus(
    split(text_list),
    truncation=True,
    padding=True,
    return_tensors='pt',
    is_split_into_words=True)

    return inputs

# 展示编码结果
inputs_test = split_and_encode(text)
for key,value in inputs_test.items():
    print(key ,": \n", value, "\n")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


input_ids : 
 tensor([[ 101, 5273, 1759, 1158, 3173, 1825, 7032, 5052, 4415, 3300, 7361, 1062,
         1385,  127, 3299,  124,  121, 3189, 1355, 2357, 1062, 1440, 8024, 5273,
         1759, 1158, 3173, 4663, 4506, 3949,  160,  147,  151,  162,  113,  122,
          129,  121,  124,  121,  122,  114, 4638, 1825, 7032, 5307, 4415, 3173,
         5470, 7357, 6631,  511,  102],
        [ 101, 7032, 5543, 4906, 2825,  131, 7032, 5543, 4906, 2825, 5500,  819,
         3300, 7361, 1062, 1385, 1068,  754,  100, 7032, 5543, 6760,  965,  100,
         6760, 5500,  817, 3419, 6444, 3146, 4638, 2990, 4850, 2595, 1062, 1440,
          102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0]]) 

token_type_ids : 
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

### 2、定义新的Dataset类

In [214]:
import torch
import datasets

# Dataset类，用于处理训练集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, path):
        
        # 用pandas读取csv，并将其转换为Dataset
        data_df = pd.read_csv(path)
        data_list = data_df.to_dict(orient="list")
        dataset = datasets.Dataset.from_dict(data_list)

        #过滤掉太长的句子
        def f(data):
            return data['length'] <= max_size - 2

        dataset = dataset.filter(f)

        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        token = split_sent(self.data[index]['text'])
        label = eval(self.data[index]['labels'])

        return token, label 

# 创建数据集实例
dataset = Dataset('data/sep_'+str(max_size)+'.csv')

# 使用数据集
token, label = dataset[0]

#len(dataset), list(token), label

Filter: 100%|██████████| 6434/6434 [00:00<00:00, 183786.81 examples/s]


### 3、定义数据整理函数

In [215]:
#数据整理函数
def collate_fn(data):
    
    # tokens是分解后的文本
    tokens = [row[0] for row in data]
    # labels是对应的标记
    labels = [row[1] for row in data]

    # inputs是tokens的编码
    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)

    # 编码的长度（最长长度）
    length = inputs['input_ids'].shape[1]

    # 将[CLS]和[PAD]标记为3，并将标记长度与编码长度统一
    for i in range(len(labels)):
        labels[i] = [3] + labels[i]
        labels[i] += [3] * length
        labels[i] = labels[i][:length]

    return inputs, torch.LongTensor(labels)


### 4、定义数据加载器

In [216]:
batch_size = 20

#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

#查看数据样例
for i, (inputs, labels) in enumerate(loader):
    break

print(len(loader))
print(len(labels))
print(tokenizer.decode(inputs['input_ids'][0]))
print(labels)

for k, v in inputs.items():
    print(k,'\t', v.shape)

321
20
[CLS] 融 资 余 额 占 比 前 十 的 个 股 平 均 流 通 市 值 为 4 1. 9 2 亿 元, 融 资 余 额 占 比 最 高 的 仁 东 控 股 最 新 流 通 市 值 为 2 8. 0 0 亿 元 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
tensor([[3, 0, 0,  ..., 3, 3, 3],
        [3, 0, 3,  ..., 3, 3, 3],
        [3, 0, 3,  ..., 3, 3, 3],
        ...,
        [3, 0, 0,  ..., 3, 3, 3],
        [3, 0, 0,  ..., 3, 3, 3],
        [3, 0, 0,  ..., 3, 3, 3]])
input_ids 	 torch.Size([20, 98])
token_type_ids 	 torch.Size([20, 98])
attention_mask 	 torch.Size([20, 98])


## 三、模型构建

### 1、加载预训练模型

In [217]:
from transformers import AutoModel

# 加载预训练模型
pretrained = AutoModel.from_pretrained('hfl/rbt6')

# 统计参数量
print(sum(i.numel() for i in pretrained.parameters()) / 10000, '万')

#模型试算
pretrained(**inputs).last_hidden_state.shape

torch.save(pretrained, 'model/pretrained.model')

5974.0416 万


### 2、定义下游主体模型

In [218]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")  # 使用可用的第一个 GPU
device1 = torch.device("cpu")

# pretrained = torch.load('model/pretrained.model')

#定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = None

        self.rnn = torch.nn.GRU(768, 768,batch_first=True)
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, inputs):
        if self.tuneing:
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)

        out = self.fc(out).softmax(dim=2)

        return out

    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        if tuneing:
            for i in pretrained.parameters():
                i.requires_grad = True

            pretrained.train()
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval()
            self.pretrained = None

pretrained = pretrained.to(device)
#model = torch.load('model/old.model')
model = Model()
inputs = inputs.to(device)
model.to(device)



model(inputs).shape

torch.Size([20, 98, 8])

In [219]:
#对计算结果和label变形,并且移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    #变形,便于计算loss
    #[b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    #[b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    #忽略对pad的计算结果
    #[b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels


reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),
                       torch.ones(2, 3))

(tensor([[-0.4701,  1.0503,  0.0942, -2.0883, -0.0691,  0.4383,  0.9786, -1.3162],
         [-0.2091, -0.2392, -0.2887, -0.6284, -0.4473, -0.1226,  0.3278, -1.1443],
         [-0.2640, -0.9409,  0.6567, -1.3145, -0.4140, -0.2321,  1.4232, -0.5155],
         [-1.6754, -0.4503,  0.5053, -0.2655,  0.9738,  0.7809, -0.3302,  0.2328],
         [ 0.4400, -0.2033,  0.6827,  2.2053, -1.7541,  0.8680,  0.0548,  0.2773],
         [-1.9146, -0.9129, -0.7194, -0.1552, -0.3077,  0.5339, -1.5717, -0.4352]]),
 tensor([1., 1., 1., 1., 1., 1.]))

In [220]:
#获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    #[b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    #计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高
    select = (labels == 3) + (labels == 4)
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content


get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

(1, 16, 0, 0)

In [221]:
from torch.optim import AdamW


#训练
def train(epochs):
    lr = 2e-5 if model.tuneing else 5e-4

    #训练
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for step, (inputs, labels) in enumerate(loader):
            #模型计算
            #[b, lens] -> [b, lens, 8]
            inputs = inputs.to(device)
            outs = model(inputs)

            labels = labels.to(device)

            #对outs和label变形,并且移除pad
            #outs -> [b, lens, 8] -> [c, 8]
            #labels -> [b, lens] -> [c]
            outs, labels = reshape_and_remove_pad(outs, labels,
                                                  inputs['attention_mask'])

            #梯度下降
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step % 50 == 0:
                counts = get_correct_and_total_count(labels, outs)

                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]

                print(epoch, step, loss.item(), accuracy, accuracy_content)

        if accuracy_content > 0.8:
            break
        torch.save(model, 'model/model.model')

model.fine_tuneing(False)
print(sum(p.numel() for p in model.parameters()) / 10000)
train(1)

354.9704
0 0 2.0862789154052734 0.02663622526636225 0.08695652173913043
0 50 1.3383339643478394 0.936267071320182 0.19230769230769232
0 100 1.3648669719696045 0.9092284417549168 0.14285714285714285
0 150 1.3306382894515991 0.9434129089301503 0.23809523809523808
0 200 1.3988412618637085 0.8751962323390895 0.11173184357541899
0 250 1.3660998344421387 0.9079307201458523 0.1652892561983471
0 300 1.3919060230255127 0.882120253164557 0.11834319526627218


In [222]:
model.fine_tuneing(True)
print(sum(p.numel() for p in model.parameters()) / 10000)
train(20)

6329.012
0 0 1.3921314477920532 0.881896551724138 0.12738853503184713
0 50 1.349989891052246 0.9240196078431373 0.17699115044247787
0 100 1.392376184463501 0.8816326530612245 0.14705882352941177
0 150 1.3194451332092285 0.9551166965888689 0.45977011494252873
0 200 1.323422908782959 0.950834879406308 0.43010752688172044
0 250 1.331437110900879 0.942630185348632 0.39622641509433965
0 300 1.3493136167526245 0.9247038917089678 0.304
1 0 1.3315593004226685 0.94240317775571 0.422680412371134
1 50 1.3123054504394531 0.9624542124542125 0.5119047619047619
1 100 1.3029303550720215 0.9708654670094259 0.575
1 150 1.352297306060791 0.9213085764809903 0.37681159420289856
1 200 1.3433125019073486 0.9307992202729045 0.4017094017094017
1 250 1.3939942121505737 0.8797763280521901 0.3128491620111732
1 300 1.32818603515625 0.9459962756052142 0.45263157894736844
2 0 1.3529618978500366 0.920675105485232 0.36363636363636365
2 50 1.378406286239624 0.8951160928742994 0.3021978021978022
2 100 1.3361860513687134

In [None]:
#train(100)

In [227]:

#测试
def test():
    model_load = torch.load('model/model.model')
    model_load.eval()

    loader_test = torch.utils.data.DataLoader(dataset=dataset,
                                              batch_size=128,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        if step == 5:
            break
        print(step)

        with torch.no_grad():
            #[b, lens] -> [b, lens, 8] -> [b, lens]
            inputs = inputs.to(device)
            model_load = model_load.to(device)
            outs = model_load(inputs)
            labels = labels.to(device)

        #对outs和label变形,并且移除pad
        #outs -> [b, lens, 8] -> [c, 8]
        #labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels,
                                              inputs['attention_mask'])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)


test()

0
1
2
3
4
0.9824910354101299 0.9103922129974235


In [225]:
import torch
def predict_one():
    model_load = torch.load('model/model.model')
    model_load.to(device1)
    model_load.eval()

    text = ["红土创新基金管理有限公司6月30日发布公告，红土创新盐田港REIT(180301)的基金经理新聘陈超。"
    ,"金能科技: 金能科技股份有限公司关于“金能转债”转股价格调整的提示性公告",
    "大地熊成功收购技研株式会社安徽大地熊新材料股份有限公司2023年6月22日,大地熊日本株式会社与技研株式会社\\
        (英文名:p.m.giken inc,简称pm公司)原有股东和相关方签订收购合同》,大地熊日本株式会社全资收购pm公司\\
        并即时启动相关手续交接。 pm公司是专业从事磁钢整体磁气回路的解析与设计,注塑磁体的开发、制造与销售。\\
        此次收购是大地熊向注塑磁体领域的全新产业布局,将进一步推进大地熊海内外战略合作与业务拓展,丰富公司磁性材料及部件产品种类,提升公司综合实力和整体竞争力。",
    "黑猫股份:设立全资子公司投建年产16万吨碳材/橡胶复合母胶项目证券时报e公司讯,黑猫股份(002068)6月27日晚间\\
    公告,公司拟出资1亿元在辽宁省朝阳市高新技术产业开发区设立全资子公司辽宁黑猫复合新材料科技有限公司(简称辽\\
        宁黑猫),并且将以辽宁黑猫为项目主体,投资新建年产16万吨碳材/橡胶复合母胶项目,分三期进行建设,项目预计投资总额为6.88亿元。",
    "秋田满满、小鹿蓝蓝食安问题频发,投诉也没用随着我国新生代父母的科学喂养观念加强,我国婴童辅食行业消费规模\\
        持续上升。早在2019年我国婴幼儿辅食消费市场规模就已经达到404亿元,年复合增长率高达23%,预计未来我国婴幼儿辅食市场规模应在千亿以上。"]
    
    text_split = []
    for sent in text:
        sent_split = [char for char in sent]
        text_split.append(sent_split)

    #print(text_split)
    inputs = tokenizer.batch_encode_plus(
    text_split,
    truncation=True,
    padding=True,
    return_tensors='pt',
    is_split_into_words=True)


    with torch.no_grad():
        #[b, lens] -> [b, lens, 8] -> [b, lens]
        outs = model_load(inputs).argmax(dim=2)

    for i in range(len(text)):
        #移除pad
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i, select]
        #label = labels[i, select]
        
        #输出原句子
        print(tokenizer.decode(input_id).replace(' ', ''), len(input_id))

        #输出tag
        for tag in [out]:
            s = ''
            for j in range(len(tag)):
                if tag[j] == 0:
                    s += '·'
                    continue
                s += tokenizer.decode(input_id[j])
                s += str(tag[j].item())

            print(s)
        print('==========================')
    
predict_one()

[CLS]红土创新基金管理有限公司6月30日发布公告，红土创新盐田港reit(180301)的基金经理新聘陈超。[SEP] 53
[CLS]3红3土4创4新4基4金4管4理4有4限4公4司4··········红3土4创4新4盐4田4港4r4e4i4t4··················[SEP]3
[CLS]金能科技:金能科技股份有限公司关于[UNK]金能转债[UNK]转股价格调整的提示性公告[SEP] 37
[CLS]3金3能4科4技4·金3能4科4技4股4份4有4限4公4司4····················[SEP]3
[CLS]大地熊成功收购技研株式会社安徽大地熊新材料股份有限公司2023年6月22日,大地熊日本株式会社与技研株式会社(英文名:p.m.gikeninc,简称pm公司)原有股东和相关方签订收购合同》,大地熊日本株式会社全资收购pm公司并即时启动相关手续交接。pm公司是专业从事磁钢整体磁气回路的解析与设计,注塑磁体的开发、制造与销售。此次收购是大地熊向注塑磁体领域的全新产业布局,将进一步推进大地熊海内外战略合作与业务拓展,丰富公司磁性材料及部件产品种类,提升公司综合实力和整体竞争力。[SEP] 240
[CLS]3大3地4熊4····技3研4株4式4会4社4安4徽4大4地4熊4新4材4料4股4份4有4限4公4司4···········大3地4熊4日4本4株4式4会4社4·技3研4株4式4会4社4·····p3.4m4.4g4i4k4e4n4i4n4c4···p3m4公4司4·················大3地4熊4日4本4株4式4会4社4····p3m4公4司4············p3m4公4司4·······································大3地4熊4·····················大3地4熊4············································[SEP]3
[CLS]黑猫股份:设立全资子公司投建年产16万吨碳材/橡胶复合母胶项目证券时报e公司讯,黑猫股份(002068)6月27日晚间公告,公司拟出资1亿元在辽宁省朝阳市高新技术产业开发区设立全资子公司辽宁黑猫复合新材料科技有限公司(简称辽宁黑猫),并且将以辽宁黑猫为项目主体,投资新建年产16万吨碳材

In [None]:
import torch
print(torch.cuda.is_available())

In [None]:
import pandas as pd
import math

df = pd.read_excel('data/train.xlsx')
kb = pd.read_excel('data/knowledge_base_new.xlsx')


def find(entity, offset):
    for index, row in kb.iterrows():
        if index not in index_list and (row['stockName'] == entity or row['companyName'] == entity):
            know = \
            {"companyId": row['companyId'],
            "companyName": row['companyName'],
            "mention": entity, 
            "offset": offset, 
            "stockId": int(row['stockId']) if not math.isnan(row['stockId']) else 'null', 
            "stockName": row['stockName'] if not math.isnan(row['stockId']) else 'null', 
            "type": row['type'][:2]}
            knowledge_list.append(know)
            index_list.append(index)
        

for index, row in df.iterrows():

    print(index)

    mention_list = eval(row['mentions'])
    entity_list = []
    index_list = []
    knowledge_list = []

    # 对mention列表的每一个mention做处理，其中每个mention是一个dict
    for mention in mention_list:

        entity = mention['mention']

        offset = mention['offset']

        if entity not in entity_list:

            find(entity, offset)

            entity_list.append(entity)

    # 存储结果
    df.loc[index, "result"] = str(knowledge_list)

df.to_csv('data/knowledge.csv')