Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile fails with view_as_float=True #18

Closed
mobicham opened this issue Mar 1, 2024 · 1 comment
Closed

torch.compile fails with view_as_float=True #18

mobicham opened this issue Mar 1, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@mobicham
Copy link
Collaborator

mobicham commented Mar 1, 2024

Seems like torch.compile doesn't like using views on dtypes. This causes the PYTORCH_COMPILE backend and model=torch.compile(model) to break when view_as_float is set to True :

BackendCompilerFailed: backend='inductor' raised:
LoweringException: NotImplementedError: bitcast torch.float16 to different bitwidth type torch.uint8 is not supported yet.

Wrapping the view with torch.jit.ignore doesn't work in this case.
Minimal code to reproduce the issue:

import torch
from hqq.core.quantize import *

HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)

#######################################################################################
batch_size    = 1
context_size  = 512
compute_dtype = torch.float16
linear_layer  = torch.nn.Linear(4096, 4096)

quant_config  = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, offload_meta=False, view_as_float=True) 
hqq_linear    = HQQLinear(linear_layer, quant_config, compute_dtype=compute_dtype, del_orig=False)

@torch.jit.ignore
def dequantize_Wq_aten(W_q, meta):
	if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
	return hqq_aten.dequantize(W_q, meta['scale'], meta['zero'], meta['shape'], meta['group_size'] if (meta['group_size']) else -1, meta['nbits'], meta['axis'], meta['packing'])

@torch.compile()
def dequantize(hqq_layer):
	return dequantize_Wq_aten(hqq_layer.W_q, hqq_layer.meta)

######################################################################################

#This works: 
hqq_linear.W_q.data = hqq_linear.W_q.data.view(hqq_linear.meta['unpack_view_dtype']) 
W_r = dequantize(hqq_linear)

#This breaks
hqq_linear.W_q.data = hqq_linear.W_q.data.view(compute_dtype) 
W_r = dequantize(hqq_linear)

A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.

This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998

@KeremTurgutlu fyi

@mobicham mobicham added the bug Something isn't working label Mar 1, 2024
@mobicham
Copy link
Collaborator Author

mobicham commented Mar 4, 2024

Note: It looks like it only happens between types that don't have the same bitwidth (uint8 <> float16). If you force to use float32 view with 3-bit that uses int32 bitpacking, it would work fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant