Skip to content

Fix gpt-j-6b RTN RuntimeError#1848

Merged
XuehaoSun merged 7 commits into
mainfrom
lvl/fix_gptj_dtype_issue
May 25, 2026
Merged

Fix gpt-j-6b RTN RuntimeError#1848
XuehaoSun merged 7 commits into
mainfrom
lvl/fix_gptj_dtype_issue

Conversation

@lvliang-intel
Copy link
Copy Markdown
Contributor

Description

Fix gpt-j-6b RTN RuntimeError

root cause:
GPT-J calls block(hidden_states, position_ids, attention_mask), position_ids is a positional arg, not a keyword arg. The original hook only captured keyword args, so position_ids was never stored in state.inputs. On replay, position_ids became None or a Python list instead of a torch.Tensor, and torch.gather requires int32/int64, caused the error.

Solution:
The hook wasn't capturing positional args like GPT-J's position_ids, so it was missing on replay. The fix wraps the hook to intercept positional args as keywords, unwraps single-element containers from batch_size=1, and invokes blocks via block(**input_others) using real parameter names.

Type of Change

Bug fix

Related Issues

#1838

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.
  • The CUDA CI has passed. You can trigger it by commenting /azp run Unit-Test-CUDA-AutoRound.

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
Copilot AI review requested due to automatic review settings May 24, 2026 11:42
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes the GPT-J-6B RTN runtime error caused by missing / incorrectly-typed position_ids during cached-input replay by improving how block inputs are captured, normalized, and replayed.

Changes:

  • Wraps calibration block-forward hooks to convert positional args into named kwargs for caching.
  • Adds input preprocessing to unwrap single-element containers and preserve integer tensor dtypes.
  • Updates block_forward to replay blocks using keyword arguments and normalizes position_ids when needed.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
auto_round/compressors/utils.py Reworks block replay to map positional inputs to real parameter names and call blocks via kwargs; adds position_ids normalization.
auto_round/compressors/data_driven.py Adjusts cached-input dtype/device preparation, including unwrapping single-element containers and list handling.
auto_round/calibration/inputs.py Adds shared preprocessing helper to unwrap single-element kwargs and preserves int index tensors from dtype casting.
auto_round/calibration/hooks.py Wraps block-forward capture hook with positional→kwargs conversion and updates replay invocation.

Comment thread auto_round/calibration/hooks.py Outdated
kwargs.pop("hidden_states")
return m.orig_forward(hidden_states, *positional_inputs, **kwargs)
kwargs.pop("hidden_states", None)
return m.orig_forward(hidden_states=hidden_states, *positional_inputs, **kwargs)
Comment on lines +180 to +197
# Use the block's actual parameter name for the first positional argument.
import inspect as _inspect

param_names = [p for p in _inspect.signature(block.forward).parameters.keys() if p != "self"]
block_input_kwarg = param_names[0] if param_names else "hidden_states"
if block_input_kwarg not in input_others:
input_others[block_input_kwarg] = input_ids

# Convert positional inputs to keyword args for any remaining positional parameters.
positional_inputs = input_tuple or ()
if positional_inputs:
for i, val in enumerate(positional_inputs):
param_idx = i + 1 # hidden_states is params[0]
if param_idx < len(param_names):
param_name = param_names[param_idx]
if param_name not in input_others:
input_others[param_name] = val
positional_inputs = ()
Comment on lines +977 to +979
to_dtype(v, tmp_dtype)
for v in val
if not (isinstance(v, torch.Tensor) and v.dtype in (torch.int32, torch.int64))
Comment on lines +199 to +219
# Guard: ensure position_ids is a tensor, not a list or None.
if "position_ids" in input_others:
pid = input_others["position_ids"]
if isinstance(pid, list):
if len(pid) == 1:
input_others["position_ids"] = pid[0]
elif len(pid) == 0:
# Generate position_ids from hidden_states shape when it's empty.
input_others["position_ids"] = (
torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.shape[0], -1)
)
elif pid is None:
# Generate position_ids from hidden_states shape when it's None.
input_others["position_ids"] = (
torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long)
.unsqueeze(0)
.expand(input_ids.shape[0], -1)
)

@lvliang-intel
Copy link
Copy Markdown
Contributor Author

CUDA_VISIBLE_DEVICES=4 auto-round --model_name /mnt/disk3/lvl/gpt-j-6b/ --bits 4 --iters 0 --tasks lambada_openai
2026-05-24 22:46:36 INFO main.py L652: start to quantize /mnt/disk3/lvl/gpt-j-6b
2026-05-24 22:46:36 INFO config.py L45: enable_opt_rtn is turned on, set --disable_opt_rtn for higher speed at the cost of accuracy.
2026-05-24 22:46:36 WARNING logging.py L340: Using LLM mode (new architecture).
Loading weights: 100%|███████████████████████████████████████████| 285/285 [00:00<00:00, 32852.65it/s]
[transformers] GPTJForCausalLM LOAD REPORT from: /mnt/disk3/lvl/gpt-j-6b
Key | Status | |
----------------------------------------+------------+--+-
transformer.h.{0...27}.attn.bias | UNEXPECTED | |
transformer.h.{0...27}.attn.masked_bias | UNEXPECTED | |

Notes:

  • UNEXPECTED: can be ignored when loading from different task/architecture; not ok if you expect identical arch.
    2026-05-24 22:46:58 WARNING logging.py L340: some layers are skipped quantization (shape not divisible by 32):
    [transformers] loss_type=None was set in the config but it is unrecognized. Using the default loss: ForCausalLMLoss.
    2026-05-24 22:46:58 INFO base.py L667: 'enable_torch_compile' is set to False by default. Enabling it can reduce tuning cost by 20%, but it might throw an exception.
    2026-05-24 22:46:59 INFO data_driven.py L1088: start to compute imatrix
    2026-05-24 22:46:59 INFO calib_dataset.py L977: Preprocessing calibration dataset in a subprocess to avoid memory leaks...
    2026-05-24 22:47:19 INFO calib_dataset.py L977: Preprocessing calibration dataset in a subprocess to avoid memory leaks...
    Quantizing transformer.h.0: 0%| | 0/28 [00:00<?, ?it/s]2026-05-24 22:47:38 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 13.86GB
    Quantizing transformer.h.1: 4%|█▎ | 1/28 [00:05<02:34, 5.71s/it]2026-05-24 22:47:43 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.2: 7%|██▋ | 2/28 [00:10<02:15, 5.21s/it]2026-05-24 22:47:48 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.3: 11%|████ | 3/28 [00:15<02:05, 5.01s/it]2026-05-24 22:47:53 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.4: 14%|█████▍ | 4/28 [00:20<01:58, 4.94s/it]2026-05-24 22:47:58 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.5: 18%|██████▊ | 5/28 [00:25<01:55, 5.01s/it]2026-05-24 22:48:04 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.6: 21%|████████▏ | 6/28 [00:31<01:58, 5.39s/it]2026-05-24 22:48:11 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.7: 25%|█████████▌ | 7/28 [00:38<02:01, 5.77s/it]2026-05-24 22:48:17 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.8: 29%|██████████▊ | 8/28 [00:44<01:59, 6.00s/it]2026-05-24 22:48:23 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.9: 32%|████████████▏ | 9/28 [00:50<01:52, 5.92s/it]2026-05-24 22:48:29 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.10: 36%|████████████▊ | 10/28 [00:55<01:45, 5.85s/it]2026-05-24 22:48:35 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.11: 39%|██████████████▏ | 11/28 [01:02<01:43, 6.11s/it]2026-05-24 22:48:41 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.12: 43%|███████████████▍ | 12/28 [01:08<01:34, 5.94s/it]2026-05-24 22:48:47 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.13: 46%|████████████████▋ | 13/28 [01:13<01:28, 5.87s/it]2026-05-24 22:48:52 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.14: 50%|██████████████████ | 14/28 [01:19<01:21, 5.82s/it]2026-05-24 22:48:58 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.15: 54%|███████████████████▎ | 15/28 [01:25<01:15, 5.83s/it]2026-05-24 22:49:04 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.16: 57%|████████████████████▌ | 16/28 [01:31<01:10, 5.84s/it]2026-05-24 22:49:10 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.17: 61%|█████████████████████▊ | 17/28 [01:37<01:04, 5.82s/it]2026-05-24 22:49:16 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.18: 64%|███████████████████████▏ | 18/28 [01:42<00:58, 5.81s/it]2026-05-24 22:49:20 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.19: 68%|████████████████████████▍ | 19/28 [01:47<00:49, 5.51s/it]2026-05-24 22:49:26 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.20: 71%|█████████████████████████▋ | 20/28 [01:52<00:43, 5.45s/it]2026-05-24 22:49:31 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.21: 75%|███████████████████████████ | 21/28 [01:58<00:38, 5.51s/it]2026-05-24 22:49:36 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.22: 79%|████████████████████████████▎ | 22/28 [02:03<00:31, 5.31s/it]2026-05-24 22:49:41 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.23: 82%|█████████████████████████████▌ | 23/28 [02:08<00:25, 5.15s/it]2026-05-24 22:49:46 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.24: 86%|██████████████████████████████▊ | 24/28 [02:13<00:20, 5.05s/it]2026-05-24 22:49:51 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.25: 89%|████████████████████████████████▏ | 25/28 [02:17<00:14, 4.98s/it]2026-05-24 22:49:55 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.26: 93%|█████████████████████████████████▍ | 26/28 [02:22<00:09, 4.93s/it]2026-05-24 22:50:00 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.27: 96%|██████████████████████████████████▋ | 27/28 [02:27<00:04, 4.89s/it]2026-05-24 22:50:05 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    Quantizing transformer.h.27: 100%|████████████████████████████████████| 28/28 [02:32<00:00, 5.44s/it]
    2026-05-24 22:50:07 INFO shard_writer.py L324: model has been saved to ./tmp_autoround/gpt-j-6b-w4g128/
    2026-05-24 22:50:07 INFO device.py L1840: 'peak_ram': 15.33GB, 'peak_vram': 15.86GB
    2026-05-24 22:50:07 INFO evaluation.py L457: Using lm-eval version 0.4.11.dev0
    Detected kernel version 5.4.292, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
    /mnt/disk1/lvl/conda_envs/artest-main/lib/python3.11/site-packages/transformers/quantizers/auto.py:262: UserWarning: You passed quantization_config or equivalent parameters to from_pretrained but the model you're loading already has a quantization_config attribute. The quantization_config from the model will be used.However, loading attributes (e.g. ['backend']) will be overwritten with the one you passed to from_pretrained. The rest will be ignored.
    warnings.warn(warning_msg)
    2026-05-24 22:50:07 WARNING backend.py L1176: Better backend is found, please install all the following requirements to enable it.
    2026-05-24 22:50:07 WARNING backend.py L1176: pip install -v "gptqmodel>=2.0" --no-build-isolation
    Loading weights: 100%|████████████████████████████████████████████| 621/621 [00:00<00:00, 1105.10it/s]
    2026-05-24 22:50:10 WARNING convert_model.py L768: Forced model to torch.float16
    100%|████████████████████████████████████████████████████████████| 5153/5153 [00:10<00:00, 502.29it/s]
    Running loglikelihood requests: 0%| | 0/5153 [00:00<?, ?it/s]Passed argument batch_size = auto:1. Detecting largest batch size
    Determined largest batch size: 64
    Running loglikelihood requests: 100%|████████████████████████████| 5153/5153 [00:45<00:00, 114.16it/s]
    bootstrapping for stddev: perplexity
    100%|███████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 51.96it/s]
    | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
    |--------------|------:|------|-----:|----------|---|-----:|---|-----:|
    |lambada_openai| 1|none | 0|acc |↑ |0.6907|± |0.0064|
    | | |none | 0|perplexity|↓ |3.9934|± |0.0868|

evaluation running time=95s

@lvliang-intel
Copy link
Copy Markdown
Contributor Author

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

Signed-off-by: lvliang-intel <liang1.lv@intel.com>
@lvliang-intel
Copy link
Copy Markdown
Contributor Author

/azp run Unit-Test-CUDA-AutoRound

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines successfully started running 1 pipeline(s).

@XuehaoSun XuehaoSun merged commit 28f4238 into main May 25, 2026
44 of 46 checks passed
@XuehaoSun XuehaoSun deleted the lvl/fix_gptj_dtype_issue branch May 25, 2026 06:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants