From 0518be721d7c1176ee0e164ec89988889419ec17 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sat, 10 Feb 2024 15:18:32 -0500 Subject: [PATCH 01/20] Adding times --- src/refiners/training_utils/trainer.py | 36 ++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 433d10b46..49d5641d5 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -95,6 +95,22 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> Any: return decorator +# Ported from open-muse +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.avg: float = 0 + self.sum: float = 0 + self.count: int = 0 + + def update(self, val: float): + self.sum += val + self.count += 1 + self.avg = self.sum / self.count class WarmupScheduler(LRScheduler): _step_count: int # defined by LRScheduler @@ -278,6 +294,10 @@ def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = N gradient_accumulation=config.training.gradient_accumulation, lr_scheduler_interval=config.scheduler.update_interval, ) + self.batch_time_m = AverageMeter() + self.forward_time_m = AverageMeter() + self.backprop_time_m = AverageMeter() + self.data_time_m = AverageMeter() self.callbacks = callbacks or [] self.callbacks += self.default_callbacks() self._call_callbacks(event_name="on_init_begin") @@ -539,22 +559,34 @@ def backward(self) -> None: if self.clock.is_evaluation_step: self.evaluate() - def step(self, batch: Batch) -> None: + def step(self, batch: Batch) -> tuple[float, float]: """Perform a single training step.""" + start = time.time() self._call_callbacks(event_name="on_compute_loss_begin") loss = self.compute_loss(batch=batch) self.loss = loss + forward_time = time.time()-start + self.forward_time_m.update(forward_time) + start = time.time() self._call_callbacks(event_name="on_compute_loss_end") self.backward() + backward_time = time.time()-start + self.backprop_time_m.update(backward_time) + return forward_time, backward_time def epoch(self) -> None: """Perform a single epoch.""" + start = time.time() for batch in self.dataloader: if self.clock.done: break self._call_callbacks(event_name="on_batch_begin") - self.step(batch=batch) + data_time = time.time()-start + self.data_time_m.update(data_time) + forward_time, backward_time = self.step(batch=batch) self._call_callbacks(event_name="on_batch_end") + batch_time = data_time+forward_time+backward_time + self.batch_time_m.update(batch_time) @staticmethod def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int: From f38db1cc53680859045ab82adbdc2ea6e279e3cb Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sat, 10 Feb 2024 15:23:30 -0500 Subject: [PATCH 02/20] Ported from dino ip --- src/refiners/training_utils/callback.py | 17 +++++++++++++++++ src/refiners/training_utils/trainer.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 2d7c0b27b..396df241b 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -150,6 +150,23 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: logger.info("Evaluation ended.") +class MonitorTime(Callback["Trainer[BaseConfig, Any]"]): + def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + batch_time, forward_time, backprop_time, data_time = ( + trainer.batch_time_m.avg, + trainer.forward_time_m.avg, + trainer.backprop_time_m.avg, + trainer.data_time_m.avg, + ) + if trainer.clock.is_evaluation_step: + trainer.log( + data={ + "batch_time": batch_time, + "forward_time": forward_time, + "backprop_time": backprop_time, + "data_time": data_time, + } + ) class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 49d5641d5..28c7454d4 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -33,6 +33,7 @@ ClockCallback, GradientNormClipping, GradientValueClipping, + MonitorTime, ) from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback @@ -311,6 +312,7 @@ def default_callbacks(self) -> list[Callback[Any]]: GradientValueClipping(), GradientNormClipping(), DropoutCallback(), + MonitorTime() ] # look for any Callback that might be a property of the Trainer From 37f593a381331e65e6e995003665faa3b9bb9d19 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Sat, 10 Feb 2024 15:30:36 -0500 Subject: [PATCH 03/20] Style fix --- src/refiners/fluxion/utils.py | 3 +-- .../latent_diffusion/stable_diffusion_1/unet.py | 3 ++- src/refiners/training_utils/callback.py | 2 ++ src/refiners/training_utils/trainer.py | 12 +++++++----- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 7184a4af1..6999acc7a 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -229,8 +229,7 @@ def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore assert isinstance(tensors, dict) and all( - isinstance(key, str) and isinstance(value, Tensor) - for key, value in tensors.items() # type: ignore + isinstance(key, str) and isinstance(value, Tensor) for key, value in tensors.items() # type: ignore ), "Invalid tensor file, expected a dict[str, Tensor]" return cast(dict[str, Tensor], tensors) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 25e750219..217eebeb8 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -165,7 +165,8 @@ def __init__(self, device: Device | str | None = None, dtype: DType | None = Non class SD1UNet(fl.Chain): """Stable Diffusion 1.5 U-Net. - See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details.""" + See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details. + """ def __init__( self, diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 396df241b..e47eb6911 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -150,6 +150,7 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: logger.info("Evaluation ended.") + class MonitorTime(Callback["Trainer[BaseConfig, Any]"]): def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: batch_time, forward_time, backprop_time, data_time = ( @@ -168,6 +169,7 @@ def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: } ) + class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: clip_norm = trainer.config.training.clip_grad_norm diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 28c7454d4..805a40427 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -96,6 +96,7 @@ def inner_wrapper(*args: Any, **kwargs: Any) -> Any: return decorator + # Ported from open-muse class AverageMeter(object): """Computes and stores the average and current value""" @@ -113,6 +114,7 @@ def update(self, val: float): self.count += 1 self.avg = self.sum / self.count + class WarmupScheduler(LRScheduler): _step_count: int # defined by LRScheduler @@ -312,7 +314,7 @@ def default_callbacks(self) -> list[Callback[Any]]: GradientValueClipping(), GradientNormClipping(), DropoutCallback(), - MonitorTime() + MonitorTime(), ] # look for any Callback that might be a property of the Trainer @@ -567,12 +569,12 @@ def step(self, batch: Batch) -> tuple[float, float]: self._call_callbacks(event_name="on_compute_loss_begin") loss = self.compute_loss(batch=batch) self.loss = loss - forward_time = time.time()-start + forward_time = time.time() - start self.forward_time_m.update(forward_time) start = time.time() self._call_callbacks(event_name="on_compute_loss_end") self.backward() - backward_time = time.time()-start + backward_time = time.time() - start self.backprop_time_m.update(backward_time) return forward_time, backward_time @@ -583,11 +585,11 @@ def epoch(self) -> None: if self.clock.done: break self._call_callbacks(event_name="on_batch_begin") - data_time = time.time()-start + data_time = time.time() - start self.data_time_m.update(data_time) forward_time, backward_time = self.step(batch=batch) self._call_callbacks(event_name="on_batch_end") - batch_time = data_time+forward_time+backward_time + batch_time = data_time + forward_time + backward_time self.batch_time_m.update(batch_time) @staticmethod From f437454ed2ff1b2ec148822f430208aa02dbdf8f Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 16 Feb 2024 12:14:59 -0500 Subject: [PATCH 04/20] Fixed to wandb folder --- src/refiners/training_utils/callback.py | 20 -------------------- src/refiners/training_utils/wandb.py | 18 +++++++++++++++++- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index e47eb6911..8e80c520b 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -150,26 +150,6 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: logger.info("Evaluation ended.") - -class MonitorTime(Callback["Trainer[BaseConfig, Any]"]): - def on_batch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - batch_time, forward_time, backprop_time, data_time = ( - trainer.batch_time_m.avg, - trainer.forward_time_m.avg, - trainer.backprop_time_m.avg, - trainer.data_time_m.avg, - ) - if trainer.clock.is_evaluation_step: - trainer.log( - data={ - "batch_time": batch_time, - "forward_time": forward_time, - "backprop_time": backprop_time, - "data_time": data_time, - } - ) - - class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: clip_norm = trainer.config.training.clip_grad_norm diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 6422803ce..448594822 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -103,7 +103,23 @@ def on_init_begin(self, trainer: "TrainerWithWandb") -> None: def on_train_begin(self, trainer: "TrainerWithWandb") -> None: self.epoch_losses = [] self.iteration_losses = [] - + def on_batch_end(self, trainer: "TrainerWithWandb") -> None: + batch_time, forward_time, backprop_time, data_time = ( + trainer.batch_time_m.avg, + trainer.forward_time_m.avg, + trainer.backprop_time_m.avg, + trainer.data_time_m.avg, + ) + if trainer.clock.is_evaluation_step: + effective_batch_size = trainer.clock.batch_size * trainer.clock.num_step_per_iteration + trainer.wandb_log( + data={ + "batch_time": batch_time / effective_batch_size, + "forward_time": forward_time / effective_batch_size, + "backprop_time": backprop_time / effective_batch_size, + "data_time": data_time / effective_batch_size, + } + ) def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None: loss_value = trainer.loss.detach().cpu().item() self.epoch_losses.append(loss_value) From e158d1b65748592a4d0a5ef0cd3713ddb3f89629 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 16 Feb 2024 12:15:37 -0500 Subject: [PATCH 05/20] Style fix --- src/refiners/training_utils/callback.py | 1 + src/refiners/training_utils/wandb.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/refiners/training_utils/callback.py b/src/refiners/training_utils/callback.py index 8e80c520b..2d7c0b27b 100644 --- a/src/refiners/training_utils/callback.py +++ b/src/refiners/training_utils/callback.py @@ -150,6 +150,7 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: logger.info("Evaluation ended.") + class GradientNormClipping(Callback["Trainer[BaseConfig, Any]"]): def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: clip_norm = trainer.config.training.clip_grad_norm diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 448594822..9441ccfd8 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -103,6 +103,7 @@ def on_init_begin(self, trainer: "TrainerWithWandb") -> None: def on_train_begin(self, trainer: "TrainerWithWandb") -> None: self.epoch_losses = [] self.iteration_losses = [] + def on_batch_end(self, trainer: "TrainerWithWandb") -> None: batch_time, forward_time, backprop_time, data_time = ( trainer.batch_time_m.avg, @@ -120,6 +121,7 @@ def on_batch_end(self, trainer: "TrainerWithWandb") -> None: "data_time": data_time / effective_batch_size, } ) + def on_compute_loss_end(self, trainer: "TrainerWithWandb") -> None: loss_value = trainer.loss.detach().cpu().item() self.epoch_losses.append(loss_value) From 7a48c6dd9c7f2513a7cde06a8cba53d00fbb8477 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 16 Feb 2024 12:16:41 -0500 Subject: [PATCH 06/20] Removing monitor time callback --- src/refiners/training_utils/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 805a40427..1534bf3ba 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -33,7 +33,6 @@ ClockCallback, GradientNormClipping, GradientValueClipping, - MonitorTime, ) from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback @@ -314,7 +313,6 @@ def default_callbacks(self) -> list[Callback[Any]]: GradientValueClipping(), GradientNormClipping(), DropoutCallback(), - MonitorTime(), ] # look for any Callback that might be a property of the Trainer From f1d14918b8aca3977f80768ccbeec7a4b16b4f57 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 16 Feb 2024 12:21:05 -0500 Subject: [PATCH 07/20] Fix import error --- src/refiners/training_utils/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 081ced0df..e859fb519 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Generic, Literal, TypeVar, cast import torch +import time from loguru import logger from torch import Tensor, device as Device, dtype as DType, nn from torch.autograd import backward From f15bdb3cc07adf86002cae64cea4b49ae5218cbe Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 16 Feb 2024 15:00:06 -0500 Subject: [PATCH 08/20] Fixed 2 minor bugs --- src/refiners/training_utils/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index e859fb519..aabbe8d7d 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -364,7 +364,7 @@ def compute_loss(self, batch: Batch) -> Tensor: def compute_evaluation(self) -> None: pass - def backward(self) -> None: + def backward(self, start: float) -> float: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") scaled_loss = self.loss / self.clock.num_step_per_iteration @@ -379,8 +379,10 @@ def backward(self) -> None: self._call_callbacks(event_name="on_lr_scheduler_step_begin") self.lr_scheduler.step() self._call_callbacks(event_name="on_lr_scheduler_step_end") + backward_time = time.time()-start if self.clock.is_evaluation_step: self.evaluate() + return backward_time def step(self, batch: Batch) -> tuple[float, float]: """Perform a single training step.""" @@ -392,8 +394,7 @@ def step(self, batch: Batch) -> tuple[float, float]: self.forward_time_m.update(forward_time) start = time.time() self._call_callbacks(event_name="on_compute_loss_end") - self.backward() - backward_time = time.time() - start + backward_time = self.backward(start) self.backprop_time_m.update(backward_time) return forward_time, backward_time @@ -410,6 +411,7 @@ def epoch(self) -> None: self._call_callbacks(event_name="on_batch_end") batch_time = data_time + forward_time + backward_time self.batch_time_m.update(batch_time) + start = time.time() @staticmethod def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int: From 8ab7629fa3763b1075fff104e7ef3f874974c4e1 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:04:20 -0500 Subject: [PATCH 09/20] Remove unnecessary diffs --- src/refiners/fluxion/utils.py | 3 ++- .../foundationals/latent_diffusion/stable_diffusion_1/unet.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 6999acc7a..7184a4af1 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -229,7 +229,8 @@ def load_tensors(path: Path | str, /, device: Device | str = "cpu") -> dict[str, tensors = torch.load(path, map_location=device, weights_only=True) # type: ignore assert isinstance(tensors, dict) and all( - isinstance(key, str) and isinstance(value, Tensor) for key, value in tensors.items() # type: ignore + isinstance(key, str) and isinstance(value, Tensor) + for key, value in tensors.items() # type: ignore ), "Invalid tensor file, expected a dict[str, Tensor]" return cast(dict[str, Tensor], tensors) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 217eebeb8..25e750219 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -165,8 +165,7 @@ def __init__(self, device: Device | str | None = None, dtype: DType | None = Non class SD1UNet(fl.Chain): """Stable Diffusion 1.5 U-Net. - See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details. - """ + See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details.""" def __init__( self, From cf8e7579b708f16eee06eac0dfb25b831388a53d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:31:53 -0500 Subject: [PATCH 10/20] Cleaned up --- src/refiners/training_utils/config.py | 2 + src/refiners/training_utils/trainer.py | 52 +++++--------------------- src/refiners/training_utils/wandb.py | 8 ++-- 3 files changed, 16 insertions(+), 46 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 4098ea01f..fcddd7e66 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -13,6 +13,7 @@ from refiners.training_utils.clock import ClockConfig from refiners.training_utils.common import TimeUnit, TimeValue, parse_number_unit_field from refiners.training_utils.gradient_clipping import GradientClippingConfig +from refiners.training_utils.callback import CallbackConfig # PyTorch optimizer parameters type # TODO: replace with `from torch.optim.optimizer import ParamsT` when PyTorch 2.2+ is enforced @@ -167,6 +168,7 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() + timer: CallbackConfig gradient_clipping: GradientClippingConfig = GradientClippingConfig() model_config = ConfigDict(extra="forbid") diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index aabbe8d7d..0fb6c4769 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -31,6 +31,7 @@ CallbackConfig, ) from refiners.training_utils.clock import ClockConfig, TrainingClock +from refiners.training_utils.timer import TrainingTimer from refiners.training_utils.common import ( compute_grad_norm, count_learnable_parameters, @@ -41,24 +42,6 @@ from refiners.training_utils.gradient_clipping import GradientClipping, GradientClippingConfig -# Ported from open-muse -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.avg: float = 0 - self.sum: float = 0 - self.count: int = 0 - - def update(self, val: float): - self.sum += val - self.count += 1 - self.avg = self.sum / self.count - - class WarmupScheduler(LRScheduler): _step_count: int # defined by LRScheduler @@ -160,10 +143,6 @@ def __init__(self, config: ConfigType) -> None: self._models: ModelRegistry = {} self._callbacks: CallbackRegistry = {} self.config = config - self.batch_time_m = AverageMeter() - self.forward_time_m = AverageMeter() - self.backprop_time_m = AverageMeter() - self.data_time_m = AverageMeter() self._load_callbacks() self._call_callbacks(event_name="on_init_begin") self._load_models() @@ -181,6 +160,10 @@ def clock(self, config: ClockConfig) -> TrainingClock: verbose=config.verbose, ) + @register_callback() + def timer(self, config: CallbackConfig) -> TrainingTimer: + return TrainingTimer() + @register_callback() def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping: return GradientClipping(config) @@ -358,13 +341,12 @@ def dataloader(self) -> DataLoader[Any]: ) @abstractmethod - def compute_loss(self, batch: Batch) -> Tensor: - ... + def compute_loss(self, batch: Batch) -> Tensor: ... def compute_evaluation(self) -> None: pass - def backward(self, start: float) -> float: + def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") scaled_loss = self.loss / self.clock.num_step_per_iteration @@ -379,39 +361,25 @@ def backward(self, start: float) -> float: self._call_callbacks(event_name="on_lr_scheduler_step_begin") self.lr_scheduler.step() self._call_callbacks(event_name="on_lr_scheduler_step_end") - backward_time = time.time()-start if self.clock.is_evaluation_step: self.evaluate() - return backward_time - def step(self, batch: Batch) -> tuple[float, float]: + def step(self, batch: Batch) -> None: """Perform a single training step.""" - start = time.time() self._call_callbacks(event_name="on_compute_loss_begin") loss = self.compute_loss(batch=batch) self.loss = loss - forward_time = time.time() - start - self.forward_time_m.update(forward_time) - start = time.time() self._call_callbacks(event_name="on_compute_loss_end") - backward_time = self.backward(start) - self.backprop_time_m.update(backward_time) - return forward_time, backward_time + self.backward() def epoch(self) -> None: """Perform a single epoch.""" - start = time.time() for batch in self.dataloader: if self.clock.done: break self._call_callbacks(event_name="on_batch_begin") - data_time = time.time() - start - self.data_time_m.update(data_time) - forward_time, backward_time = self.step(batch=batch) + self.step(batch=batch) self._call_callbacks(event_name="on_batch_end") - batch_time = data_time + forward_time + backward_time - self.batch_time_m.update(batch_time) - start = time.time() @staticmethod def get_training_seed(instance: "Trainer[BaseConfig, Any]") -> int: diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 8479f02aa..85cd473e2 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -107,10 +107,10 @@ def on_train_begin(self, trainer: "TrainerWithWandb") -> None: def on_batch_end(self, trainer: "TrainerWithWandb") -> None: batch_time, forward_time, backprop_time, data_time = ( - trainer.batch_time_m.avg, - trainer.forward_time_m.avg, - trainer.backprop_time_m.avg, - trainer.data_time_m.avg, + trainer.timer.batch_time_meter.avg, + trainer.timer.forward_time_meter.avg, + trainer.timer.backprop_time_meter.avg, + trainer.timer.data_time_meter.avg, ) if trainer.clock.is_evaluation_step: effective_batch_size = trainer.clock.batch_size * trainer.clock.num_step_per_iteration From c7beb2df8048f2f3a79e37ae18425253a3ab5aa1 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:32:57 -0500 Subject: [PATCH 11/20] Remove diff --- src/refiners/training_utils/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 0fb6c4769..b43242f4d 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -4,7 +4,6 @@ from typing import Any, Callable, Generic, Literal, TypeVar, cast import torch -import time from loguru import logger from torch import Tensor, device as Device, dtype as DType, nn from torch.autograd import backward @@ -341,7 +340,8 @@ def dataloader(self) -> DataLoader[Any]: ) @abstractmethod - def compute_loss(self, batch: Batch) -> Tensor: ... + def compute_loss(self, batch: Batch) -> Tensor: + ... def compute_evaluation(self) -> None: pass From e65437720879f9077c6241d1abbdf79e78d91985 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:34:34 -0500 Subject: [PATCH 12/20] Adding timer.py --- src/refiners/training_utils/timer.py | 83 ++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 src/refiners/training_utils/timer.py diff --git a/src/refiners/training_utils/timer.py b/src/refiners/training_utils/timer.py new file mode 100644 index 000000000..59ea2ed9b --- /dev/null +++ b/src/refiners/training_utils/timer.py @@ -0,0 +1,83 @@ +import time +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from refiners.training_utils.callback import Callback +from refiners.training_utils.common import TimeUnit, TimeValue +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.trainer import Trainer + +if TYPE_CHECKING: + from refiners.training_utils.config import BaseConfig + from refiners.training_utils.trainer import Trainer + + +from loguru import logger + + +# Ported from open-muse +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val: float = 0 + self.avg: float = 0 + self.sum: float = 0 + self.count: int = 0 + + def update(self, val: float): + self.val = val + self.sum += val + self.count += 1 + self.avg = self.sum / self.count + + +class TrainingTimer(Callback["Trainer[BaseConfig, Any]"]): + def __init__( + self, + ) -> None: + self.start_time: float = 0 + self.batch_time_meter = AverageMeter() + self.forward_time_meter = AverageMeter() + self.backprop_time_meter = AverageMeter() + self.data_time_meter = AverageMeter() + + def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.start_time = time.time() + + def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.epoch += 1 + trainer.clock.num_batches_processed = 0 + + def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: + self.data_time_meter.update(time.time() - self.start_time) + + def on_compute_loss_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.start_time = time.time() + + def on_compute_loss_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.forward_time_meter.update(time.time() - self.start_time) + + def on_backward_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.start_time = time.time() + + def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + if (not trainer.clock.is_optimizer_step) and (not trainer.clock.is_lr_scheduler_step): + self.backprop_time_meter.update(time.time() - self.start_time) + + def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + if not trainer.clock.is_lr_scheduler_step: + self.backprop_time_meter.update(time.time() - self.start_time) + + def on_lr_scheduler_step_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.backprop_time_meter.update(time.time() - self.start_time) + + def on_batch_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + data_time = self.data_time_meter.val + forward_time = self.forward_time_meter.val + backprop_time = self.backprop_time_meter.val + self.batch_time_meter.update(data_time + forward_time + backprop_time) + self.start_time = time.time() From 0f2119eef39a6a754055f71a7ee2c519fa36525f Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:35:06 -0500 Subject: [PATCH 13/20] Remove unnecessary imports --- src/refiners/training_utils/timer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/refiners/training_utils/timer.py b/src/refiners/training_utils/timer.py index 59ea2ed9b..d9dfc0409 100644 --- a/src/refiners/training_utils/timer.py +++ b/src/refiners/training_utils/timer.py @@ -1,9 +1,7 @@ import time -from functools import cached_property from typing import TYPE_CHECKING, Any from refiners.training_utils.callback import Callback -from refiners.training_utils.common import TimeUnit, TimeValue from refiners.training_utils.config import BaseConfig from refiners.training_utils.trainer import Trainer From e8fced0fbe9c01bdc4d5517540b145e4aba1712d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:36:09 -0500 Subject: [PATCH 14/20] Basic checks --- src/refiners/training_utils/timer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/refiners/training_utils/timer.py b/src/refiners/training_utils/timer.py index d9dfc0409..605ff0a05 100644 --- a/src/refiners/training_utils/timer.py +++ b/src/refiners/training_utils/timer.py @@ -10,9 +10,6 @@ from refiners.training_utils.trainer import Trainer -from loguru import logger - - # Ported from open-muse class AverageMeter(object): """Computes and stores the average and current value""" From 4ecbb31e8425868928cee800dafbe437a4efd77c Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Tue, 20 Feb 2024 22:36:53 -0500 Subject: [PATCH 15/20] Remove unnecessary func --- src/refiners/training_utils/timer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/refiners/training_utils/timer.py b/src/refiners/training_utils/timer.py index 605ff0a05..57b35cc3f 100644 --- a/src/refiners/training_utils/timer.py +++ b/src/refiners/training_utils/timer.py @@ -43,10 +43,6 @@ def __init__( def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.start_time = time.time() - def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.epoch += 1 - trainer.clock.num_batches_processed = 0 - def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.data_time_meter.update(time.time() - self.start_time) From aa7d70d90380aec177ba5c5fbbf6553b15bf300b Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 23 Feb 2024 10:34:19 -0500 Subject: [PATCH 16/20] Merged timer with clock --- src/refiners/training_utils/clock.py | 57 ++++++++++++++++++++++++-- src/refiners/training_utils/trainer.py | 5 --- src/refiners/training_utils/wandb.py | 8 ++-- 3 files changed, 57 insertions(+), 13 deletions(-) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 8eb3936cf..be8821067 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -13,6 +13,25 @@ from loguru import logger from torch import Tensor +# Ported from open-muse +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val: float = 0 + self.avg: float = 0 + self.sum: float = 0 + self.count: int = 0 + + def update(self, val: float): + self.val = val + self.sum += val + self.count += 1 + self.avg = self.sum / self.count + class ClockConfig(CallbackConfig): verbose: bool = True @@ -45,6 +64,11 @@ def __init__( self.num_batches_processed = 0 self.num_minibatches_processed = 0 self.loss: Tensor | None = None + self.meter_start_time: float = 0 + self.batch_time_meter = AverageMeter() + self.forward_time_meter = AverageMeter() + self.backprop_time_meter = AverageMeter() + self.data_time_meter = AverageMeter() @cached_property def unit_to_steps(self) -> dict[TimeUnit, int]: @@ -168,26 +192,51 @@ def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Epoch {trainer.clock.epoch} started.") - - def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.epoch += 1 - trainer.clock.num_batches_processed = 0 + self.meter_start_time = time.time() def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Step {trainer.clock.step} started.") + self.data_time_meter.update(time.time() - self.meter_start_time) + + def on_compute_loss_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.meter_start_time = time.time() + + def on_compute_loss_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.forward_time_meter.update(time.time() - self.meter_start_time) + + def on_backward_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.meter_start_time = time.time() def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: trainer.clock.step += 1 trainer.clock.num_batches_processed += 1 trainer.clock.num_minibatches_processed += 1 + if (not trainer.clock.is_optimizer_step) and (not trainer.clock.is_lr_scheduler_step): + self.backprop_time_meter.update(time.time() - self.meter_start_time) def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Iteration {trainer.clock.iteration} ended.") trainer.clock.iteration += 1 trainer.clock.num_minibatches_processed = 0 + if not trainer.clock.is_lr_scheduler_step: + self.backprop_time_meter.update(time.time() - self.meter_start_time) + + def on_lr_scheduler_step_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + self.backprop_time_meter.update(time.time() - self.meter_start_time) + + def on_batch_end(self, trainer: Trainer[BaseConfig, Any]) -> None: + data_time = self.data_time_meter.val + forward_time = self.forward_time_meter.val + backprop_time = self.backprop_time_meter.val + self.batch_time_meter.update(data_time + forward_time + backprop_time) + self.meter_start_time = time.time() def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log("Evaluation started.") def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log("Evaluation ended.") + + def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.epoch += 1 + trainer.clock.num_batches_processed = 0 diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index b43242f4d..c68a08e88 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -30,7 +30,6 @@ CallbackConfig, ) from refiners.training_utils.clock import ClockConfig, TrainingClock -from refiners.training_utils.timer import TrainingTimer from refiners.training_utils.common import ( compute_grad_norm, count_learnable_parameters, @@ -159,10 +158,6 @@ def clock(self, config: ClockConfig) -> TrainingClock: verbose=config.verbose, ) - @register_callback() - def timer(self, config: CallbackConfig) -> TrainingTimer: - return TrainingTimer() - @register_callback() def gradient_clipping(self, config: GradientClippingConfig) -> GradientClipping: return GradientClipping(config) diff --git a/src/refiners/training_utils/wandb.py b/src/refiners/training_utils/wandb.py index 85cd473e2..b4b7103f2 100644 --- a/src/refiners/training_utils/wandb.py +++ b/src/refiners/training_utils/wandb.py @@ -107,10 +107,10 @@ def on_train_begin(self, trainer: "TrainerWithWandb") -> None: def on_batch_end(self, trainer: "TrainerWithWandb") -> None: batch_time, forward_time, backprop_time, data_time = ( - trainer.timer.batch_time_meter.avg, - trainer.timer.forward_time_meter.avg, - trainer.timer.backprop_time_meter.avg, - trainer.timer.data_time_meter.avg, + trainer.clock.batch_time_meter.avg, + trainer.clock.forward_time_meter.avg, + trainer.clock.backprop_time_meter.avg, + trainer.clock.data_time_meter.avg, ) if trainer.clock.is_evaluation_step: effective_batch_size = trainer.clock.batch_size * trainer.clock.num_step_per_iteration From 2a5b18aa4615bc37359beea4ab987b9995841be2 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 23 Feb 2024 10:35:25 -0500 Subject: [PATCH 17/20] FIxed style --- src/refiners/training_utils/clock.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index be8821067..92bba3588 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -13,6 +13,7 @@ from loguru import logger from torch import Tensor + # Ported from open-muse class AverageMeter(object): """Computes and stores the average and current value""" From ea16951209f66b6253cb18e13802436e7f18c6f8 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 23 Feb 2024 10:36:43 -0500 Subject: [PATCH 18/20] Remove diff --- src/refiners/training_utils/clock.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 92bba3588..052704d6a 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -195,6 +195,10 @@ def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Epoch {trainer.clock.epoch} started.") self.meter_start_time = time.time() + def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: + trainer.clock.epoch += 1 + trainer.clock.num_batches_processed = 0 + def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log(f"Step {trainer.clock.step} started.") self.data_time_meter.update(time.time() - self.meter_start_time) @@ -238,6 +242,3 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log("Evaluation ended.") - def on_epoch_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - trainer.clock.epoch += 1 - trainer.clock.num_batches_processed = 0 From 3454aedeb352b7750ef2d2232e8aa1a567860848 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 23 Feb 2024 10:38:38 -0500 Subject: [PATCH 19/20] Remove diffs --- src/refiners/training_utils/clock.py | 1 - src/refiners/training_utils/config.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/refiners/training_utils/clock.py b/src/refiners/training_utils/clock.py index 052704d6a..883f8eb7c 100644 --- a/src/refiners/training_utils/clock.py +++ b/src/refiners/training_utils/clock.py @@ -241,4 +241,3 @@ def on_evaluate_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: def on_evaluate_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: self.log("Evaluation ended.") - diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index fcddd7e66..5eae26244 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -168,7 +168,6 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() - timer: CallbackConfig gradient_clipping: GradientClippingConfig = GradientClippingConfig() model_config = ConfigDict(extra="forbid") From c0b0747cfffcf0b75c417e35a5d43e0570871df6 Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 23 Feb 2024 10:39:15 -0500 Subject: [PATCH 20/20] Remove timer --- src/refiners/training_utils/timer.py | 74 ---------------------------- 1 file changed, 74 deletions(-) delete mode 100644 src/refiners/training_utils/timer.py diff --git a/src/refiners/training_utils/timer.py b/src/refiners/training_utils/timer.py deleted file mode 100644 index 57b35cc3f..000000000 --- a/src/refiners/training_utils/timer.py +++ /dev/null @@ -1,74 +0,0 @@ -import time -from typing import TYPE_CHECKING, Any - -from refiners.training_utils.callback import Callback -from refiners.training_utils.config import BaseConfig -from refiners.training_utils.trainer import Trainer - -if TYPE_CHECKING: - from refiners.training_utils.config import BaseConfig - from refiners.training_utils.trainer import Trainer - - -# Ported from open-muse -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val: float = 0 - self.avg: float = 0 - self.sum: float = 0 - self.count: int = 0 - - def update(self, val: float): - self.val = val - self.sum += val - self.count += 1 - self.avg = self.sum / self.count - - -class TrainingTimer(Callback["Trainer[BaseConfig, Any]"]): - def __init__( - self, - ) -> None: - self.start_time: float = 0 - self.batch_time_meter = AverageMeter() - self.forward_time_meter = AverageMeter() - self.backprop_time_meter = AverageMeter() - self.data_time_meter = AverageMeter() - - def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.start_time = time.time() - - def on_batch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None: - self.data_time_meter.update(time.time() - self.start_time) - - def on_compute_loss_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: - self.start_time = time.time() - - def on_compute_loss_end(self, trainer: Trainer[BaseConfig, Any]) -> None: - self.forward_time_meter.update(time.time() - self.start_time) - - def on_backward_begin(self, trainer: Trainer[BaseConfig, Any]) -> None: - self.start_time = time.time() - - def on_backward_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - if (not trainer.clock.is_optimizer_step) and (not trainer.clock.is_lr_scheduler_step): - self.backprop_time_meter.update(time.time() - self.start_time) - - def on_optimizer_step_end(self, trainer: "Trainer[BaseConfig, Any]") -> None: - if not trainer.clock.is_lr_scheduler_step: - self.backprop_time_meter.update(time.time() - self.start_time) - - def on_lr_scheduler_step_end(self, trainer: Trainer[BaseConfig, Any]) -> None: - self.backprop_time_meter.update(time.time() - self.start_time) - - def on_batch_end(self, trainer: Trainer[BaseConfig, Any]) -> None: - data_time = self.data_time_meter.val - forward_time = self.forward_time_meter.val - backprop_time = self.backprop_time_meter.val - self.batch_time_meter.update(data_time + forward_time + backprop_time) - self.start_time = time.time()