Skip to content

Commit

Permalink
keeping transformer architecture for now
Browse files Browse the repository at this point in the history
  • Loading branch information
popcornell committed Apr 23, 2021
1 parent 829a983 commit bbe18b5
Showing 1 changed file with 18 additions and 57 deletions.
75 changes: 18 additions & 57 deletions speechbrain/lobes/models/transformer/Transformer.py
@@ -1,4 +1,4 @@
"""Transformer implementaion in the SpeechBrain sytle.
"""Transformer implementaion in the SpeechBrain style.
Authors
* Jianyuan Zhong 2020
Expand Down Expand Up @@ -62,7 +62,6 @@ def __init__(
bias: Optional[bool] = True,
encoder_module: Optional[str] = "transformer",
conformer_activation: Optional[nn.Module] = Swish,
attention_type: Optional[str] = "regularMHA",
):
super().__init__()

Expand All @@ -86,7 +85,6 @@ def __init__(
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
attention_type=attention_type,
)
elif encoder_module == "conformer":
self.encoder = ConformerEncoder(
Expand All @@ -98,7 +96,6 @@ def __init__(
activation=conformer_activation,
kernel_size=kernel_size,
bias=bias,
attention_type=attention_type,
)
assert (
normalize_before
Expand All @@ -121,7 +118,6 @@ def __init__(
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
attention_type=attention_type,
)

def forward(self, **kwags):
Expand Down Expand Up @@ -215,24 +211,12 @@ def __init__(
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
):
super().__init__()

self.attention_type = attention_type
if attention_type == "regularMHA":
self.self_att = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
dropout=dropout,
kdim=kdim,
vdim=vdim,
)
elif attention_type == "relposMHA":
self.self_att = sb.nnet.attention.RelPosMultiHeadAttention(
num_heads=nhead, embed_dim=d_model, dropout=dropout
)

self.self_att = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, dropout=dropout, kdim=kdim, vdim=vdim,
)
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
Expand Down Expand Up @@ -342,7 +326,6 @@ def __init__(
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
):
super().__init__()

Expand All @@ -366,7 +349,6 @@ def __init__(
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
attention_type=attention_type,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -441,35 +423,14 @@ def __init__(
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
):
super().__init__()

self.attention_type = attention_type
if attention_type == "regularMHA":
self.self_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
)
self.mutihead_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead,
d_model=d_model,
kdim=kdim,
vdim=vdim,
dropout=dropout,
)
elif attention_type == "relposMHA":
self.self_attn = sb.nnet.attention.RelPosMultiHeadAttention(
num_heads=nhead, embed_dim=d_model, dropout=dropout
)
self.mutihead_attn = sb.nnet.attention.RelPosMultiHeadAttention(
num_heads=nhead, embed_dim=d_model, dropout=dropout
)
# self.pos_enc = RelPosMHAPositional(d_model)

self.self_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout,
)
self.mutihead_attn = sb.nnet.attention.MultiheadAttention(
nhead=nhead, d_model=d_model, kdim=kdim, vdim=vdim, dropout=dropout,
)
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward(
d_ffn=d_ffn,
input_size=d_model,
Expand Down Expand Up @@ -517,6 +478,7 @@ def forward(
else:
tgt1 = tgt

# self-attention over the target sequence
tgt2, self_attn = self.self_attn(
query=tgt1,
key=tgt1,
Expand All @@ -535,6 +497,7 @@ def forward(
else:
tgt1 = tgt

# multi-head attention over the target sequence and encoder states
tgt2, multihead_attention = self.mutihead_attn(
query=tgt1,
key=memory,
Expand Down Expand Up @@ -602,7 +565,6 @@ def __init__(
dropout=0.1,
activation=nn.ReLU,
normalize_before=False,
attention_type="regularMHA",
):
super().__init__()
self.layers = torch.nn.ModuleList(
Expand All @@ -616,7 +578,6 @@ def __init__(
dropout=dropout,
activation=activation,
normalize_before=normalize_before,
attention_type=attention_type,
)
for _ in range(num_layers)
]
Expand Down Expand Up @@ -743,9 +704,9 @@ def get_lookahead_mask(padded_input):
-------
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]])
>>> get_lookahead_mask(a)
tensor([[False, True, True],
[False, False, True],
[False, False, False]])
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0.]])
"""
seq_len = padded_input.shape[1]
mask = (
Expand All @@ -754,7 +715,7 @@ def get_lookahead_mask(padded_input):
).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0.0, True)
.masked_fill(mask == 1.0, False)
).bool()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask.detach().to(padded_input.device)

0 comments on commit bbe18b5

Please sign in to comment.