Fix gpt-j-6b RTN RuntimeError#1848
Conversation
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
…ix_gptj_dtype_issue
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull request overview
Fixes the GPT-J-6B RTN runtime error caused by missing / incorrectly-typed position_ids during cached-input replay by improving how block inputs are captured, normalized, and replayed.
Changes:
- Wraps calibration block-forward hooks to convert positional args into named kwargs for caching.
- Adds input preprocessing to unwrap single-element containers and preserve integer tensor dtypes.
- Updates
block_forwardto replay blocks using keyword arguments and normalizesposition_idswhen needed.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| auto_round/compressors/utils.py | Reworks block replay to map positional inputs to real parameter names and call blocks via kwargs; adds position_ids normalization. |
| auto_round/compressors/data_driven.py | Adjusts cached-input dtype/device preparation, including unwrapping single-element containers and list handling. |
| auto_round/calibration/inputs.py | Adds shared preprocessing helper to unwrap single-element kwargs and preserves int index tensors from dtype casting. |
| auto_round/calibration/hooks.py | Wraps block-forward capture hook with positional→kwargs conversion and updates replay invocation. |
| kwargs.pop("hidden_states") | ||
| return m.orig_forward(hidden_states, *positional_inputs, **kwargs) | ||
| kwargs.pop("hidden_states", None) | ||
| return m.orig_forward(hidden_states=hidden_states, *positional_inputs, **kwargs) |
| # Use the block's actual parameter name for the first positional argument. | ||
| import inspect as _inspect | ||
|
|
||
| param_names = [p for p in _inspect.signature(block.forward).parameters.keys() if p != "self"] | ||
| block_input_kwarg = param_names[0] if param_names else "hidden_states" | ||
| if block_input_kwarg not in input_others: | ||
| input_others[block_input_kwarg] = input_ids | ||
|
|
||
| # Convert positional inputs to keyword args for any remaining positional parameters. | ||
| positional_inputs = input_tuple or () | ||
| if positional_inputs: | ||
| for i, val in enumerate(positional_inputs): | ||
| param_idx = i + 1 # hidden_states is params[0] | ||
| if param_idx < len(param_names): | ||
| param_name = param_names[param_idx] | ||
| if param_name not in input_others: | ||
| input_others[param_name] = val | ||
| positional_inputs = () |
| to_dtype(v, tmp_dtype) | ||
| for v in val | ||
| if not (isinstance(v, torch.Tensor) and v.dtype in (torch.int32, torch.int64)) |
| # Guard: ensure position_ids is a tensor, not a list or None. | ||
| if "position_ids" in input_others: | ||
| pid = input_others["position_ids"] | ||
| if isinstance(pid, list): | ||
| if len(pid) == 1: | ||
| input_others["position_ids"] = pid[0] | ||
| elif len(pid) == 0: | ||
| # Generate position_ids from hidden_states shape when it's empty. | ||
| input_others["position_ids"] = ( | ||
| torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long) | ||
| .unsqueeze(0) | ||
| .expand(input_ids.shape[0], -1) | ||
| ) | ||
| elif pid is None: | ||
| # Generate position_ids from hidden_states shape when it's None. | ||
| input_others["position_ids"] = ( | ||
| torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long) | ||
| .unsqueeze(0) | ||
| .expand(input_ids.shape[0], -1) | ||
| ) | ||
|
|
…ix_gptj_dtype_issue
…uto-round into lvl/fix_gptj_dtype_issue
|
CUDA_VISIBLE_DEVICES=4 auto-round --model_name /mnt/disk3/lvl/gpt-j-6b/ --bits 4 --iters 0 --tasks lambada_openai Notes:
evaluation running time=95s |
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines successfully started running 1 pipeline(s). |
Signed-off-by: lvliang-intel <liang1.lv@intel.com>
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines successfully started running 1 pipeline(s). |
Description
Fix gpt-j-6b RTN RuntimeError
root cause:
GPT-J calls block(hidden_states, position_ids, attention_mask), position_ids is a positional arg, not a keyword arg. The original hook only captured keyword args, so position_ids was never stored in state.inputs. On replay, position_ids became None or a Python list instead of a torch.Tensor, and torch.gather requires int32/int64, caused the error.
Solution:
The hook wasn't capturing positional args like GPT-J's position_ids, so it was missing on replay. The fix wraps the hook to intercept positional args as keywords, unwraps single-element containers from batch_size=1, and invokes blocks via block(**input_others) using real parameter names.
Type of Change
Bug fix
Related Issues
#1838
Checklist Before Submitting
/azp run Unit-Test-CUDA-AutoRound.