Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GGUF CLIP/Text Encoders don't work on Intel Arc. #50

Closed
simonlui opened this issue Aug 20, 2024 · 13 comments
Closed

GGUF CLIP/Text Encoders don't work on Intel Arc. #50

simonlui opened this issue Aug 20, 2024 · 13 comments

Comments

@simonlui
Copy link

simonlui commented Aug 20, 2024

So I hate to rain on the newly committed code, but I have errors with the current implementation loading GGUF CLIP/text encoders with Intel Arc using ComfyUI commit 5a69f84. If I try to load it with GPU only using --gpu-only, I get the following backtrace.

got prompt
Using pytorch attention in VAE
Using pytorch attention in VAE
/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py:35: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/pytorch/torch/csrc/utils/tensor_numpy.cpp:206.)
  torch.from_numpy(tensor.data), # mmap

ggml_sd_loader:
 0                             471
 14                            304
 1                               5


model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16
model_type FLUX

ggml_sd_loader:
 13                            144
 0                              50
 14                             25


/ComfyUI/custom_nodes/ComfyUI-GGUF/dequant.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  data = torch.tensor(tensor.data)
ignoring 'copy_' on tensor
Requested to load FluxClipModel_
Loading 1 new model
loaded completely 0.0 161.97479629516602 True
/deps/venv/lib/python3.11/site-packages/intel_extension_for_pytorch/frontend.py:465: UserWarning: Conv BatchNorm folding failed during the optimize process.
  warnings.warn(
/deps/venv/lib/python3.11/site-packages/intel_extension_for_pytorch/frontend.py:472: UserWarning: Linear BatchNorm folding failed during the optimize process.
  warnings.warn(
!!! Exception during processing !!! 'GGMLTensor' object has no attribute 'tensor_shape'
Traceback (most recent call last):
  File "/ComfyUI/execution.py", line 316, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/execution.py", line 191, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/execution.py", line 168, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/ComfyUI/execution.py", line 157, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py", line 224, in load_clip
    return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py", line 193, in load_patcher
    if clip.cond_stage_model.clip_l.transformer.text_projection.weight.tensor_shape == None:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'GGMLTensor' object has no attribute 'tensor_shape'

Prompt executed in 3.52 seconds
Exception in thread Thread-6 (prompt_worker):
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.11/threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "/ComfyUI/main.py", line 152, in prompt_worker
    comfy.model_management.cleanup_models()
  File "/ComfyUI/comfy/model_management.py", line 573, in cleanup_models
    x.model_unload()
  File "/ComfyUI/comfy/model_management.py", line 341, in model_unload
    self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
  File "/ComfyUI/comfy/model_patcher.py", line 620, in unpatch_model
    self.model.to(device_to)
  File "/deps/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/deps/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/deps/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  [Previous line repeated 5 more times]
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/ops.py", line 86, in _apply
    self.weight = fn(self.weight)
                  ^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1158, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/ops.py", line 24, in to
    new.tensor_type = self.tensor_type
                      ^^^^^^^^^^^^^^^^
AttributeError: 'GGMLTensor' object has no attribute 'tensor_type'

With --lowvram, I get this backtrace instead.

got prompt
Using pytorch attention in VAE
Using pytorch attention in VAE
/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py:35: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /build/pytorch/torch/csrc/utils/tensor_numpy.cpp:206.)
  torch.from_numpy(tensor.data), # mmap

ggml_sd_loader:
 0                             471
 14                            304
 1                               5


model weight dtype torch.float8_e4m3fn, manual cast: torch.bfloat16
model_type FLUX

ggml_sd_loader:
 13                            144
 0                              50
 14                             25


/ComfyUI/custom_nodes/ComfyUI-GGUF/dequant.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  data = torch.tensor(tensor.data)
ignoring 'copy_' on tensor
Requested to load FluxClipModel_
Loading 1 new model
loaded completely 0.0 161.97479629516602 True
!!! Exception during processing !!! The default implementation of __deepcopy__() for non-wrapper subclasses only works for subclass types that implement new_empty() and for which that function returns another instance of the same subclass. You should either properly implement new_empty() for your subclass or override __deepcopy__() if it is intended behavior for new_empty() to return an instance of a different type.
Traceback (most recent call last):
  File "/ComfyUI/execution.py", line 316, in execute
    output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/execution.py", line 191, in get_output_data
    return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/execution.py", line 168, in _map_node_over_list
    process_inputs(input_dict, i)
  File "/ComfyUI/execution.py", line 157, in process_inputs
    results.append(getattr(obj, func)(**inputs))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py", line 224, in load_clip
    return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/custom_nodes/ComfyUI-GGUF/nodes.py", line 183, in load_patcher
    clip = comfy.sd.load_text_encoder_state_dicts(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/comfy/sd.py", line 475, in load_text_encoder_state_dicts
    clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/comfy/sd.py", line 94, in __init__
    model_management.load_models_gpu([self.patcher], force_full_load=True)
  File "/ComfyUI/comfy/model_management.py", line 535, in load_models_gpu
    cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ComfyUI/comfy/model_management.py", line 325, in model_load
    self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/intel_extension_for_pytorch/frontend.py", line 451, in optimize
    optimized_model, optimized_optimizer = _copy_model_and_optimizer(
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/intel_extension_for_pytorch/frontend.py", line 42, in _copy_model_and_optimizer
    new_model = copy.deepcopy(model)
                ^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 297, in _reconstruct
    value = deepcopy(value, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 153, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/_tensor.py", line 84, in __deepcopy__
    return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/overrides.py", line 1577, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/_tensor.py", line 1386, in __torch_function__
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/deps/venv/lib/python3.11/site-packages/torch/_tensor.py", line 174, in __deepcopy__
    raise RuntimeError(
RuntimeError: The default implementation of __deepcopy__() for non-wrapper subclasses only works for subclass types that implement new_empty() and for which that function returns another instance of the same subclass. You should either properly implement new_empty() for your subclass or override __deepcopy__() if it is intended behavior for new_empty() to return an instance of a different type.

Prompt executed in 3.28 seconds

Understandably, this might not be a high priority. But just putting it out there for people who may run into the same issue I did.

@city96
Copy link
Owner

city96 commented Aug 20, 2024

I've seen reports of this happening with that "ViT-L-14-BEST-smooth-GmP-ft" checkpoint. Could you test the default clip-l.safetensors one to see if it still happens? There should also be a CLI flag to disable ipex optimizations (--disable-ipex-optimize iirc) which might help?

@simonlui
Copy link
Author

Using the regular clip-l didn't change anything, I still run into the problem with the missing GGMLTensor object attribute. As for disabiing optimizations, I wrote the code for that flag and the default optimization call for ipex.optimize. The reason why I reported that issue is because IPEX optimizations are default for a reason where the call does things pretty safely so if something broke with IPEX optimizations, it usually isn't working and is a code issue somewhere on the client side and not the runtime. In any case, not using ipex.optimize will bypass the deepcopy issue but it will immediately hit the aforementioned GGML issue instead.

@city96
Copy link
Owner

city96 commented Aug 20, 2024

Got it, thanks for testing, will investigate and try to fix.

@edwios
Copy link

edwios commented Aug 21, 2024

I have the exact same error running on Apple Silicon M1 Max 64GB when using the ViT-L-14-BEST-smooth-GmP-ft.safetensor. Switching away from it solved the problem. Tried with both clip-l.safetensor, clip-vit-large-patch14-336.bin, none of them has any issue.

@simonlui
Copy link
Author

simonlui commented Aug 23, 2024

@city96 I have done some deep diving and I think I have everything working now. I had to add the following patch which is an implementation of deepcopy to GGUFTensor.

diff --git a/ops.py b/ops.py
index 6bc1e2a..7b13806 100644
--- a/ops.py
+++ b/ops.py
@@ -39,6 +39,9 @@ class GGMLTensor(torch.Tensor):
         except Exception as e:
             print(f"ignoring 'copy_' on tensor")
 
+    def __deepcopy__(self, memo):^M
+        return self.detach().clone()^M
+^M
     @property
     def shape(self):
         if not hasattr(self, "tensor_shape"):

This and the changes I have in comfyanonymous/ComfyUI#4562 should allow for XPU to now work correctly with the GGUF Text Encoders/CLIP loaders you wrote. I believe you should be able to make this change independently of that pull request. Let me know if you have any concerns.

Edit: Visual snippet of a modified ComfyUI Examples workflow I tested to make sure this worked.

image

@city96
Copy link
Owner

city96 commented Aug 23, 2024

@simonlui .detach() and .clone() don't do anything though, since the tensor is supposed to be immutable anyway - can you check if doing just return self would work? See below, both of those just return self (could also implement actual deepcopy the way copy is implemented, which may be safer)

image

@simonlui
Copy link
Author

@city96 I tested it out and it does work too.

@city96
Copy link
Owner

city96 commented Aug 24, 2024

I don't have intel so I can't test, but what about any of these:
(I have a feeling that just returning self for deepcopy would break some stuff down the line.)

    def __deepcopy__(self, *args, **kwargs):
        return super().__deepcopy__(*args, **kwargs)
    def __deepcopy__(self, *args, **kwargs):
        return self.data.__deepcopy__(*args, **kwargs)
    def __deepcopy__(self, *args, **kwargs):
        new = super().__deepcopy__(*args, **kwargs)
        new.tensor_type = getattr(self, "tensor_type", None)
        new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
        new.patches = getattr(self, "patches", []).copy()
        return new

@simonlui
Copy link
Author

All of these implementations work but I am guessing you want to use the last one since it aligns to changes you already made which is fine by me.

city96 added a commit that referenced this issue Aug 24, 2024
Deep dive & fix by @simonlui in #50
@city96
Copy link
Owner

city96 commented Aug 24, 2024

@simonlui Thanks for the research/testing. Added - hopefully it's fine.

@simonlui
Copy link
Author

Yep, did a final test, everything checks out. Thanks for making the change.

@zer0int
Copy link

zer0int commented Aug 24, 2024

Hey everyone / @city96,

I'm the author of ViT-L-14-BEST-smooth-GmP-ft.safetensor @ https://huggingface.co/zer0int/CLIP-GmP-ViT-L-14/tree/main and code @ https://github.com/zer0int/CLIP-fine-tune that was used to fine-tune the model.

I am assuming the issue may have been due to 1. Fine-tuning based on original OpenAI/CLIP code and 2. Then just converting the .pt model to .safetensors, without converting to the syntax HF uses for the model, and 3. without "detaching" the vision transformer (it's a full text-vision transformer model .safetensors).

For my previous model, I supplied a text encoder only HF format version, but it didn't seem like there was a particular interest in that - so I omitted the potential "choice confusion" for my latest model.

It seems there's demand for that after all, judging by this thread. Alas:

If you could either

  1. Point me to a script for the "absolutely right way to convert original OpenAI/CLIP to HF that works for everything" or
  2. Confirm the above older model "TE only, HF format" works correctly (I just reasoned it together based on what a HF CLIP-L looks like, alas can't gurantee it's 100% conforming),

then I would be happy to upload a 'proper' HF model / proper text encoder for this model, and in the future.

Kind regards!

@city96
Copy link
Owner

city96 commented Aug 24, 2024

@zer0int I think comfy handles the logic for loading/key conversions internally (it uses a custom implementation for CLIP instead of importing from transformers/open-clip) and since we're not quantizing/touching clip-l it does work correctly on the latest commit afaik - the issue was that metadata was stripped from our custom quantized tensors during the conversion (even for FP16) which was causing issues even with non-quantized models, but I've added a fallback to just treat them like normal tensors in that case.

The intel issue was related to something else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants