New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Monarch Projection Step for rectangular blocks and where the number of blocks != sqrt(input dimension) #10
Comments
Maybe the function you're looking for is block_diag_butterfly_project_einsum_rank.
|
Thanks! just to make sure, the forward function called in monarch_linear.py:
is then equivalent to fly/src/ops/blockdiag_butterfly_einsum.py Line 89 in cd624cf
and we can just use the |
Another related question, I'm seeing relatively high projection error for arbitrary weight matrices. ie. if I generate a standard normal matrix I'm wondering whether I'm using the function suggested correctly. I've shown this in the colab here: https://colab.research.google.com/drive/18uQy0nWP-oH0bXcViwipzxsA-5MpfMpk?usp=sharing |
Hi, I had a question regarding the projection step from dense -> monarch butterfly matrices, in the general case, where we have rectangular or nblocks != sqrt(m) of the matrix, as I'm having trouble finding the code for this and extending existing implementations for this case.
I found multiple projection files/functions: blockdiag_butterfly_projection.py and blockdiag_butterfly_einsum.py.
As an example, if I use code roughly as follows:
The resulting shape of the output will be
8x8
, which is essentially the full matrix used for transformation.However, if I use the projection function (which is meant for square matrices) from blockdiag_butterfly_projection.py to try and recover the butterfly matrices from this matrix, I run into the issue that it expects the matrix to decompose as follows
M_permuted_batched = rearrange(M, '(p k) (r s) -> k r p s', k=sizes[1], r=sizes[0])
, while in our case:r = 4 and s = 4
, making it incompatible with the matrix dimensions.Meanwhile, the einsum functions in blockdiag_butterfly_einsum.py gave different results from the original
blockdiag_butterfly_multiply
(comparing the forward multiplication step not the projection step). (see this colab)In the paper, I did see the original derivation for algorithm 1:
but I was unclear on how to actually perform the decomposition step when we can't decompose the tensor into an
m x m x m x m
shape.The text was updated successfully, but these errors were encountered: