diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index 8d6b0b5c..beabea00 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -134,7 +134,7 @@ struct RuntimeConfig { // These defaults are overridden by InferenceArgs::CopyTo(*this): // Max tokens per batch during prefill. - size_t prefill_tbatch_size = 256; + size_t prefill_tbatch_size = kMaxBatchSize; // Max queries per batch (one token from each) during decode. size_t decode_qbatch_size = 16; @@ -225,7 +225,7 @@ struct InferenceArgs : public ArgsBase { visitor(max_generated_tokens, "max_generated_tokens", size_t{4096}, "Maximum number of tokens to generate."); - visitor(prefill_tbatch_size, "prefill_tbatch", size_t{256}, + visitor(prefill_tbatch_size, "prefill_tbatch", size_t{kMaxBatchSize}, "Prefill: max tokens per batch."); visitor(decode_qbatch_size, "decode_qbatch", size_t{16}, "Decode: max queries per batch."); diff --git a/ops/matmul.h b/ops/matmul.h index 0f3d2866..8d75f2bb 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -54,12 +54,12 @@ HWY_INLINE_VAR constexpr size_t kNR = 4; HWY_INLINE_VAR constexpr size_t kMaxMR = 4; // For `MMTilesC`. -HWY_INLINE_VAR constexpr size_t kMaxMC = 512; -HWY_INLINE_VAR constexpr size_t kMaxNC = 16384; +HWY_INLINE_VAR constexpr size_t kMaxMC = 256; +HWY_INLINE_VAR constexpr size_t kMaxNC = 6 * 1024; // Upper bound for per-worker B storage on the stack. Chosen such that one row // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. -HWY_INLINE_VAR constexpr size_t kMaxKC = 8 * 1024; +HWY_INLINE_VAR constexpr size_t kMaxKC = 6 * 1024; // Policy classes for parallelism, implementing some of `Parallelism`.