-
Notifications
You must be signed in to change notification settings - Fork 936
/
transformer.py
66 lines (50 loc) · 2.11 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn.functional as F
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper, top_k, top_p
from x_transformers import TransformerWrapper, Decoder
class CustomARWrapper(AutoregressiveWrapper):
def __init__(self, *args, **kwargs):
super(CustomARWrapper, self).__init__(*args, **kwargs)
@torch.no_grad()
def generate(self, start_tokens, seq_len=256, eos_token=None, temperature=1., filter_logits_fn=top_k, filter_thres=0.9, **kwargs):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
if num_dims == 1:
start_tokens = start_tokens[None, :]
b, t = start_tokens.shape
self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)
if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]
# print('arw:',out.shape)
logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres=filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)
sample = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)
if eos_token is not None and (torch.cumsum(out == eos_token, 1)[:, -1] >= 1).all():
break
out = out[:, t:]
if num_dims == 1:
out = out.squeeze(0)
self.net.train(was_training)
return out
def get_decoder(args):
return CustomARWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
attn_layers=Decoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
**args.decoder_args
)),
pad_value=args.pad_token)