Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create BlockSparse Tensor #202

Merged
merged 8 commits into from
Feb 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
173 changes: 173 additions & 0 deletions tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

# needed to register custom ops
import xformers # noqa: F401
from xformers.ops import masked_matmul
from xformers.sparse import BlockSparseTensor

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
_devices = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the tests should cover fp16 also, most people will use that in fp16 and for some gpus it will follow a different code path (v100 for instance)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running on fp16 but I got consistent segfaults. This might be due to my version of triton, or something else I'm doing wrong. Anyway, I'll let the traceback here if it can be of use.

cc @ptillet

Traceback of the segfault
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007ffef65bb34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
(gdb) bt
#0  0x00007ffef65bb34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#1  0x00007ffef65c375f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#2  0x00007ffef650b1e4 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#3  0x00007ffef650b34f in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#4  0x00007ffef64e503b in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#5  0x00007ffef64e5bca in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#6  0x00007ffef66b2f83 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#7  0x00007ffef66b3027 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#8  0x00007ffef63a9bf4 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#9  0x00007ffef63b2578 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#10 0x00007ffef63b67c2 in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#11 0x00007ffef63b7c2c in ?? () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#12 0x00007ffef63ab05c in __cuda_CallJitEntryPoint () from /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1
#13 0x00007fff37951942 in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#14 0x00007fff379a010d in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#15 0x00007fff37733d7a in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#16 0x00007fff376e248e in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#17 0x00007fff377a617c in ?? () from /lib/x86_64-linux-gnu/libcuda.so
#18 0x00007fff0d671216 in triton::driver::dispatch::cuModuleLoadData(CUmod_st**, void const*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#19 0x00007fff0d6ae865 in cu_load_binary(std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long) ()
   from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#20 0x00007fff0d6b21f5 in pybind11::cpp_function::initialize<init_triton_codegen(pybind11::module&&)::{lambda(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned lo
ng)#2}, std::tuple<unsigned long, unsigned long>, backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long, pybind11::name, pybind11::scope, pybind11::sibling, pybi
nd11::return_value_policy>(init_triton_codegen(pybind11::module&&)::{lambda(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long)#2}&&, std::tuple<unsigned long
, unsigned long> (*)(backend_t, std::string const&, std::map<std::string, pybind11::object, std::less<std::string>, std::allocator<std::pair<std::string const, pybind11::object> > >&, unsigned long, unsigned long), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::r
eturn_value_policy const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#21 0x00007fff0d6ab6b2 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/triton/_C/libtriton.so
#22 0x00005555556a8348 in cfunction_call_varargs (kwargs=<optimized out>, args=<optimized out>, func=0x7fff3751ed60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:743
#23 PyCFunction_Call (func=0x7fff3751ed60, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:773
#24 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff3751ed60, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#25 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff29043b40, callable=0x7fff3751ed60) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#26 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#27 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#28 0x00005555556eee3f in function_code_fastcall (globals=<optimized out>, nargs=3, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#29 _PyFunction_Vectorcall (kwnames=0x0, nargsf=<optimized out>, stack=0x7fffffffa300, func=0x7fff28e98040) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#30 _PyObject_FastCallDict (kwargs=<optimized out>, nargsf=<optimized out>, args=0x7fffffffa300, callable=0x7fff28e98040) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:96
#31 _PyObject_Call_Prepend (callable=0x7fff28e98040, obj=<optimized out>, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#32 0x00005555556eef9a in slot_tp_init (self=0x7fff28f9b910, args=0x7fff28e14440, kwds=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6790
#33 0x0000555555697d2e in type_call (kwds=0x0, args=0x7fff28e14440, type=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:994
#34 _PyObject_MakeTpCall (callable=0x555558cf9240, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#35 0x000055555571f545 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5555f36a2098, callable=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#36 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#37 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#38 0x00005555556ed821 in PyEval_EvalFrameEx (throwflag=0, f=0x5555f36a1e10) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#39 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x7fff28f84988, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x7fff28f009c0, closure=0x0,
    name=0x7ffff78cc1f0, qualname=0x7fff29b4b8b0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#40 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff28f848f0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#41 0x00005555556eec71 in _PyObject_FastCallDict (kwargs=0x7fff29bf0ac0, nargsf=19, args=0x7fff21cbbe90, callable=0x7fff28e985e0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:104
#42 _PyObject_Call_Prepend (callable=0x7fff28e985e0, obj=<optimized out>, args=<optimized out>, kwargs=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#43 0x00005555556eef0a in slot_tp_call (self=0x7fff28f9b550, args=0x7fff21ceb040, kwds=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#44 0x00005555556985fb in PyObject_Call (callable=0x7fff28f9b550, args=0x7fff21ceb040, kwargs=0x7fff29bf0ac0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:246
#45 0x00005555557210b6 in do_call_core (kwdict=0x7fff29bf0ac0, callargs=0x7fff21ceb040, func=0x7fff28f9b550, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:5010
#46 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3559
#47 0x00005555556ed821 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff29c6d800) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#48 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=<optimized out>, kwargs=0x7fff29083740, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x7fff28f3c8c0,
    name=0x7ffff76b36b0, qualname=0x7fff28e31490) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#49 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff290836b0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#50 0x0000555555698693 in PyVectorcall_Call (kwargs=<optimized out>, tuple=<optimized out>, callable=0x7fff28f9ae50) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:200
#51 PyObject_Call (callable=0x7fff28f9ae50, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:228
#52 0x00005555557210b6 in do_call_core (kwdict=0x7fff29bf0400, callargs=0x7fff2900ab80, func=0x7fff28f9ae50, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:5010
#53 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3559
#54 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff29b74440) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#55 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x7fff21cda598, kwargs=0x7fff29083678, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x7ffff78cc1f0,
    qualname=0x7fff28e1b760) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#56 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff290835e0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#57 0x00005555556eec71 in _PyObject_FastCallDict (kwargs=0x7fff3751be80, nargsf=19, args=0x7fff21cbbdf0, callable=0x7fff28e98700) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:104
#58 _PyObject_Call_Prepend (callable=0x7fff28e98700, obj=<optimized out>, args=<optimized out>, kwargs=0x7fff3751be80) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#59 0x00005555556eef0a in slot_tp_call (self=0x7fff28f9b760, args=0x7fff2900a700, kwds=0x7fff3751be80) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#60 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff28f9b760, args=<optimized out>, nargs=<optimized out>, keywords=0x7fff28e79160) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#61 0x00005555557202ab in _PyObject_Vectorcall (kwnames=0x7fff28e79160, nargsf=<optimized out>, args=<optimized out>, callable=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#62 call_function (kwnames=0x7fff28e79160, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#63 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3515
#64 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=11, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#65 _PyFunction_Vectorcall (func=<optimized out>, stack=0x5555e73fccb0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
--Type <RET> for more, q to quit, c to continue without paging--
#66 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x5555e73fccb0, callable=0x7fff28e7f4c0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#67 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#68 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#69 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=21, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#70 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff29083538, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#71 0x00005555556f3a22 in PyVectorcall_Call (kwargs=0x0, tuple=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:200
#72 PyObject_Call (kwargs=0x0, args=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:228
#73 PyEval_CallObjectWithKeywords (kwargs=0x0, args=<optimized out>, callable=0x7fff28ff4a60) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:810
#74 PyObject_CallObject (callable=0x7fff28ff4a60, args=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:818
#75 0x00007ffff5c9b608 in THPFunction_apply(_object*, _object*) () from /private/home/fmassa/.conda/envs/xformers/lib/python3.8/site-packages/torch/lib/libtorch_python.so
#76 0x00005555556a83d0 in cfunction_call_varargs (kwargs=<optimized out>, args=<optimized out>, func=0x7fff290a5ea0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:758
#77 PyCFunction_Call (func=0x7fff290a5ea0, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:773
#78 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7fff290a5ea0, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#79 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x55555832ddb0, callable=0x7fff290a5ea0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#80 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#81 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#82 0x00005555556eee3f in function_code_fastcall (globals=<optimized out>, nargs=3, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#83 _PyFunction_Vectorcall (kwnames=0x0, nargsf=<optimized out>, stack=0x7fffffffb670, func=0x7fff28ff4ca0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#84 _PyObject_FastCallDict (kwargs=<optimized out>, nargsf=<optimized out>, args=0x7fffffffb670, callable=0x7fff28ff4ca0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:96
#85 _PyObject_Call_Prepend (callable=0x7fff28ff4ca0, obj=<optimized out>, args=<optimized out>, kwargs=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:888
#86 0x00005555556eef0a in slot_tp_call (self=0x7ffff7820490, args=0x7ffff772e8c0, kwds=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/typeobject.c:6556
#87 0x0000555555697dbc in _PyObject_MakeTpCall (callable=0x7ffff7820490, args=<optimized out>, nargs=<optimized out>, keywords=0x0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:159
#88 0x0000555555723666 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff29084580, callable=0x7ffff7820490) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:125
#89 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#90 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#91 0x00005555556ee36b in function_code_fastcall (globals=<optimized out>, nargs=4, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#92 _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x555558b50d48, func=0x7fff28fe21f0) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#93 _PyObject_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, args=0x555558b50d48, callable=0x7fff28fe21f0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#94 method_vectorcall (method=<optimized out>, args=0x555558b50d50, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/classobject.c:60
#95 0x0000555555657a61 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x555558b50d50, callable=0x7fff2a6deb40) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#96 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#97 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#98 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x555558b50b90) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#99 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff2a043428, kwcount=<optimized out>, kwstep=1, defs=0x7fff28f1e258, defcount=2, kwdefs=0x0, closure=0x0, name=0x7fff374d4b70,
    qualname=0x7fff28fc5510) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#100 0x00005555556ee480 in _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x7fff2a043400, func=0x7fff28fe2550) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#101 _PyObject_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, args=0x7fff2a043400, callable=0x7fff28fe2550) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#102 method_vectorcall (method=<optimized out>, args=0x7fff2a043408, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/classobject.c:60
#103 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff2a043408, callable=0x7fff29df0540) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#104 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#105 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#106 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff2a043240) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#107 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff36373c00, kwcount=<optimized out>, kwstep=1, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x7ffff77364e0,
    qualname=0x7ffff77364e0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#108 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff36373bd8, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#109 0x0000555555657a61 in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff36373bd8, callable=0x7fff3743b040) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#110 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#111 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3469
#112 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7fff36373a40) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#113 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x7fff2a0ba1f8, kwcount=<optimized out>, kwstep=1, defs=0x7fff28fc85f8, defcount=1, kwdefs=0x0, closure=0x0,
    name=0x7ffff77322f0, qualname=0x7ffff77322f0) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#114 0x00005555556ee0a3 in _PyFunction_Vectorcall (func=<optimized out>, stack=0x7fff2a0ba1e0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:436
#115 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fff2a0ba1e0, callable=0x7fff28fc7790) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#116 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#117 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#118 0x00005555556edfcb in function_code_fastcall (globals=<optimized out>, nargs=4, args=<optimized out>, co=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:284
#119 _PyFunction_Vectorcall (func=<optimized out>, stack=0x7ffff78185b0, nargsf=<optimized out>, kwnames=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Objects/call.c:411
#120 0x00005555556575db in _PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff78185b0, callable=0x7fff29ec93a0) at /tmp/build/80754af9/python_1618343417471/work/Include/cpython/abstract.h:127
#121 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, tstate=0x5555558f4850) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4963
#122 _PyEval_EvalFrameDefault (f=<optimized out>, throwflag=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:3500
#123 0x00005555556ed270 in PyEval_EvalFrameEx (throwflag=0, f=0x7ffff7818440) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:741
#124 _PyEval_EvalCodeWithName (_co=<optimized out>, globals=<optimized out>, locals=<optimized out>, args=<optimized out>, argcount=<optimized out>, kwnames=0x0, kwargs=0x0, kwcount=<optimized out>, kwstep=2, defs=0x0, defcount=0, kwdefs=0x0, closure=0x0, name=0x0, qualname=0x0)
    at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4298
#125 0x0000555555782543 in PyEval_EvalCodeEx () at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:4327
#126 PyEval_EvalCode (co=<optimized out>, globals=<optimized out>, locals=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/ceval.c:718
#127 0x00005555557825e4 in run_eval_code_obj (co=0x7ffff77c4c90, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1165
#128 0x00005555557a8854 in run_mod (mod=<optimized out>, filename=<optimized out>, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0, flags=<optimized out>, arena=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1187
#129 0x0000555555669390 in pyrun_file (fp=0x5555558f0340, filename=0x7ffff7848eb0, start=<optimized out>, globals=0x7ffff787a9c0, locals=0x7ffff787a9c0, closeit=1, flags=0x7fffffffc5b8) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:1084
#130 0x000055555566c0d2 in pyrun_simple_file (flags=0x7fffffffc5b8, closeit=1, filename=0x7ffff7848eb0, fp=0x5555558f0340) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:439
#131 PyRun_SimpleFileExFlags (fp=0x5555558f0340, filename=<optimized out>, closeit=1, flags=0x7fffffffc5b8) at /tmp/build/80754af9/python_1618343417471/work/Python/pythonrun.c:472
#132 0x000055555566cbf0 in pymain_run_file (cf=0x7fffffffc5b8, config=0x5555558f39b0) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:391
#133 pymain_run_python (exitcode=0x7fffffffc5b0) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:616
#134 Py_RunMain () at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:695
#135 0x00005555557aba09 in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at /tmp/build/80754af9/python_1618343417471/work/Modules/main.c:1141
--Type <RET> for more, q to quit, c to continue without paging--
#136 0x00007ffff7db40b3 in __libc_start_main (main=0x55555566d460 <main>, argc=2, argv=0x7fffffffc7b8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffc7a8) at ../csu/libc-start.c:308
#137 0x000055555573afe5 in _start () at ../sysdeps/x86_64/elf/start.S:103



def _create_tensor(device, BLOCK=32, Z=8, C=2, H=64, W=64, dtype=torch.float32):
layout = torch.randint(2, (C, H // BLOCK, W // BLOCK))
layout[:, :, 0] = 1
layout[:, 0, :] = 1
values = torch.randn(Z, layout.sum(), BLOCK, BLOCK, device=device).to(dtype)

mask = (
layout[None, :, :, None, :, None]
.repeat(Z, 1, 1, BLOCK, 1, BLOCK)
.reshape(Z, C, H, W)
)

return BlockSparseTensor(values, layout), mask.bool()


@pytest.mark.parametrize("device", _devices)
def test_masked_matmul(device):
BLOCK = 16
N, C, H, W, L = 8, 2, 64, 64, 32
mask_block, _ = _create_tensor(device, BLOCK, N, C, H, W, dtype=torch.bool)
mask = mask_block.to_dense()

a = torch.randn(N, C, H, L, device=device)
b = torch.randn(N, C, W, L, device=device)

aa = a.clone()
bb = b.clone()

a.requires_grad_(True)
b.requires_grad_(True)
aa.requires_grad_(True)
bb.requires_grad_(True)

bt = b.transpose(-2, -1)
bbt = bb.transpose(-2, -1)

# res_gt = masked_matmul(a, b, mask)
res_gt = a @ bt
# res_gt[~mask] = 0
res_gt = torch.where(mask, res_gt, torch.zeros_like(res_gt))
res = masked_matmul(aa, bbt, mask_block)

res_dense = res.to_dense()
# res_dense[~mask] = float('-inf')

assert res.dtype == res_gt.dtype
assert torch.allclose(res_dense, res_gt)

# try to workaround non-contiguous issues with triton for now
res_gt.backward(torch.ones_like(res_gt))
res._blocksparse_values.backward(torch.ones_like(res._blocksparse_values))
# TODO: this is not passing!!!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that only when a row is [0], or do you have other issues ?

Copy link
Contributor Author

@fmassa fmassa Feb 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failures here are due to triton-lang/triton#419

@ptillet Are we planning on releasing a new version of 1.1.x sometime soon?

# assert torch.allclose(a.grad, aa.grad, atol=1e-7)
# assert torch.allclose(b.grad, bb.grad, atol=1e-7)


@pytest.mark.parametrize("device", _devices)
def test_bmm(device):
BLOCK = 16
N, C, H, W, L = 8, 2, 64, 64, 32
a_block, mask = _create_tensor(device, BLOCK, N, C, H, W)
a = a_block.to_dense()

a_block.requires_grad_(True)
a.requires_grad_(True)

b = torch.randn(N, C, W, L, device=device)
b2 = b.clone()

b.requires_grad_(True)
b2.requires_grad_(True)

res_gt = a @ b
res = a_block @ b2

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-5)

res_gt.sum().backward()
res.sum().backward()

a_grad = a.grad.clone().detach()
a_grad[~mask] = 0

assert torch.allclose(b.grad, b2.grad, atol=1e-5)
assert torch.allclose(a_grad, a_block.grad.to_dense(), atol=1e-7)


@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax(device):
a_block, mask = _create_tensor(device)
a = a_block.to_dense()
a[~mask] = float("-inf")

res_gt = torch.softmax(a, dim=-1)
res_block = torch.softmax(a_block, dim=-1)

res = res_block.to_dense()

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)


@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_backward(device):
a_block, mask = _create_tensor(device)
a = a_block.to_dense()
a_block.requires_grad_(True)

a[~mask] = float("-inf")
a.requires_grad_(True)

res_gt = torch.softmax(a, dim=-1)
res_block = torch.softmax(a_block, dim=-1)

# WARNING: gradients are modified in-place!
res_block._blocksparse_values.backward(
torch.ones_like(res_block._blocksparse_values)
)
res_gt.backward(torch.ones_like(res_gt))

assert torch.allclose(a.grad, a_block.grad.to_dense(), atol=2e-7)


@pytest.mark.parametrize("device", _devices)
def test_deepcopy(device):
import copy

a_block, mask = _create_tensor(device)

b_block = copy.deepcopy(a_block)
assert torch.equal(a_block, b_block)


@pytest.mark.parametrize("device", _devices)
def test_module_buffer(device):
a_block, _ = _create_tensor(device)
b_block, _ = _create_tensor(device)

module = torch.nn.Module()
# test that register_buffer works
module.register_buffer("a_block", a_block)

assert module.a_block is a_block

module.to(device)
assert module.a_block.device == torch.device(device)

state_dict = module.state_dict()
assert "a_block" in state_dict
assert torch.equal(a_block.to(device), state_dict["a_block"])

module.load_state_dict(state_dict)

module.load_state_dict({"a_block": b_block})
assert torch.equal(module.a_block, b_block.to(device))
1 change: 1 addition & 0 deletions xformers/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
# LICENSE file in the root directory of this source tree.


from .blocksparse_tensor import BlockSparseTensor # noqa: F401
from .csr_tensor import SparseCSRTensor # noqa: F401