Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update implementation of scaled_index_add and index_select_cat #195

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions dinov2/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,26 @@
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat

from xformers.ops import scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from xformers.ops import scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat
from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat


def scaled_index_add(input, index, source, scaling, alpha):
is_proper_embed_dim = input.shape[-1] % 256 == 0
is_float16 = input.dtype == torch.half
if is_proper_embed_dim and is_float16:
return _scaled_index_add(input, index, source, scaling, alpha)
else:
return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha)


def index_select_cat(sources, indices):
is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources)
is_float16 = all(s.dtype == torch.half for s in sources)
if is_proper_embed_dim and is_float16:
return _index_select_cat(sources, indices)
else:
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)


XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)")
else:
Expand Down