Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding te Linear for fp8 support #271

Closed
wants to merge 17 commits into from
Closed

Conversation

vchiley
Copy link
Contributor

@vchiley vchiley commented Jun 2, 2023

ran composer train/train.py train/yamls/pretrain/mpt-3b.yaml also with model.fc_type=te and precision=amp_fp8
Result:

torch: throughput/device/tokens_per_sec: 23.7k
te: throughput/device/tokens_per_sec: 23.7k
te with fp8: throughput/device/tokens_per_sec: 29.4k

Note there does seem to be this error when activation ckpt is enabled when activation_checkpointing_reentrant: false. If we set activation_checkpointing_reentrant: true, then act ckpt works fine without amp_fp8; with amp_fp8 the issue still persists.
(previously, circa summer 2022, activation_checkpointing_reentrant: true resulted in some difficulties which is why we set it to false; not sure if this is necessary still...)
ActCkpt error will be added to the comments.
(the error with amp_fp8 might be a composer impl of fp8 issue)

@vchiley
Copy link
Contributor Author

vchiley commented Jun 2, 2023

activation_checkpointing_reentrant: false actckpt error without amp_fp8

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:254 in <module>                                          │
│                                                                                                  │
│   251 │   │   yaml_cfg = om.load(f)                                                              │
│   252 │   cli_cfg = om.from_cli(args_list)                                                       │
│   253 │   cfg = om.merge(yaml_cfg, cli_cfg)                                                      │
│ ❱ 254 │   main(cfg)                                                                              │
│   255                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/scripts/train/train.py:243 in main                                              │
│                                                                                                  │
│   240 │   │   trainer.eval()                                                                     │
│   241 │                                                                                          │
│   242 │   print('Starting training...')                                                          │
│ ❱ 243 │   trainer.fit()                                                                          │
│   244 │                                                                                          │
│   245 │   print('Done.')                                                                         │
│   246                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit       │
│                                                                                                  │
│   1763 │   │   │   self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca  │
│   1764 │   │                                                                                     │
│   1765 │   │   self.first_batch_complete = False                                                 │
│ ❱ 1766 │   │   self._train_loop()                                                                │
│   1767 │                                                                                         │
│   1768 │   def close(self):                                                                      │
│   1769 │   │   """Shutdown the trainer.                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in           │
│ _train_loop                                                                                      │
│                                                                                                  │
│   1937 │   │   │   │   │   │   self.logger.log_metrics({'time/token': self.state.timestamp.toke  │
│   1938 │   │   │   │   │   │   self.logger.log_metrics({'time/token_in_epoch': self.state.times  │
│   1939 │   │   │   │   │                                                                         │
│ ❱ 1940 │   │   │   │   │   total_loss_dict = self._train_batch(use_grad_scaling)                 │
│   1941 │   │   │   │   │                                                                         │
│   1942 │   │   │   │   │   if use_grad_scaling:                                                  │
│   1943 │   │   │   │   │   │   self.state.scaler.update()                                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in           │
│ _train_batch                                                                                     │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper                         │
│                                                                                                  │
│     66 │   │   │   │   instance = instance_ref()                                                 │
│     67 │   │   │   │   instance._step_count += 1                                                 │
│     68 │   │   │   │   wrapped = func.__get__(instance, cls)                                     │
│ ❱   69 │   │   │   │   return wrapped(*args, **kwargs)                                           │
│     70 │   │   │                                                                                 │
│     71 │   │   │   # Note that the returned function here is no longer a bound method,           │
│     72 │   │   │   # so attributes like `__func__` and `__self__` no longer exist.               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper                           │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context                │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288  │
│ in step                                                                                          │
│                                                                                                  │
│   285 │   │   loss = None                                                                        │
│   286 │   │   if closure is not None:                                                            │
│   287 │   │   │   with torch.enable_grad():                                                      │
│ ❱ 288 │   │   │   │   loss = closure()                                                           │
│   289 │   │                                                                                      │
│   290 │   │   for group in self.param_groups:                                                    │
│   291 │   │   │   params_with_grad = []                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda>  │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in           │
│ _train_microbatches                                                                              │
│                                                                                                  │
│   2219 │   │   │                                                                                 │
│   2220 │   │   │   for microbatch_idx, self.state.batch in enumerate(microbatches):              │
│   2221 │   │   │   │   is_final_microbatch = microbatch_idx + 1 == len(microbatches)             │
│ ❱ 2222 │   │   │   │   microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_  │
│   2223 │   │   │   │                                                                             │
│   2224 │   │   │   │   # Aggregate each loss in microbatch_loss_dict into total_loss_dict        │
│   2225 │   │   │   │   for k, microbatch_loss in microbatch_loss_dict.items():                   │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in           │
│ _train_microbatch                                                                                │
│                                                                                                  │
│   2346 │   │   │   else:                                                                         │
│   2347 │   │   │   │   # Scale loss based on the number of samples in the microbatch to maintai  │
│   2348 │   │   │   │   microbatch_loss.mul_(microbatch_num_samples / current_batch_size)         │
│ ❱ 2349 │   │   │   │   microbatch_loss.backward(create_graph=self._backwards_create_graph)       │
│   2350 │   │   │                                                                                 │
│   2351 │   │   │   self.engine.run_event(Event.AFTER_BACKWARD)                                   │
│   2352                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward                                  │
│                                                                                                  │
│    484 │   │   │   │   create_graph=create_graph,                                                │
│    485 │   │   │   │   inputs=inputs,                                                            │
│    486 │   │   │   )                                                                             │
│ ❱  487 │   │   torch.autograd.backward(                                                          │
│    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                     │
│    489 │   │   )                                                                                 │
│    490                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward                        │
│                                                                                                  │
│   197 │   # The reason we repeat same the comment below is that                                  │
│   198 │   # some Python versions print out the first line of a multi-line function               │
│   199 │   # calls in the traceback and some print out the last line                              │
│ ❱ 200 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   │
│   201 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        │
│   202 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   │
│   203                                                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply                           │
│                                                                                                  │
│   271 │   │   │   │   │   │   │      "Function is not allowed. You should only implement one "   │
│   272 │   │   │   │   │   │   │      "of them.")                                                 │
│   273 │   │   user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn                    │
│ ❱ 274 │   │   return user_fn(self, *args)                                                        │
│   275 │                                                                                          │
│   276 │   def apply_jvp(self, *args):                                                            │
│   277 │   │   # _forward_cls is defined by derived class                                         │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in  │
│ backward                                                                                         │
│                                                                                                  │
│   1862 │   │   │   │   weight,                                                                   │
│   1863 │   │   │   │   weight_t_fp8,                                                             │
│   1864 │   │   │   │   fwd_scale_inverses,                                                       │
│ ❱ 1865 │   │   │   ) = ctx.saved_tensors                                                         │
│   1866 │   │   │                                                                                 │
│   1867 │   │   │   if ctx.ub_split_ag:                                                           │
│   1868 │   │   │   │   tp_world_size = get_distributed_world_size(ctx.tp_group)                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: !grad_accumulator_.expired() INTERNAL ASSERT FAILED at "../torch/csrc/autograd/saved_variable.cpp":226, please report a bug to PyTorch. No 
grad accumulator for a saved leaf

@vchiley
Copy link
Contributor Author

vchiley commented Jun 2, 2023

actckpt error with amp_fp8 (with activation_checkpointing_reentrant: false or true)

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/llm-foundry/scripts/train/train.py:260 in <module>                                          │
│                                                                                                  │
│   257 │   │   yaml_cfg = om.load(f)                                                              │
│   258 │   cli_cfg = om.from_cli(args_list)                                                       │
│   259 │   cfg = om.merge(yaml_cfg, cli_cfg)                                                      │
│ ❱ 260 │   main(cfg)                                                                              │
│   261                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/scripts/train/train.py:249 in main                                              │
│                                                                                                  │
│   246 │   │   trainer.eval()                                                                     │
│   247 │                                                                                          │
│   248 │   print('Starting training...')                                                          │
│ ❱ 249 │   trainer.fit()                                                                          │
│   250 │                                                                                          │
│   251 │   print('Done.')                                                                         │
│   252                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1766 in fit       │
│                                                                                                  │
│   1763 │   │   │   self.state.scaler = ClosureGradScaler() if self._use_closures() else GradSca  │
│   1764 │   │                                                                                     │
│   1765 │   │   self.first_batch_complete = False                                                 │
│ ❱ 1766 │   │   self._train_loop()                                                                │
│   1767 │                                                                                         │
│   1768 │   def close(self):                                                                      │
│   1769 │   │   """Shutdown the trainer.                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:1940 in           │
│ _train_loop                                                                                      │
│                                                                                                  │
│   1937 │   │   │   │   │   │   self.logger.log_metrics({'time/token': self.state.timestamp.toke  │
│   1938 │   │   │   │   │   │   self.logger.log_metrics({'time/token_in_epoch': self.state.times  │
│   1939 │   │   │   │   │                                                                         │
│ ❱ 1940 │   │   │   │   │   total_loss_dict = self._train_batch(use_grad_scaling)                 │
│   1941 │   │   │   │   │                                                                         │
│   1942 │   │   │   │   │   if use_grad_scaling:                                                  │
│   1943 │   │   │   │   │   │   self.state.scaler.update()                                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in           │
│ _train_batch                                                                                     │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/lr_scheduler.py:69 in wrapper                         │
│                                                                                                  │
│     66 │   │   │   │   instance = instance_ref()                                                 │
│     67 │   │   │   │   instance._step_count += 1                                                 │
│     68 │   │   │   │   wrapped = func.__get__(instance, cls)                                     │
│ ❱   69 │   │   │   │   return wrapped(*args, **kwargs)                                           │
│     70 │   │   │                                                                                 │
│     71 │   │   │   # Note that the returned function here is no longer a bound method,           │
│     72 │   │   │   # so attributes like `__func__` and `__self__` no longer exist.               │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/optim/optimizer.py:280 in wrapper                           │
│                                                                                                  │
│   277 │   │   │   │   │   │   │   raise RuntimeError(f"{func} must return None or a tuple of (   │
│   278 │   │   │   │   │   │   │   │   │   │   │      f"but got {result}.")                       │
│   279 │   │   │   │                                                                              │
│ ❱ 280 │   │   │   │   out = func(*args, **kwargs)                                                │
│   281 │   │   │   │   self._optimizer_step_code()                                                │
│   282 │   │   │   │                                                                              │
│   283 │   │   │   │   # call optimizer step post hooks                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/_contextlib.py:115 in decorate_context                │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/optim/decoupled_weight_decay.py:288  │
│ in step                                                                                          │
│                                                                                                  │
│   285 │   │   loss = None                                                                        │
│   286 │   │   if closure is not None:                                                            │
│   287 │   │   │   with torch.enable_grad():                                                      │
│ ❱ 288 │   │   │   │   loss = closure()                                                           │
│   289 │   │                                                                                      │
│   290 │   │   for group in self.param_groups:                                                    │
│   291 │   │   │   params_with_grad = []                                                          │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2124 in <lambda>  │
│                                                                                                  │
│   2121 │   │   │   │   │   │   │   │   │   │   │   │      closure=lambda loss_dict=total_loss_d  │
│   2122 │   │   │   │   │   │   │   │   │   │   │   │      _train_microbatches(microbatches, los  │
│   2123 │   │   │   │   │   │   else:                                                             │
│ ❱ 2124 │   │   │   │   │   │   │   optimizer.step(closure=lambda **kwargs: self._train_microbat  │
│   2125 │   │   │   │   │   │   │   │   microbatches, total_loss_dict, **kwargs).item())          │
│   2126 │   │   │   │   else:                                                                     │
│   2127 │   │   │   │   │   self._train_microbatches(microbatches, total_loss_dict)               │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2222 in           │
│ _train_microbatches                                                                              │
│                                                                                                  │
│   2219 │   │   │                                                                                 │
│   2220 │   │   │   for microbatch_idx, self.state.batch in enumerate(microbatches):              │
│   2221 │   │   │   │   is_final_microbatch = microbatch_idx + 1 == len(microbatches)             │
│ ❱ 2222 │   │   │   │   microbatch_loss_dict = self._train_microbatch(use_grad_scaling, current_  │
│   2223 │   │   │   │                                                                             │
│   2224 │   │   │   │   # Aggregate each loss in microbatch_loss_dict into total_loss_dict        │
│   2225 │   │   │   │   for k, microbatch_loss in microbatch_loss_dict.items():                   │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/composer/trainer/trainer.py:2349 in           │
│ _train_microbatch                                                                                │
│                                                                                                  │
│   2346 │   │   │   else:                                                                         │
│   2347 │   │   │   │   # Scale loss based on the number of samples in the microbatch to maintai  │
│   2348 │   │   │   │   microbatch_loss.mul_(microbatch_num_samples / current_batch_size)         │
│ ❱ 2349 │   │   │   │   microbatch_loss.backward(create_graph=self._backwards_create_graph)       │
│   2350 │   │   │                                                                                 │
│   2351 │   │   │   self.engine.run_event(Event.AFTER_BACKWARD)                                   │
│   2352                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/_tensor.py:487 in backward                                  │
│                                                                                                  │
│    484 │   │   │   │   create_graph=create_graph,                                                │
│    485 │   │   │   │   inputs=inputs,                                                            │
│    486 │   │   │   )                                                                             │
│ ❱  487 │   │   torch.autograd.backward(                                                          │
│    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                     │
│    489 │   │   )                                                                                 │
│    490                                                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/__init__.py:200 in backward                        │
│                                                                                                  │
│   197 │   # The reason we repeat same the comment below is that                                  │
│   198 │   # some Python versions print out the first line of a multi-line function               │
│   199 │   # calls in the traceback and some print out the last line                              │
│ ❱ 200 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   │
│   201 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        │
│   202 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   │
│   203                                                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/autograd/function.py:274 in apply                           │
│                                                                                                  │
│   271 │   │   │   │   │   │   │      "Function is not allowed. You should only implement one "   │
│   272 │   │   │   │   │   │   │      "of them.")                                                 │
│   273 │   │   user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn                    │
│ ❱ 274 │   │   return user_fn(self, *args)                                                        │
│   275 │                                                                                          │
│   276 │   def apply_jvp(self, *args):                                                            │
│   277 │   │   # _forward_cls is defined by derived class                                         │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:1865 in  │
│ backward                                                                                         │
│                                                                                                  │
│   1862 │   │   │   │   weight,                                                                   │
│   1863 │   │   │   │   weight_t_fp8,                                                             │
│   1864 │   │   │   │   fwd_scale_inverses,                                                       │
│ ❱ 1865 │   │   │   ) = ctx.saved_tensors                                                         │
│   1866 │   │   │                                                                                 │
│   1867 │   │   │   if ctx.ub_split_ag:                                                           │
│   1868 │   │   │   │   tp_world_size = get_distributed_world_size(ctx.tp_group)                  │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/utils/checkpoint.py:420 in unpack                           │
│                                                                                                  │
│   417 │   │   │   │   │    torch.cuda.amp.autocast(**gpu_autocast_kwargs), \                     │
│   418 │   │   │   │   │    torch.cpu.amp.autocast(**cpu_autocast_kwargs), \                      │
│   419 │   │   │   │   │    torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):   │
│ ❱ 420 │   │   │   │   │   _unused = function(*args, **kwargs)                                    │
│   421 │   │                                                                                      │
│   422 │   │   if x not in storage:                                                               │
│   423 │   │   │   raise RuntimeError(                                                            │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/llmfoundry/models/layers/blocks.py:110 in forward                               │
│                                                                                                  │
│   107 │   │   is_causal: bool = True,                                                            │
│   108 │   ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:                               │
│   109 │   │   a = self.norm_1(x)                                                                 │
│ ❱ 110 │   │   b, attn_weights, past_key_value = self.attn(                                       │
│   111 │   │   │   a,                                                                             │
│   112 │   │   │   past_key_value=past_key_value,                                                 │
│   113 │   │   │   attn_bias=attn_bias,                                                           │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/llmfoundry/models/layers/attention.py:423 in forward                            │
│                                                                                                  │
│   420 │   │   is_causal=True,                                                                    │
│   421 │   │   needs_weights=False,                                                               │
│   422 │   ):                                                                                     │
│ ❱ 423 │   │   qkv = self.Wqkv(x)                                                                 │
│   424 │   │                                                                                      │
│   425 │   │   if self.clip_qkv:                                                                  │
│   426 │   │   │   qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)                              │
│                                                                                                  │
│ /usr/lib/python3/dist-packages/torch/nn/modules/module.py:1501 in _call_impl                     │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:2267 in  │
│ forward                                                                                          │
│                                                                                                  │
│   2264 │   │   │   │   │   │   │      produced)                                                  │
│   2265 │   │   """                                                                               │
│   2266 │   │                                                                                     │
│ ❱ 2267 │   │   with self.prepare_forward(inp, is_first_microbatch) as inp:                       │
│   2268 │   │   │   bias_tensor = (                                                               │
│   2269 │   │   │   │   bias if bias is not None                                                  │
│   2270 │   │   │   │   else self.bias if self.parameters_split is None                           │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:135 in __enter__                                               │
│                                                                                                  │
│   132 │   │   # they are only needed for recreation, which is not possible anymore               │
│   133 │   │   del self.args, self.kwds, self.func                                                │
│   134 │   │   try:                                                                               │
│ ❱ 135 │   │   │   return next(self.gen)                                                          │
│   136 │   │   except StopIteration:                                                              │
│   137 │   │   │   raise RuntimeError("generator didn't yield") from None                         │
│   138                                                                                            │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/module.py:632 in   │
│ prepare_forward                                                                                  │
│                                                                                                  │
│    629 │   │   │   │   │   self.fp8_meta["autocast_id_fwd_stack"].append(                        │
│    630 │   │   │   │   │   │   self.fp8_meta["autocast_id_fwd"]                                  │
│    631 │   │   │   │   │   )                                                                     │
│ ❱  632 │   │   │   │   │   add_amax_to_global_buffer(self.fp8_meta, forward=True)                │
│    633 │   │   │   │   self.fp8_meta["update_amax_and_scale_fwd"] = True                         │
│    634 │   │   │   else:                                                                         │
│    635 │   │   │   │   self.fp8_meta["update_amax_and_scale_fwd"] = False                        │
│                                                                                                  │
│ /mnt/llm-foundry/venv/lib/python3.10/site-packages/transformer_engine/pytorch/fp8.py:134 in      │
│ add_amax_to_global_buffer                                                                        │
│                                                                                                  │
│   131 │   │   fp8_meta[buffer_position_key] = len(_global_fp8_buffer[buffer_key]) - 1            │
│   132 │                                                                                          │
│   133 │   # Catch incorrect fp8_autocast usage.                                                  │
│ ❱ 134 │   assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \     │
│   135 │   │   "Same module is being invoked more than once inside an `fp8_autocast` region whe   │
│   136 │   │   "FP8 with amax reduction. This behavior is currently unsupported. For more detai   │
│   137 │   │   "correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93."   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently 
unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.

@abhi-mosaic
Copy link
Member

abhi-mosaic commented Jun 2, 2023

Re. activation_checkpointing_reentrant: false, I believe this is the new style of ActCkpt that torch recommends and may even be default going forward. It also enabled us to do layer freezing if I remember correctly. But it's not necessary, and if there's some torch2/fp8 bug that requires activation_checkpointing_reentrant: true, I think that's fine.

Edit: actually re-read your comments, and I guess we are still blocked on ActCkpt + amp_fp8. Will read the logs carefully now...

@abhi-mosaic
Copy link
Member

abhi-mosaic commented Jun 2, 2023

There seems to be a bug referenced in this PR (NVIDIA/TransformerEngine#93) that was fixed in this PR (NVIDIA/TransformerEngine#187) that is available on main but not any release.

Could you install TE @ main and try? 🙏

@vchiley
Copy link
Contributor Author

vchiley commented Jun 2, 2023

Note: transformer_engine has its own ckpt util which might need to be integrated into composer (?) for fp8 to work with act ckpt???

@vchiley
Copy link
Contributor Author

vchiley commented Jun 2, 2023

TE @ main requires flash-attn==1.0.6
flash-attn==1.0.6 has this issue

Solution: add --no-build-isolation to pip install

@vchiley
Copy link
Contributor Author

vchiley commented Jun 2, 2023

installing TE from main (as suggested by @abhi-mosaic) makes our integrated act ckpt work; no need to integrate TE act ckpt
activation_checkpointing_reentrant: false still broken

bmosaicml pushed a commit that referenced this pull request Jun 6, 2023
* Add a callback that logs generations to wandb at eval end (#265)

* updt

* add 40gb tput

* Update examples/llm/throughput/README.md

Co-authored-by: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com>

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
Co-authored-by: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com>
@vchiley
Copy link
Contributor Author

vchiley commented Jun 16, 2023

CE per param
Screenshot 2023-06-16 at 2 29 15 PM

TFLOPS per param
Screenshot 2023-06-16 at 2 29 30 PM

Note: 3B model uses act chpt so its model TFLOPS is multiplied by 0.75.

slightly more here

@vchiley vchiley self-assigned this Jun 16, 2023
@vchiley vchiley requested a review from abhi-mosaic June 16, 2023 21:31
@vchiley vchiley marked this pull request as ready for review June 16, 2023 21:31
@vchiley vchiley mentioned this pull request Jun 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants