/
bert.py
56 lines (45 loc) · 2.51 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import List
import numpy as np
import torch.nn
import torch
from nboost.models.base import BaseModel
class PtBertModel(BaseModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.logger.info('Loading from checkpoint %s' % self.model_dir)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device == torch.device("cpu"):
self.logger.info("RUNNING ON CPU")
else:
self.logger.info("RUNNING ON CUDA")
torch.cuda.synchronize(self.device)
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.rerank_model.to(self.device, non_blocking=True)
def rank(self, query: str, choices: List[str]):
input_ids, attention_mask, token_type_ids = self.encode(query, choices)
with torch.no_grad():
logits = self.rerank_model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)[0]
scores = np.squeeze(logits.detach().cpu().numpy())
if len(scores.shape) > 1 and scores.shape[1] == 2:
scores = np.squeeze(scores[:,1])
if len(logits) == 1:
scores = [scores]
return list(np.argsort(scores)[::-1])
def encode(self, query: str, choices: List[str]):
inputs = [self.tokenizer.encode_plus(query.lower(),
choice.lower(), add_special_tokens=True) for choice in choices]
max_len = min(max(len(t['input_ids']) for t in inputs), self.max_seq_len)
input_ids = [t['input_ids'][:max_len] +
[0] * (max_len - len(t['input_ids'][:max_len])) for t in inputs]
attention_mask = [[1] * len(t['input_ids'][:max_len]) +
[0] * (max_len - len(t['input_ids'][:max_len])) for t in inputs]
token_type_ids = [t['token_type_ids'][:max_len] +
[0] * (max_len - len(t['token_type_ids'][:max_len])) for t in inputs]
input_ids = torch.tensor(input_ids).to(self.device, non_blocking=True)
attention_mask = torch.tensor(attention_mask).to(self.device, non_blocking=True)
token_type_ids = torch.tensor(token_type_ids).to(self.device, non_blocking=True)
return input_ids, attention_mask, token_type_ids