In [1]:
from transformers import AutoTokenizer

checkpoint = 'bert-base-uncased'
checkpoint = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'

#加载字典和分词工具
token = AutoTokenizer.from_pretrained(checkpoint)

token

PreTrainedTokenizerFast(name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', vocab_size=250002, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

In [2]:
import torch
from datasets import load_dataset, concatenate_datasets
import random


#定义数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self, split):
        dataset = load_dataset(path='super_glue', name='boolq')

        #重新切分数据集
        dataset = concatenate_datasets([dataset[i] for i in dataset.keys()])
        self.dataset = dataset.train_test_split(test_size=0.005, seed=0)[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return self.dataset[i]


dataset = Dataset('train')

len(dataset), dataset[0]

Found cached dataset super_glue (/root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-6a0e212c0c0fa0d4.arrow and /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-e52c39cad4c58186.arrow


(15862,
 {'question': 'is profit and loss statement and income statement same',
  'passage': "Income statement -- An income statement or profit and loss account (also referred to as a profit and loss statement (P&L), statement of profit or loss, revenue statement, statement of financial performance, earnings statement, operating statement, or statement of operations) is one of the financial statements of a company and shows the company's revenues and expenses during a particular period. It indicates how the revenues (money received from the sale of products and services before expenses are taken out, also known as the ``top line'') are transformed into the net income (the result after all revenues and expenses have been accounted for, also known as ``net profit'' or the ``bottom line''). The purpose of the income statement is to show managers and investors whether the company made or lost money during the period being reported.",
  'idx': 6874,
  'label': 1})

In [3]:
def collate_fn(data):
    question = [i['question'] for i in data]

    #编码
    return token.batch_encode_plus(question,
                                   truncation=True,
                                   padding=True,
                                   max_length=500,
                                   return_tensors='pt')


#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     collate_fn=collate_fn,
                                     shuffle=False,
                                     drop_last=False)

for i, data in enumerate(loader):
    break

len(loader), data

(1983,
 {'input_ids': tensor([[     0,     83,  18348,    136,  86669,  63805,    136,  91763,  63805,
            5701,      2,      1,      1,      1,      1],
         [     0,     83,     70,  56877,    289, 106820,     70,   5701,    237,
              70, 106820,    111,  16095,   1760,      2],
         [     0,     83,   2565,   4299,     70,   5701,  13580,    237,      6,
          138410,   2565,      2,      1,      1,      1],
         [     0,   1556,     10,  46667,  17669,   2809, 144888,    297,     23,
              70,    653,    402,      2,      1,      1],
         [     0,     83,   2685,     10,  25550,  29398,     98,   2054,   2069,
           90695,      2,      1,      1,      1,      1],
         [     0,  14602,     70, 192182,  15889,  12936,   7701,    765,   2499,
           34153,      2,      1,      1,      1,      1],
         [     0,     83,     70,   4842,  94897,     10,    263,     71,    111,
             479,     53,      2,      1,      1,  

In [4]:
from transformers import AutoModel

pretrained = AutoModel.from_pretrained(checkpoint)

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)
    
pretrained.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(250037, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
         

In [5]:
def get_feature(data):
    with torch.no_grad():
        #[b, L, 384]
        feature = pretrained(**data)['last_hidden_state']

    #[b, L]
    attention_mask = data['attention_mask']

    #pad位置的feature是0
    #[b, L, 384] * [b, L, 1] -> [b, L, 384]
    feature *= attention_mask.unsqueeze(dim=2)

    #所有词的feature求和
    #[b, L, 384] -> [b, 384]
    feature = feature.sum(dim=1)

    #求和后的feature除以句子的长度
    #[b, L] -> [b, 1]
    attention_mask = attention_mask.sum(dim=1, keepdim=True)

    #[b, 384] / [b, 1] -> [b, 384]
    feature /= attention_mask.clamp(min=1e-8)

    return feature


get_feature(data).shape

torch.Size([8, 384])

In [6]:
#构建知识矩阵
def build_features():
    global pretrained
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    pretrained.to(device)

    features = []
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        features.append(get_feature(data))

        if i % 50 == 0:
            print(i)

    pretrained.cpu()

    features = torch.cat(features)

    torch.save(features.cpu(), 'models/features.pt')


build_features()

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850
1900
1950


In [7]:
def test_cosin():
    a = torch.FloatTensor([3, 3])
    b = torch.FloatTensor([9, 0])

    print(a, b)

    #内置函数算cosin
    print(torch.nn.functional.cosine_similarity(a, b, dim=0).item())

    import math

    #2*pi是一个圆周,这里是1/8个圆,所以是(1/4)*pi
    print(math.cos((1 / 4) * math.pi))

    #等价
    print(a.pow(2).sum().sqrt(), a.norm(2))

    #另一种计算cosin的公式
    print(a.matmul(b) / a.norm(2) / b.norm(2))


test_cosin()

tensor([3., 3.]) tensor([9., 0.])
0.7071068286895752
0.7071067811865476
tensor(4.2426) tensor(4.2426)
tensor(0.7071)


In [8]:
def test_cosin():
    #先定义两个矩阵,a是知识库,b是新问题
    a = torch.randn(5, 12)
    b = torch.randn(1, 12)

    #以循环的方式分别求cosin
    for i in range(a.shape[0]):
        cos = b[0].matmul(a[i].T) / b[0].norm(p=2) / a[i].norm(p=2)
        print(i, cos.item())

    #矩阵方式计算
    cosin = b.matmul(a.T) / b.norm(p=2, dim=1, keepdim=True) / a.norm(
        p=2, dim=1, keepdim=True).T
    print(cosin)

    #用内置函数算
    for i in range(a.shape[0]):
        cos = torch.nn.functional.cosine_similarity(b[0], a[i], dim=0)
        print(i, cos.item())

    print(torch.nn.functional.cosine_similarity(b, a, dim=1))


test_cosin()

0 -0.07690700888633728
1 0.39602503180503845
2 0.39414387941360474
3 -0.5003533363342285
4 -0.11575382947921753
tensor([[-0.0769,  0.3960,  0.3941, -0.5004, -0.1158]])
0 -0.07690700888633728
1 0.3960250914096832
2 0.39414387941360474
3 -0.5003533363342285
4 -0.11575383692979813
tensor([-0.0769,  0.3960,  0.3941, -0.5004, -0.1158])


  cos = b[0].matmul(a[i].T) / b[0].norm(p=2) / a[i].norm(p=2)


In [9]:
#测试
def test():
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=1,
                                              collate_fn=collate_fn,
                                              shuffle=False,
                                              drop_last=False)

    for i, data in enumerate(loader_test):
        feature = get_feature(data)

        score = torch.nn.functional.cosine_similarity(feature, features, dim=1)

        argmax = score.argmax().item()

        if score[argmax].item() > 0.9:
            print(i)
            print(score[argmax].item())
            print(token.decode(data['input_ids'][0], skip_special_tokens=True))
            print(dataset[argmax]['question'])


features = torch.load('models/features.pt')

test()

Found cached dataset super_glue (/root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached split indices for dataset at /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-6a0e212c0c0fa0d4.arrow and /root/.cache/huggingface/datasets/super_glue/boolq/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed/cache-e52c39cad4c58186.arrow


7
0.991922914981842
is the caribbean sea part of the atlantic
is the caribbean part of the atlantic ocean
11
0.9634772539138794
is the movie mine based on a true story
is mine the movie based on a true story
15
0.9576605558395386
do the jets and giants share a stadium
do the giants and jets share metlife stadium
17
0.9971187114715576
is the cerebral cortex part of the limbic system
is the limbic system part of the cerebral cortex
18
0.9865292906761169
is it possible for twins to have different fathers
is it possible for twins to have two different fathers
19
0.9255623817443848
does birth certificate count as a form of id
is a birth certificate a valid form of id
24
0.9181225299835205
has anyone ever found a four leaf clover
can you actually find a four leaf clover
28
0.9733133316040039
did the usa soccer team qualify for the world cup
did the united states soccer team qualify for the world cup
45
0.9900728464126587
is there going to be a season 2 for iron fist
is there going to be an i