/
dbert.py
66 lines (53 loc) · 2.83 KB
/
dbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from .base import BaseModel
import torch, torch.nn
import numpy as np
class DBERTRank(BaseModel):
model_name = 'distilbert-base-uncased'
max_grad_norm = 1.0
def __init__(self, *args, **kwargs):
from transformers import (AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
AdamW,
ConstantLRSchedule)
super().__init__(*args, **kwargs)
model_config = AutoConfig.from_pretrained(self.model_name, cache_dir=self.data_dir)
model_config.num_labels = 1 # set up for regression
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if self.device == "cpu":
self.logger.info("RUNNING ON CPU")
self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
config=model_config,
cache_dir=self.data_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, cache_dir=self.data_dir)
self.rerank_model.to(self.device)
self.optimizer = AdamW(self.rerank_model.parameters(), lr=self.lr, correct_bias=False)
self.scheduler = ConstantLRSchedule(self.optimizer)
async def train(self, query, candidates, labels):
input_ids, attention_mask = await self.encode(query, candidates)
labels = torch.tensor(labels, dtype=torch.float).to(self.device, non_blocking=True)
loss = self.rerank_model(input_ids, labels=labels, attention_mask=attention_mask)[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(self.rerank_model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
self.rerank_model.zero_grad()
async def rank(self, query, candidates):
input_ids, attention_mask = await self.encode(query, candidates)
with torch.no_grad():
logits = self.rerank_model(input_ids, attention_mask=attention_mask)[0]
scores = np.squeeze(logits.detach().cpu().numpy())
if len(logits) == 1:
scores = [scores]
return list(np.argsort(scores)[::-1])
async def encode(self, query, candidates):
inputs = [self.tokenizer.encode_plus(
query, candidate, add_special_tokens=True
) for candidate in candidates]
max_len = max(len(t['input_ids']) for t in inputs)
input_ids = [t['input_ids'] + [0] * (max_len - len(t['input_ids'])) for t in inputs]
attention_mask = [[1] * len(t['input_ids']) + [0] * (max_len - len(t['input_ids'])) for t in inputs]
input_ids = torch.tensor(input_ids).to(self.device, non_blocking=True)
attention_mask = torch.tensor(attention_mask).to(self.device, non_blocking=True)
return input_ids, attention_mask