-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTimeTransformer.py
48 lines (37 loc) · 1.88 KB
/
TimeTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class TimeTransformer(nn.Module):
# Constructor
def __init__( self, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p, ):
super().__init__()
# INFO
self.model_type = "Transformer"
self.dim_model = dim_model
# LAYERS
self.positional_encoder = TimePositionalEncoding(dim_model=dim_model, dropout_p=dropout_p, batch_size = 16)
self.embedding = nn.Embedding(num_tokens, dim_model)
self.transformer = nn.Transformer(
d_model=dim_model,
nhead=num_heads,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dropout=dropout_p,
)
self.out = nn.Linear(dim_model, num_tokens)
self.softmax = nn.Softmax(dim_model)
def forward(self, src, src_time, tgt, tgt_time, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
src = self.embedding(src) * math.sqrt(self.dim_model) #(16, 10, 8)
tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
src = self.positional_encoder(src, src_time)
tgt = self.positional_encoder(tgt, tgt_time)
src = src.permute(1,0,2)
tgt = tgt.permute(1,0,2)
transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
out = self.out(transformer_out)
return out
def get_tgt_mask(self, size) -> torch.tensor:
mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
mask = mask.float()
mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
return mask
def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
return (matrix == pad_token)