Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder-decoder fails at KMeans attention #4

Closed
tomweingarten opened this issue Jun 22, 2020 · 16 comments
Closed

Encoder-decoder fails at KMeans attention #4

tomweingarten opened this issue Jun 22, 2020 · 16 comments

Comments

@tomweingarten
Copy link
Contributor

I haven't been able to dig into the root cause here yet, but I'm getting the following error when trying to run an encoder-decoder:

 File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/encoder_decoder.py", line 77, in generate
    return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autoregressive_wrapper.py", line 71, in generate
    logits, _ = self.net(x, input_mask=input_mask, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autopadder.py", line 33, in forward
    return self.net(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 614, in forward
    x, loss = self.routing_transformer(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 592, in forward
    x, loss = self.layers(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 200, in forward
    out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 137, in forward
    x, f_loss, g_loss = block(x, **kwarg)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 80, in forward
    f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 53, in forward
    return self.net(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 121, in forward
    return self.fn(x, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 524, in forward
    global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 390, in forward
    dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 339, in forward
    self.init(x)
  File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 325, in init
    self.means.data.copy_(means)
RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 1

Here are my model params:

model = RoutingTransformerEncDec(
    enc_num_tokens=7000,
    dec_num_tokens=7000,
    dim=512,
    enc_ff_mult=4,
    dec_ff_mult=4,
    enc_depth=16,
    dec_depth=16,
    enc_heads=8,
    dec_heads=8,
    enc_max_seq_len=8192,
    dec_max_seq_len=8192,
    enc_window_size=128,
    dec_window_size=128,
    enc_causal=False,
    #dec_causal=True,  # decoder is always set to causal,
    enc_ff_dropout=0.05,
    dec_ff_dropout=0.05,
    enc_reversible=True,
    dec_reversible=True,
)
@lucidrains
Copy link
Owner

@tomweingarten ohh, that's weird, I tried it on some random input tensors and it worked

what is the shape of your source and target tensors?

@lucidrains
Copy link
Owner

I have a working script for encoder / decoder here that may help https://github.com/lucidrains/routing-transformer/blob/master/examples/toy_tasks/enc_dec_copy_task.py

@lucidrains
Copy link
Owner

@tomweingarten feel free to send me your full script if you have trouble debugging it. from the trace, it looks to be related to interaction between the kmeans and the reversible network

also, how is this architecture working out for you? what kind of results are you seeing on your end?

@tomweingarten
Copy link
Contributor Author

After some more testing, it looks like this only happens if I run generate() before the first call of the model. Something seems to go wrong with initializing kmeans under those circumstances. I'd like to try this on your script as well to verify it isn't my script but haven't gotten to do that yet.

@lucidrains
Copy link
Owner

@tomweingarten oh! Yes, that makes sense because the means are initialized on first backwards during training. I'll put in some better asserts!

@lucidrains
Copy link
Owner

you can't evaluate if there are no means from which to do the clustering!

@lucidrains
Copy link
Owner

@tomweingarten ok, i've put in a fix in 0.7.2 that will make it stop erroring, but it is still best to train before evaluating in encoder / decoder, or the decoder means will be initialized with only a few samples

@lucidrains
Copy link
Owner

@tomweingarten thanks for reporting this btw

@tomweingarten
Copy link
Contributor Author

Thanks! You're right that it is silly to run the generate() method before fitting. I do it just as a last check to make sure I haven't done anything weird like accidentally load a checkpoint when I shouldn't have. Thanks for the fix!

@lucidrains
Copy link
Owner

@tomweingarten no problem! are you seeing good results? anything interesting you found exploring the hyperparameter space?

@tomweingarten
Copy link
Contributor Author

I've been struggling to get the encoder-decoder model to converge. I can get the loss down to about 2 for my model. Based on the Reformer (trax implementation) and TransformerXL (huggingface implementation) I'm expecting something like 0.5. With routing-transformer I see loss plateau at that level for some time, then explode. Have you experienced something like this before?

Here are my latest parameters:

Set up Encoder-Decoder model

model = RoutingTransformerEncDec(
enc_num_tokens=7500,
dec_num_tokens=7500,
dim=256,
enc_ff_mult=4,
dec_ff_mult=4,
enc_depth=16,
dec_depth=16,
enc_heads=8,
dec_heads=8,
enc_max_seq_len=512,
dec_max_seq_len=4096,
enc_window_size=128,
dec_window_size=128,
enc_causal=False,
#dec_causal=True, # decoder is always set to causal,
enc_attn_dropout = 0.1,
dec_attn_dropout = 0.1,
enc_ff_dropout=0.1,
dec_ff_dropout=0.1,
enc_reversible=True,
dec_reversible=True,
enc_ff_chunks = 10, # feed forward chunking, from Reformer paper
dec_ff_chunks = 10, # feed forward chunking, from Reformer paper
enc_ff_glu = True, # use GLU variant in feedforward
dec_ff_glu = True, # use GLU variant in feedforward
enc_pkm_layers = (4n_layers//8,5n_layers//8+1), # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
enc_pkm_num_keys = 128, # defaults to 128, but can be increased to 256 or 512 as memory allows
dec_pkm_layers = (4n_layers//8,5n_layers//8+1), # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
dec_pkm_num_keys = 128, # defaults to 128, but can be increased to 256 or 512 as memory allows
)

@lucidrains
Copy link
Owner

lucidrains commented Jun 26, 2020

@tomweingarten turns out there was an issue with un-shared QK and causal networks, which has been fixed in the latest minor version bump

your setting looks fine except it is best to mix local attention heads with the global kmeans attention heads. both Aurko and Aran, researchers who work on this variant of sparse attention, advise to do all local attention except for the last couple layers. https://github.com/lucidrains/routing-transformer/blob/master/examples/enwik8_simple/train.py#L47

https://arxiv.org/abs/2004.05150 this paper has more experimental results with how best to distribute the local inductive bias

@lucidrains
Copy link
Owner

@tomweingarten how high is your learning rate? do you have gradient clipping turned on?

@lucidrains
Copy link
Owner

@tomweingarten
Copy link
Contributor Author

@lucidrains Quick update: Running with the new version fixed my training loss problem! Unfortunately I'm seeing some weird results for predictions that I can't quite explain yet, but it's going to take me a bit longer to dig into why that is. I'm also going to play around with mixed attention head locality too, thanks for the tip!

@lucidrains
Copy link
Owner

@tomweingarten awesome! glad to hear it is converging!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants