In [1]:
from transformers import AutoTokenizer

checkpoint = 'bert-base-uncased'
checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'

#加载字典和分词工具
token = AutoTokenizer.from_pretrained(checkpoint)

token

PreTrainedTokenizerFast(name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', vocab_size=250002, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

In [2]:
import torch
from datasets import load_dataset, concatenate_datasets
import random


#定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, split):
        dataset = load_dataset(path='super_glue', name='boolq')

        #重新切分数据集
        dataset = concatenate_datasets([dataset[i] for i in dataset.keys()])
        self.dataset = dataset.train_test_split(test_size=0.005, seed=0)[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return self.dataset[i]


dataset = Dataset('train')

len(dataset), dataset[0]

Found cached dataset super_glue (/root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-6a0e212c0c0fa0d4.arrow and /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-e52c39cad4c58186.arrow


(15862,
 {'question': 'is profit and loss statement and income statement same',
  'passage': "Income statement -- An income statement or profit and loss account (also referred to as a profit and loss statement (P&L), statement of profit or loss, revenue statement, statement of financial performance, earnings statement, operating statement, or statement of operations) is one of the financial statements of a company and shows the company's revenues and expenses during a particular period. It indicates how the revenues (money received from the sale of products and services before expenses are taken out, also known as the ``top line'') are transformed into the net income (the result after all revenues and expenses have been accounted for, also known as ``net profit'' or the ``bottom line''). The purpose of the income statement is to show managers and investors whether the company made or lost money during the period being reported.",
  'idx': 6874,
  'label': 1})

In [3]:
def collate_fn(data):
    question = [i['question'] for i in data]

    #编码
    return token.batch_encode_plus(question,
                                   truncation=True,
                                   padding=True,
                                   max_length=500,
                                   return_tensors='pt')


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=False,
                                     drop_last=False)

for i, data in enumerate(loader):
    break

len(loader), data

(1983,
 {'input_ids': tensor([[     0,     83,  18348,    136,  86669,  63805,    136,  91763,  63805,
            5701,      2,      1,      1,      1,      1],
         [     0,     83,     70,  56877,    289, 106820,     70,   5701,    237,
              70, 106820,    111,  16095,   1760,      2],
         [     0,     83,   2565,   4299,     70,   5701,  13580,    237,      6,
          138410,   2565,      2,      1,      1,      1],
         [     0,   1556,     10,  46667,  17669,   2809, 144888,    297,     23,
              70,    653,    402,      2,      1,      1],
         [     0,     83,   2685,     10,  25550,  29398,     98,   2054,   2069,
           90695,      2,      1,      1,      1,      1],
         [     0,  14602,     70, 192182,  15889,  12936,   7701,    765,   2499,
           34153,      2,      1,      1,      1,      1],
         [     0,     83,     70,   4842,  94897,     10,    263,     71,    111,
             479,     53,      2,      1,      1,  

In [4]:
from transformers import AutoModel


#定义模型
class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        #加载预训练模型
        self.pretrained = AutoModel.from_pretrained(checkpoint)

        #不训练,不需要计算梯度
        for param in self.pretrained.parameters():
            param.requires_grad_(False)

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(768, 768),
            torch.nn.ReLU(),
            torch.nn.Linear(768, 2),
        )

    def get_feature(self, data):
        with torch.no_grad():
            #[b, L, 384]
            feature = self.pretrained(**data)['last_hidden_state']

        #[b, L]
        attention_mask = data['attention_mask']

        #pad位置的feature是0
        #[b, L, 384] * [b, L, 1] -> [b, L, 384]
        feature *= attention_mask.unsqueeze(dim=2)

        #所有词的feature求和
        #[b, L, 384] -> [b, 384]
        feature = feature.sum(dim=1)

        #求和后的feature除以句子的长度
        #[b, L] -> [b, 1]
        attention_mask = attention_mask.sum(dim=1, keepdim=True)

        #[b, 384] / [b, 1] -> [b, 384]
        feature /= attention_mask.clamp(min=1e-8)

        return feature

    def forward(self, data1, data2):
        feature1 = self.get_feature(data1)
        feature2 = self.get_feature(data2)

        feature = torch.cat([feature1, feature2], dim=1)

        return self.fc(feature)


model = torch.load('models/2.求相似度_分别计算法.model')

model.eval()

Model(
  (pretrained): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(250037, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)

In [5]:
#构建知识矩阵
def build_features():
    global model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    features = []
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        features.append(model.get_feature(data))

        if i % 50 == 0:
            print(i)

    model.cpu()

    features = torch.cat(features)

    torch.save(features.cpu(), 'models/features.pt')


build_features()

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950


In [7]:
#测试
def test():
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=1,
                                              collate_fn=collate_fn,
                                              shuffle=False,
                                              drop_last=False)

    for i, data in enumerate(loader_test):
        feature = model.get_feature(data)
        
        feature = feature.repeat(features.shape[0], 1)
        
        feature = torch.cat([features, feature], dim=1)

        score = model.fc(feature).softmax(dim=1)[:,1]

        argmax = score.argmax().item()

        if score[argmax].item() > 0.99:
            print(i)
            print(score[argmax].item())
            print(token.decode(data['input_ids'][0], skip_special_tokens=True))
            print(dataset[argmax]['question'])


features = torch.load('models/features.pt')

test()

Found cached dataset super_glue (/root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-6a0e212c0c0fa0d4.arrow and /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-e52c39cad4c58186.arrow


0
0.9841532111167908
do all bacteria have peptidoglycan in their cell walls
can there be two foreign keys in a table
1
0.9935773015022278
curium-242 was synthesized by bombarding an isotope with alpha particles
can a person from california buy a gun in arizona
2
0.9999700784683228
can you block the plate in college baseball
do you have to get out of the way of a pitch
3
0.9998216032981873
does the color of a flame indicate its temperature
can you have a turbo and a supercharger at the same time
4
0.9998830556869507
is it illegal to wear clothes and return them
is hand baggage the same as carry on
5
0.9999994039535522
is stock price the same as share price
can you do a like kind exchange on stock
6
0.9989058971405029
did candace cameron win dancing with the stars
did a magician ever win americas got talent
7
0.9999935626983643
is the caribbean sea part of the atlantic
is the caribbean part of the atlantic ocean
8
0.9999872446060181
were the mamma mia songs written for the movie
was pret

78
0.9999995231628418
can you buy beer on sunday in new york
can you buy beer in new york state on sunday
79
0.9999910593032837
are the oscars the same as the academy awards
are oscars and academy awards the same thing
