Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AMP instead of naive fp16 #133

Merged
merged 15 commits into from
Dec 20, 2019
Merged

Use AMP instead of naive fp16 #133

merged 15 commits into from
Dec 20, 2019

Conversation

BramVanroy
Copy link
Contributor

closes #126

@BramVanroy
Copy link
Contributor Author

I tried running the tests with AMP enabled but ran in some issues that would otherwise also occur, I think, when fp16 would be used.

In particular, I run into an issue with FusedAdam:

AttributeError: 'FusedAdam' object has no attribute 'get_lr'

Since .get_lr seems quite important for the rest of the functionality, I am open to comments about how you would like to deal with this.

@Timoeller
Copy link
Contributor

Hey @BramVanroy
thanks for this thorough PR, we want to proceed with this!
Just FYI: @tholor will take care of the get_lr tomorrow, today he is with a client and possibly won't be able to commit time.

@BramVanroy
Copy link
Contributor Author

Sure, let me know if I can contribute further!

@tholor
Copy link
Member

tholor commented Nov 5, 2019

Hey @BramVanroy,
Thanks for your work! Looks pretty solid already!

I investigated options to deal with .get_lr() when we use the amp optimizer. I see two solutions:

  1. Quick fix: using optimizer.defaults["lr"] to get the maximum LR and using warmup_linear.get_lr(self.global_step) to get the proportion of the LR in the current step
  2. Cleaner way: refactoring optimization.py to use the pytorch standard LR schedules ( torch.optim.lr_scheduler.LambdaLR) as the transformers repo is doing it these days. We could then probably pass the schedule to the Trainer or even call initialize_optimizer() from within the trainer.

I am happy to go for the quick fix and tackle the refactoring in a later PR. What do you think?

I have implentend the quick fix in 7419b37. From a few first tests it seems to work in amp modes "O1" and "O2". However, there's an issue in "O2", if the prediction head is balancing classes via class_weights in the loss function.

if class_weights:
logger.info(f"Using class weights for task '{self.task_name}': {self.class_weights}")
balanced_weights = nn.Parameter(torch.tensor(class_weights), requires_grad=False)
else:
balanced_weights = None
self.loss_fct = CrossEntropyLoss(
weight=balanced_weights,
reduction=loss_reduction,
ignore_index=loss_ignore_index,
)

Do you have an idea how to fix this?

The related error message:

Train epoch 1/1:   0%|                                  | 0/134 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/ubuntu/pycharm/forks-FARM/FARM/examples/doc_classification.py", line 98, in <module>
    model = trainer.train(model)
  File "/home/ubuntu/pycharm/forks-FARM/FARM/farm/train.py", line 158, in train
    per_sample_loss = model.logits_to_loss(logits=logits, **batch)
  File "/home/ubuntu/pycharm/forks-FARM/FARM/farm/modeling/adaptive_model.py", line 129, in logits_to_loss
    all_losses = self.logits_to_loss_per_head(logits, **kwargs)
  File "/home/ubuntu/pycharm/forks-FARM/FARM/farm/modeling/adaptive_model.py", line 116, in logits_to_loss_per_head
    all_losses.append(head.logits_to_loss(logits=logits_for_one_head, **kwargs))
  File "/home/ubuntu/pycharm/forks-FARM/FARM/farm/modeling/prediction_head.py", line 258, in logits_to_loss
    return self.loss_fct(logits, label_ids.view(-1))
  File "/home/ubuntu/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 942, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/home/ubuntu/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 2056, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/home/ubuntu/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 1871, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #3 'weight'

@tholor tholor added enhancement New feature or request part: optimizer Optimizer part: trainer Trainer labels Nov 5, 2019
@BramVanroy
Copy link
Contributor Author

BramVanroy commented Nov 5, 2019

I agree that a temporary fix is suitable, but it might be easier in the long run to indeed implement PyTorch's schedulers. This should make maintenance easier (and the code base potentially smaller). I am not sure how much work that would be.

EDIT: can't we just import the custom schedulers from the transformers library (since we use that anyway)? Then by using setattr() we can set the get_lr methods. I am willing to do this.

Unless the issue of O2 is a reflection of some deeper problem (which I don't see at first sight), I would propose to ignore it. When you look at the issues over at the apex repository, you'll see that many problems occur when using O2 and not O1. mcarilli has said that O2 "really only exists to support some internal use cases, otherwise I'd remove it entirely at this point (because the fused optimizers now work with O1 as well). O1 is always the recommended method for training, and switching to O1 should solve the issue.".

@tholor
Copy link
Member

tholor commented Nov 5, 2019

Unless the issue of O2 is a reflection of some deeper problem (which I don't see at first sight), I would propose to ignore it.

Okay, let's do that!

EDIT: can't we just import the custom schedulers from the transformers library (since we use that anyway)? Then by using setattr() we can set the get_lr methods. I am willing to do this.

Importing from transformers is a good idea. We prefer having synergies wherever it's reasonable, so that we can fix bugs together instead of having two communities working in isolation on similar things.

If we move towards PyTorch's schedulers, importing from transformers is a crystal clear option for me. For the current fix though I am not sure, if I got you right. Are you proposing to use their schedulers, but keep the rest as it is (BertAdam, calls within Trainer ...)? That would help us getting rid of _LRSchedule, but I don't see yet how this helps with accessing optimizer.get_lr() in the Trainer. Probably I am missing a part of your plan :)

@BramVanroy
Copy link
Contributor Author

BramVanroy commented Nov 5, 2019

I hadn't looked at the transformers schedulers in detail but at first glance they don't seem to implement the get_lr method which is proposed to be overloaded in the PyTorch abstract class. (It's not immediately clear to me what the difference is between get_lr and get_last_lr.) Perhaps in transformers they call it lr_lambda?

But to better understand what you mean: do you want to get the lr directly from the optimizer as well as directly from the scheduler?

I think it's useful to first rely on everything that PyTorch provides out of the box, if that is not available see what transformers has to offer, and only if that leaves you empty handed implement your own - exactly for the reason that you mention. When available, Apex can help as well (fused optimizers, for example). So ideally, I would think that doing away with BertAdam and leaving more room for customization is a good way to go. But of course that's just what I would suggest as a user - as maintainers you may have other ideas.

@tholor
Copy link
Member

tholor commented Nov 5, 2019

they don't seem to implement the get_lr method [...]

It seems a bit tricky, but get_lr() is actually provided by the parent class LambdaLR .

If you are interested, that's how it works: In transformers they implement a lr_lambda(), which gets passed to the constructor of LambdaLR(), where it sets an attribute self.lr_lambdas. This attribute is then accessed by the get_lr() method later on.

But to better understand what you mean: do you want to get the lr directly from the optimizer as well as directly from the scheduler?

No, having access through one of the two objects should be enough.

From my perspective, we have however one major flaw with the current FARM optimization style: BertAdam optimizer updates the LR internally (with the help of the schedules), while the FusedAdam optimizer relies on an "external" updates of the LR within Trainer.backward_propagate().

I think it would be cleaner to have one standard way of updating the LR. After a quick sketch I believe that standardizing this is just as much effort as going the full way towards the PyTorch Scheduler as discussed above (incl. scheduler.step() etc. in the train loop).

That's why I now sketched this refactoring in e083e5f. What do you think about this approach? Do you see any potential flaws or improvements?

Sorry, for messing up your branch with my commits ;). Feel free to roll back, if you don't wanna go in this direction.

farm/train.py Show resolved Hide resolved
farm/modeling/optimization.py Outdated Show resolved Hide resolved
@BramVanroy
Copy link
Contributor Author

BramVanroy commented Nov 6, 2019

It seems a bit tricky, but get_lr() is actually provided by the parent class LambdaLR .

If you are interested, that's how it works: In transformers they implement a lr_lambda(), which gets passed to the constructor of LambdaLR(), where it sets an attribute self.lr_lambdas. This attribute is then accessed by the get_lr() method later on.

Ah, got it. Thanks.

I agree that one entry point would be the best way to approach this. I am still not clear about how you would allow freedom in the optimizers and schedulers, though, or if you even want that. If you don't, then we probably shouldn't load all schedulers from the transformers library.

from transformers.optimization import (
ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule, AdamW)

If you do want to give users a lot of freedom, you can pass options to the initialize_optimizer function for the optimizer and the scheduler separately. A basic example of what it could look like.

    def _get_optim(opts, model):
        """ Get the optimizer based on current config. Tries to import name from
            torch.optim.lr_scheduler ,transformers.optimization, or apex.optimizers. """
        optim_name = opts.pop('name')

        try:
            optim_constructor = getattr(import_module('torch.optim'), optim_name)
        except AttributeError:
            try:
                optim_constructor = getattr(import_module('transformers.optimization'), optim_name)
            except AttributeError:
                try:
                    optim_constructor = getattr(import_module('apex.optimizers'), optim_name)
                except (AttributeError, ImportError):
                    raise AttributeError(f"Optimizer '{optim_name}' not found in 'torch', 'transformers', or 'apex'")

        return optim_constructor([p for p in model.parameters() if p.requires_grad], **opts)

    def _get_scheduler(opts, optimizer):
        """ Get the scheduler based on current config. Tries to import name from
            torch.optim.lr_scheduler or transformers.optimization. """
        sched_name = opts.pop('name', None)

        if sched_name:
            try:
                sched_constructor = getattr(import_module('torch.optim.lr_scheduler'), sched_name)
            except AttributeError:
                try:
                    sched_constructor = getattr(import_module('transformers.optimization'), sched_name)
                except AttributeError:
                    raise AttributeError(f"Scheduler '{sched_name}' not found in 'torch' or 'transformers'")
            logging.info(f"Using scheduler '{sched_name}'")
            return sched_constructor(optimizer, **opts)
        else:
            return None

Example input for opts for the optimizer: {'name': 'AdamW', 'lr': 1e-03, 'weight_decay': 2e-04, 'eps': 1e-07}

@tholor
Copy link
Member

tholor commented Nov 6, 2019

I am still not clear about how you would allow freedom in the optimizers and schedulers, though, or if you even want that.

I believe some freedom would be great here! As some users might be overwhelmed, I would try to set decent defaults (e.g. AdamW and linear warmup) and avoid any mandatory opts other than LR.

[...] you can pass options to the initialize_optimizer function for the optimizer and the scheduler separately

I really like the sketched code! The two functions could be called from init_optimizer.

Do you want to go on and implement this (+ your suggested changes from the review)?

Happy to support at any point.

@BramVanroy
Copy link
Contributor Author

Do you want to go on and implement this (+ your suggested changes from the review)?

Sure. I think I can do it tomorrow!

Important to know is which versions of Python you support and which PyTorch versions. The former since I like using f-strings but they're only supported from 3.6 and up (but if lower versions need to be supported I can just use .format, that's fine). The latter because of where the optimizer step should be implemented relative to the scheduler step. transformers officially supports Python 3.5 and up, and PyTorch 1.0.0 and up but their examples do the scheduler step after the optimizer (as it should be done in PyTorch 1.1). I don't really mind supporting Python 3.5, but it would be nice if the minimum requirement of torch is 1.1.

@tholor
Copy link
Member

tholor commented Nov 6, 2019

Great, thanks for your effort! Very much appreciated 👍

Pytorch >= 3.6 is totally fine (we also like f-strings)
PyTorch >= 1.1 is also fine. We so far still support 1.0.1, but already have some mid term plans to make use of the named Tensors in pytorch 1.3.0 . It might now be a good opportunity to step up to 1.1

- Any optimizer present in torch.optim.lr_scheduler, transformers.optimization, and apex.optimizers can be used
- Any scheduler in torch.optim.lr_scheduler and transformers.optimization can be used
- easy implementation of amp by only requiring you to pass the opt level to 'use_cuda' in initialize_optimizer.
- added test for CUDA+AMP

BUG: the added test will fail on PyTorch 1.3 due to a critical bug (pytorch/pytorch#28623)
@BramVanroy
Copy link
Contributor Author

I have implemented the changes that we discussed. I still see one problem, though, and it concerns DDP and AMP. When using AMP and DDP in conjunction, amp.initialize should be called after the model and optimizer are created but before the model is wrapped as DDP. Currently, that is not the case. I would therefore move the DDP away from the training loop and into the optimisation file.

@tholor
Copy link
Member

tholor commented Nov 7, 2019

Sure, makes sense 👍

@BramVanroy
Copy link
Contributor Author

BramVanroy commented Nov 7, 2019

Sure, makes sense 👍

I think I was mistaken. The current implementation does follow the correct order, if I'm right. However, it might be cleaner to do the DDP stuff in a separate method or as part of the optimizer.py file. Perhaps something like this:

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

try:
    from apex import amp
    try:
        from apex.parallel import convert_syncbn_model
        APEX_PARALLEL_AVAILABLE = True
    except AttributeError:
        APEX_PARALLEL_AVAILABLE = False
    AMP_AVAILABLE = True
except ImportError:
    AMP_AVAILABLE = False


def _optimize_model(model, device, local_rank, distributed, optimizer=None, use_amp=None):
    model, optimizer = _init_amp(model, device, optimizer, use_amp)

    if distributed:
        if APEX_PARALLEL_AVAILABLE:
            model = convert_syncbn_model(model)

        n = torch.cuda.device_count() // dist.get_world_size()
        device_ids = list(range(local_rank * n, (local_rank + 1) * n))
        # for some models DistributedDataParallel might complain about parameters
        # not contributing to loss. find_used_parameters remedies that.
        model = DistributedDataParallel(model,
                                        device_ids=device_ids,
                                        output_device=device_ids[0],
                                        find_unused_parameters=True)

    return model, optimizer

def _init_amp(model, device, optimizer=None, use_amp=None):
    model = model.to(device)
    if use_amp and optimizer:
        model, optimizer = amp.initialize(model, optimizer, opt_level=use_amp)

    return model, optimizer

_optimize_model can then be called from initialize_optimizer.

@tholor
Copy link
Member

tholor commented Nov 7, 2019

Yes, I agree that it will be in any case cleaner to move DDP into a separate function.

@BramVanroy
Copy link
Contributor Author

BramVanroy commented Nov 7, 2019

I'm not sure if you still want to support DataParallel. From my tests (and from looking at other libraries), it's seldom used and quite a bit slower than DDP. DDP is a real distributed approach (multiple processes). Using DDP through the torch.distributed.launch script seems like a good way forward.

@tholor
Copy link
Member

tholor commented Nov 7, 2019

I haven't used / benchmarked torch.distributed.launch on a single multi GPU machine yet. It sounds like an interesting alternative. If it's really faster than DataParallel, it might be worth the different style of launching a script.

As many users I know currently rely on DataParallel, I would be in favor of keeping it for now and get rid of it in a separate PR after some benchmarking.

@BramVanroy
Copy link
Contributor Author

I haven't used / benchmarked torch.distributed.launch on a single multi GPU machine yet. It sounds like an interesting alternative. If it's really faster than DataParallel, it might be worth the different style of launching a script.

As many users I know currently rely on DataParallel, I would be in favor of keeping it for now and get rid of it in a separate PR after some benchmarking.

I don't have any numbers, but it's not surprising that DDP performs better than DP in almost every case. From the documentation:

In the single-machine synchronous case, torch.distributed or the torch.nn.parallel.DistributedDataParallel() wrapper may still have advantages over other approaches to data-parallelism, including torch.nn.DataParallel():

Each process maintains its own optimizer and performs a complete optimization step with each iteration. While this may appear redundant, since the gradients have already been gathered together and averaged across processes and are thus the same for every process, this means that no parameter broadcast step is needed, reducing time spent transferring tensors between nodes.

Each process contains an independent Python interpreter, eliminating the extra interpreter overhead and “GIL-thrashing” that comes from driving several execution threads, model replicas, or GPUs from a single Python process. This is especially important for models that make heavy use of the Python runtime, including models with recurrent layers or many small components.

I currently don't have a lot of time to work on this, so feel free to edit and pull how you wish. As some side work, I am working on something similar to what FARM does (modular creation and fine-tuning of transformers for downstream tasks) but it's less beginner-friendly - as my code proposals above perhaps show. It's meant for research purposes, so I'm not sure when or even if I'll open it up.

@tholor
Copy link
Member

tholor commented Nov 14, 2019

Ok, got It. I can take over from here and finish the last steps so that we can merge it into master.

Thanks a lot for your work! It will help a lot to accelerate training with FARM.

If you have other suggestions / ideas for FARM: always happy to discuss :)

from torch.nn import DataParallel

# Used indirectly in _get_optim() to avoid name collision with torch's AdamW
from transformers.optimization import AdamW as TransformersAdamW
Copy link
Contributor Author

@BramVanroy BramVanroy Dec 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you don't default to torch's AdamW when available and otherwise use the transformer implementation? They do have a slightly different signature though. (Sorry for lurking.) Also, since the minimal requirement now is torch 1.2, AdamW should always be available from torch anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, for this late reply! I didn't see your comment here and I was busy with other topics the last weeks. I haven't investigated the differences between AdamW in transformers and pytorch in detail. One difference however is the correct_bias option. This is only implemented in transformers and should be set to False for the original BERT implementation. I stumbled across it when trying to replicate some of our earlier experiments on CONLL using this branch here. Interestingly, it had quite some big effect on the early steps of training. So for the sake of reproducibility / backwards compatibility, I would suggest using AdamW from transformers as a default in FARM.

@tholor
Copy link
Member

tholor commented Dec 13, 2019

Sorry for the delay here. Totally my bad. I have this branch still on my todo list and hope to find some time next week.
It's not forgotten and will definitely find it's way into FARM :)

result = model.extract_vectors(dicts=basic_texts)
assert result[0]["context"] == ['Farmer', "'", 's', 'life', 'is', 'great', '.']
assert result[0]["vec"].shape == (768,)
# TODO check why results vary accross runs with same seed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you set all the seeds? This is quite tricky. What works best for me is the following.

    def set_seed(seed: Optional[int]):
        """ Set all seeds to make results reproducible (deterministic mode).
             When seed is None, disables deterministic mode. """
        if seed is not None:
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            np.random.seed(seed)
            random.seed(seed)
            os.environ['PYTHONHASHSEED'] = str(seed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tholor Tagging you so you get a notifiation. ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the hint! We don't use the torch.backends.cudnn flags + PYTHONHASHSEED yet. I will add them here and see if this makes a difference.

@tholor
Copy link
Member

tholor commented Dec 20, 2019

I did a few quick performance checks comparing amp optimization levels none, O0, O1.

  1. NER on conll2003:
  • same train, dev & test performance
  • O1 makes full run (incl. preprocessing) about 35 % faster (7.3 min vs. 11.4min)
  1. Text Classification on GermEval2018-Fine:
  • slightly worse performance with O1 (test f1: 41.7 vs 43.9)
  • 38% faster with O1

Seems all good to me. Speed improvement will of course be way better on longer training jobs (where preprocessing causes less of the total run time). Also we should be able to fit larger batch sizes into memory which will further speed things up :)

@tholor tholor merged commit e632371 into deepset-ai:master Dec 20, 2019
@Timoeller
Copy link
Contributor

This looks like quite some improvement for making FARM work faster on GPUs. Really looking forward doing more performance benchmarks with mixed or lower precision, for training but especially for inferencing large datasets, too.
Nice work guys!

@BramVanroy
Copy link
Contributor Author

BramVanroy commented Dec 20, 2019

@Timoeller It really is a big difference in speed (especially if you have tensor cores). I'm a bit disappointed in seeing that there's an F1 difference of more than two points. That's quite a lot. I am wondering whether this might be a non-determinism issue, i.e. caused by randomness.

That being said, it looks like AMP is going upstream so that's nice. It'll just work out of the box! NVIDIA/apex#381 (comment)

Thanks for the additional work and tests @tholor. Glad I could lend an initial hand.

@BramVanroy BramVanroy deleted the amp branch December 20, 2019 17:23
maknotavailable added a commit to maknotavailable/FARM that referenced this pull request Dec 21, 2019
Was missing the full integration of the updated initialize_optimizer() here.
This was referenced Jan 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request part: optimizer Optimizer part: trainer Trainer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Apex.amp support
3 participants