Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] VisionMambaBlock example #204

Open
MelihDarcanxyz opened this issue Apr 29, 2024 · 3 comments
Open

[BUG] VisionMambaBlock example #204

MelihDarcanxyz opened this issue Apr 29, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@MelihDarcanxyz
Copy link
Contributor

MelihDarcanxyz commented Apr 29, 2024

Describe the bug
VisionMambaBlock example doesn't work. I'm getting:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
test.ipynb Cell 2 line 3
      1 block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
      2 x = torch.randn(1, 32, 256)
----> 3 out = block(x)
      4 out.shape

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File .venv/lib/python3.10/site-packages/zeta/nn/modules/vision_mamba.py:71, in VisionMambaBlock.forward(self, x)
     69 forward_conv_output = self.forward_conv1d(x1_rearranged)
     70 forward_conv_output = rearrange(forward_conv_output, "b d s -> b s d")
---> 71 x1_ssm = self.ssm(forward_conv_output)
     73 # backward conv x2
     74 x2_rearranged = rearrange(x1, "b s d -> b d s")

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File .venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:147, in SSM.forward(self, x, pscan)
    145 # Assuming selective_scan and selective_scan_seq are defined functions
    146 if pscan:
--> 147     y = selective_scan(x, delta, A, B, C, D)
    148 else:
    149     y = selective_scan_seq(x, delta, A, B, C, D)

File .venv/lib/python3.10/site-packages/zeta/nn/modules/ssm.py:29, in selective_scan(x, delta, A, B, C, D)
     26 deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, ED, N)
     27 deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)
---> 29 BX = deltaB * x.unsqueeze(-1)  # (B, L, ED, N)
     31 hs = pscan(deltaA, BX)
     33 y = (
     34     hs @ C.unsqueeze(-1)
     35 ).squeeze()  # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 2

To Reproduce
Steps to reproduce the behavior:

  1. Run the example:
block = VisionMambaBlock(dim=256, heads=8, dt_rank=32, dim_inner=512, d_state=256)
x = torch.randn(1, 32, 256)
out = block(x)
out.shape

Expected behavior

torch.Size([1, 32, 256])

Upvote & Fund

  • We're using Polar.sh so you can upvote and help fund this issue.
  • We receive the funding once the issue is completed & confirmed by you.
  • Thank you in advance for helping prioritize & fund our backlog.
Fund with Polar
@MelihDarcanxyz MelihDarcanxyz added the bug Something isn't working label Apr 29, 2024
Copy link

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.

@kyegomez
Copy link
Owner

@MelihDarcanxyz the updated model is here, this is functional: https://github.com/kyegomez/VisionMamba

I need to update this implementation here with the new implementation.

@MelihDarcanxyz
Copy link
Contributor Author

MelihDarcanxyz commented Apr 29, 2024

Hi @kyegomez , I saw that but there was parameter named num_classes so I assumed it was only suitable for classification while this implementation has no such assumptions. I'm looking for something more general. Is it a general block? Sorry, I just couldn't understand and trying to by asking.

EDIT: Tried it, didn't work either. Got the same problem from this issue kyegomez/VisionMamba#4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants