Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for custom loss aggregation of heads #220

Merged
merged 3 commits into from Jan 31, 2020
Merged

Conversation

tholor
Copy link
Member

@tholor tholor commented Jan 29, 2020

Context:
Each prediction head produces a loss that gets then aggregated to a full model loss. So far we used sum() as the aggregation function.

Problem:
This is not optimal in the case of multiple PHs as it impacts the scale of the loss and therefore might require adjustments of the learning rate.
It also doesn't allow weighting of different prediction heads, where one task might be more important or just on a different loss scale.

Solution:
I would suggest to make this more flexible and let the user define a custom strategy via
loss_aggregation_fn, which can be passed when initializing the AdaptiveModel.

Default of loss_aggregation_fn will be sum(), but you can configure any fn that [Tensor, Tensor ...] and returns a single Tensor.

Example:

    import torch
    loss_fn = lambda x: torch.mean(torch.stack(x), dim=0)

    model = AdaptiveModel(
        language_model=language_model,
        prediction_heads=[lm_prediction_head, next_sentence_head],
        embeds_dropout_prob=0.1,
        lm_output_types=["per_token", "per_sequence"],
        device=device,
        loss_aggregation_fn=loss_fn
    )

Related to discussion in #182

What do you think @johann-petrak ?
Would that also cover your use cases?

@johann-petrak
Copy link
Contributor

I think this needs a bit more thinking through. I am not too sure myself, since what I mentioned in #182 is based on what I am planning to do, not what I already have experience with.

My main concern is that the method for combining the loss may depend on the context, where the context could be global (in this case your solution would be sufficient because the lambda could be a closure over the global data), but also could be the concrete instances or the concrete batch.

For example, the way how we want to combine the losses could depend on the targets for each head.
Another completely different approach could depend on the training step, e.g. when we want to boost the importance of each head in a round robin fashion every kth batch.
I have no idea which of these things are really useful or required, just thinking they may turn out to be important.

So I guess ideally we would pass a couple of additional parameters to the function which allow the user to access all these things if needed or otherwise they can ignore those parameters.

@johann-petrak
Copy link
Contributor

That one looks good, all there I can think of now! :)

@tholor
Copy link
Member Author

tholor commented Jan 31, 2020

Great :)
It's a bit of a trade-off here between giving flexibility and keeping the code clean and easy to understand. For now, this is the best compromise I could come up with. Let's merge it and reflect on it once we have a clearer understanding of the more complex use cases (e.g. loss depending on batches).

@tholor tholor added enhancement New feature or request part: model part: trainer Trainer labels Jan 31, 2020
@tholor
Copy link
Member Author

tholor commented Jan 31, 2020

@AndriyMulyar, this PR might be helpful for your work on MTL. While it doesn't implement a round-robin dataloader, you could basically switch on/off or weight the loss of prediction heads when you iterate over batches (batch 1 => only the loss of head 1, batch 2 = loss of head 2 ...)

Code sketch:

def loss_aggregation_fn(tensors, global_step, batch=None):
    if global_step % 2:
        return tensors[0]
    else:
        return tensors[1]
...

model = AdaptiveModel(
        language_model=language_model,
        prediction_heads=[head1, head2],
        embeds_dropout_prob=0.1,
        lm_output_types=["per_sequence","per_sequence"],
        device=device,
        loss_aggregation_fn=loss_aggregation_fn)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request part: model part: trainer Trainer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants