Skip to content

Commit

Permalink
deploy a reasonable solution for encoder auxiliary loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 19, 2020
1 parent 460863b commit 93ec372
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,17 @@ model = RoutingTransformerEncDec(
dec_heads = 8,
dec_max_seq_len = 4096,
dec_window_size = 128,
dec_reversible = True
).cuda()

src = torch.randint(0, 20000, (1, 4096)).cuda()
tgt = torch.randint(0, 20000, (1, 4096)).cuda()
src_mask = torch.ones_like(src).bool().cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

loss = model(src, tgt, enc_input_mask = src_mask, dec_input_mask = tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss, aux_loss = model(src, tgt, enc_input_mask = src_mask, dec_input_mask = tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss.backward()
aux_loss.backward()

# do your training, then to sample up to 2048 tokens based on the source sequence
src = torch.randint(0, 20000, (1, 4096)).cuda()
Expand Down
2 changes: 1 addition & 1 deletion examples/toy_tasks/enc_dec_copy_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def cycle():
model.train()

src, tgt, src_mask, tgt_mask = next(cycle())
loss = model(src, tgt, enc_input_mask=src_mask, dec_input_mask=tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss, _ = model(src, tgt, enc_input_mask=src_mask, dec_input_mask=tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
Expand Down
21 changes: 16 additions & 5 deletions routing_transformer/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ def __init__(self, dim, ignore_index = None, pad_value = 0, **kwargs):
self.enc = enc
self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

# there is an outstanding bug where the network breaks when the decoder is reversible and the encoder auxiliary loss is added to the total loss
# user will have to manually call backwards on encoder auxiliary loss if the decoder reversibility is turned on
# should place a bug bounty on this
self.dec_reversible = dec_kwargs.pop('reversible', False)

# display a warning message
if self.dec_reversible:
print('Warning! Due to an issue with reversible nets and encoder auxiliary losses, you must explicitly call backwards on the encoder auxiliary loss, which is supplied as the second element of the returned tuple on forward')

update_kmeans_on_backwards(self)

@torch.no_grad()
Expand All @@ -81,7 +87,12 @@ def generate(self, seq_in, seq_out_start, max_seq_len = None, **kwargs):
def forward(self, seq_in, seq_out, return_loss = False, randomly_truncate_sequence = False, **kwargs):
enc_kwargs, dec_kwargs, kwargs = extract_and_set_enc_dec_kwargs(kwargs)
context, enc_aux_loss = self.enc(seq_in, **enc_kwargs)
loss = self.dec(seq_out, return_loss = return_loss, randomly_truncate_sequence = randomly_truncate_sequence, context = context, **dec_kwargs)
if not self.dec_reversible:
loss = loss + enc_aux_loss
return loss
loss = self.dec(seq_out, return_loss = return_loss, randomly_truncate_sequence = randomly_truncate_sequence, context = context, aux_loss = enc_aux_loss, **dec_kwargs)

# if decoder reversibility turned on, user must manually call backward on encoder auxiliary losses
if self.dec_reversible:
return loss, enc_aux_loss

aux_loss = torch.tensor(0., requires_grad = True)
loss = loss + enc_aux_loss
return loss, aux_loss
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'routing_transformer',
packages = find_packages(exclude=['examples']),
version = '0.10.1',
version = '0.11.0',
license='MIT',
description = 'Routing Transformer (Pytorch)',
author = 'Phil Wang, Aran Komatsuzaki',
Expand Down

0 comments on commit 93ec372

Please sign in to comment.