-
Notifications
You must be signed in to change notification settings - Fork 4
/
transformer.py
75 lines (59 loc) · 2.45 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import math
import torch
import ppgs
###############################################################################
# Transformer model
###############################################################################
class Transformer(torch.nn.Module):
def __init__(
self,
num_layers=ppgs.NUM_HIDDEN_LAYERS,
channels=ppgs.HIDDEN_CHANNELS):
super().__init__()
self.position = PositionalEncoding(channels)
self.input_layer = torch.nn.Conv1d(
ppgs.INPUT_CHANNELS,
ppgs.HIDDEN_CHANNELS,
kernel_size=ppgs.KERNEL_SIZE,
padding='same')
self.model = torch.nn.TransformerEncoder(
torch.nn.TransformerEncoderLayer(channels, ppgs.ATTENTION_HEADS),
num_layers)
self.output_layer = torch.nn.Conv1d(
ppgs.HIDDEN_CHANNELS,
ppgs.OUTPUT_CHANNELS,
kernel_size=ppgs.KERNEL_SIZE,
padding='same')
def forward(self, x, lengths):
mask = mask_from_lengths(lengths).unsqueeze(1)
x = self.input_layer(x) * mask
x = self.model(
self.position(x.permute(2, 0, 1)),
src_key_padding_mask=~mask.squeeze(1)
).permute(1, 2, 0)
return self.output_layer(x) * mask
###############################################################################
# Utilities
###############################################################################
class PositionalEncoding(torch.nn.Module):
def __init__(self, channels, dropout=.1, max_len=5000):
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
index = torch.arange(max_len).unsqueeze(1)
frequency = torch.exp(
torch.arange(0, channels, 2) * (-math.log(10000.0) / channels))
encoding = torch.zeros(max_len, 1, channels)
encoding[:, 0, 0::2] = torch.sin(index * frequency)
encoding[:, 0, 1::2] = torch.cos(index * frequency)
self.register_buffer('encoding', encoding)
def forward(self, x):
if x.size(0) > self.encoding.size(0):
raise ValueError('size is too large')
return self.dropout(x + self.encoding[:x.size(0)])
def mask_from_lengths(lengths, padding=0):
"""Create boolean mask from sequence lengths and offset to start"""
x = torch.arange(
lengths.max() + 2 * padding,
dtype=lengths.dtype,
device=lengths.device)
return x.unsqueeze(0) - 2 * padding < lengths.unsqueeze(1)