Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Understanding adaptive-span loss #13

Closed
prajjwal1 opened this issue Jan 21, 2020 · 7 comments
Closed

Understanding adaptive-span loss #13

prajjwal1 opened this issue Jan 21, 2020 · 7 comments

Comments

@prajjwal1
Copy link

prajjwal1 commented Jan 21, 2020

Hi,

Sorry to bother you. I have gone through the paper several times. I've also looked at the code many times
I just had one query with adaptive span loss. Here's what I interpreted:
This parameter self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) is responsible for calculating loss, mask and span.
In this case, this parameter will be initialized with zero values since as per your config since init_val is kept as 0 (since the mean of all the values of the parameter will be 0).

My question is how is this parameter getting updated ?

When I call adaptive_span.get_loss(), it in turn calls:
self._loss_coeff * self._max_span * self._mask.current_val.mean() which will also return 0.
When I do :
adaptive_span.clamp_param(), nothing will happen since all the values inside the parameter were initialized with 0. These are the only two function calls happening inside train method.
Can you please point out what am I missing ?

@tesatory
Copy link
Contributor

It's a model parameter, so it will be updated by optimizer.step() like any other parameter.

@prajjwal1 prajjwal1 changed the title Understanding loss term from adaptive-span Understanding adaptive-span loss Jan 21, 2020
@prajjwal1
Copy link
Author

prajjwal1 commented Jan 22, 2020

Thanks for your reply. I wanted to ask:

  1. Do you think adaptive span takes a longer time to converge as compared to standard attention ? In my case, I'm seeing improvements but the extent is very less. Could this be due to trim_memory ? Did you try this on other tasks except char LM ?
  2. In your experiments, did adaptive span loss become non zero at any moment ? Although current_val is a parameter and it's being constantly updated, the loss is a constant 0.
    Thanks for your support.

@tesatory
Copy link
Contributor

  1. Not sure what "converge" means here. If you're saying it's not growing large enough, you might want to reduce the loss coefficient associated with it. trim_memory shouldn't affect learning. Yes, we used it on word level LM without a problem.

  2. The loss can be zero if it has too large weight compared to the LM loss. Try setting --adapt-span-loss to 0.

@prajjwal1
Copy link
Author

Hi,
Thanks for replying. What did you use to calculate FLOPS?

@tesatory
Copy link
Contributor

We just counted all the flops in the model. For example, a linear layer has d_in x d_out flops.

@prajjwal1
Copy link
Author

prajjwal1 commented Jan 26, 2020

Thanks for your reply.

  1. In case where trim_len<0, the trim_memory will perform padding on the input tensor as specified here. So in my case, trim_len<0 since 1024 is big, here's what happens:
# query.shape -> [128,36,768]
# key.shape -> [128,20,768]
# value.shape -> [128,20,768]
k,v,k_pe = adaptive.trim_memory(q,k,v,k_pe)
# k.shape -> [128,1060,768]
# v.shape -> [128,1060,768]
# k_pe.shape -> [1,64,768] 

So in this case, I don't think memory consumption is being reduced, since now the dimensions have risen many fold, and more FLOPS are required. Am I right or am I missing something? So for now, I've removed this operation.

  1. Using masking function as specified in the paper, my FLOPS have stayed the same
macs: 12.074G 
params: 237.558M

These results are noted during inference. Did you measure FLOPS (as per in the paper) during training (since spans only change during this process only) ? My spans are changing after some changes, but the FLOPS are same. Is it because trimming operations are solely responsible for reducing FLOPS ?

@tesatory
Copy link
Contributor

As noted in the paper, FLOPS is the number of FLOPS necessary for computing one step prediction. So it's not the training time flops where a batch of samples being processed together.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants