diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index bdbf668e1e30..9a4e4fc3504e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -40,7 +40,6 @@ jobs: run: python3 -m pip install -r benchmark_v2/requirements.txt kernels - name: Reinstall transformers in edit mode (remove the one installed during docker image build) - working-directory: /transformers run: python3 -m pip uninstall -y transformers && python3 -m pip install -e ".[torch]" - name: Run benchmark diff --git a/benchmark_v2/framework/benchmark_runner.py b/benchmark_v2/framework/benchmark_runner.py index 47a60b4e0a88..69fa2b51b576 100644 --- a/benchmark_v2/framework/benchmark_runner.py +++ b/benchmark_v2/framework/benchmark_runner.py @@ -117,8 +117,6 @@ def flush_memory(): # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() - torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() gc.collect() diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index 45841ee4e197..780da4ce9b15 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -189,7 +189,9 @@ def __init__( num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens( num_blocks=getattr(generation_config, "num_blocks", None), max_batch_tokens=getattr(generation_config, "max_batch_tokens", None), - max_memory_percent=getattr(generation_config, "max_memory", 0.9), + max_memory_percent=getattr( + generation_config, "max_memory", 0.8 + ), # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI cache_dtype=self.dtype, ) @@ -414,7 +416,7 @@ def infer_num_blocks_and_max_batch_tokens( self, num_blocks: Optional[int] = None, max_batch_tokens: Optional[int] = None, - max_memory_percent: float = 0.9, + max_memory_percent: float = 0.8, # FIXME: it seems we overcommit memory, was changed from 0.9 which caused OOMs in our benchmarking CI cache_dtype: torch.dtype = torch.float16, ) -> tuple[int, int]: """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and @@ -454,7 +456,7 @@ def infer_num_blocks_and_max_batch_tokens( def compute_num_blocks_and_max_batch_tokens( self, - max_memory_percent: float = 0.9, + max_memory_percent: float, cache_dtype: torch.dtype = torch.float16, m: float = 0.01, ) -> tuple[int, int]: @@ -503,7 +505,7 @@ def compute_num_blocks_and_max_batch_tokens( def compute_max_batch_tokens( self, num_blocks: int, - max_memory_percent: float = 0.9, + max_memory_percent: float, cache_dtype: torch.dtype = torch.float16, ) -> int: """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by: @@ -531,7 +533,7 @@ def compute_max_batch_tokens( def compute_num_blocks( self, max_batch_tokens: int, - max_memory_percent: float = 0.9, + max_memory_percent: float, cache_dtype: torch.dtype = torch.float16, ) -> int: """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by: diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index e6bbfd9ad771..407a66f775d7 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -826,6 +826,8 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None: if block: self.join(stop_trigger_time, timeout) + self.batch_processor = None + def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None: """Wait for the background thread to finish.