Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_for_backward can only save variables, but argument 5 is of type bool #4

Open
abalikhan opened this issue Oct 26, 2022 · 1 comment

Comments

@abalikhan
Copy link

abalikhan commented Oct 26, 2022

Hi,

Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.

Can you please help me out?

Code:

import torch
from memory_efficient_attention_pytorch import Attention

cross_attn = Attention(
dim = 512,
dim_head = 64,
heads = 8,
memory_efficient = True,
q_bucket_size = 1024,
k_bucket_size = 2048
).cuda()
(# out = sm_mod(inp1)) did this to avoid being a header
x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
(# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading
out = cross_attn(x

ERROR:

File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in
cli.main()
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
run()
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
runpy.run_path(target_as_str, run_name=compat.force_str("main"))
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in
out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out)
File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
TypeError: save_for_backward can only save variables, but argument 5 is of type bool

@lucidrains
Copy link
Owner

@aliabid2243 Hi Abid! Which version of pytorch are you on?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants