Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monarch Projection Step for rectangular blocks and where the number of blocks != sqrt(input dimension) #10

Open
sashaDoubov opened this issue Nov 23, 2022 · 3 comments

Comments

@sashaDoubov
Copy link

sashaDoubov commented Nov 23, 2022

Hi, I had a question regarding the projection step from dense -> monarch butterfly matrices, in the general case, where we have rectangular or nblocks != sqrt(m) of the matrix, as I'm having trouble finding the code for this and extending existing implementations for this case.

I found multiple projection files/functions: blockdiag_butterfly_projection.py and blockdiag_butterfly_einsum.py.

As an example, if I use code roughly as follows:

from src.models.layers.blockdiag_butterfly_multiply import blockdiag_butterfly_multiply

x = np.eye(8)
nblocks = 2
bfly_dim = 4

w1_bfly = np.random.normal((nblocks, bfly_dim, bfly_dim))
w2_bfly = np.random.normal((nblocks, bfly_dim, bfly_dim))

bfly_matrix = blockdiag_butterfly_multiply(x, w1_bfly, w2_bfly)

print(bfly_matrix.shape)

The resulting shape of the output will be 8x8, which is essentially the full matrix used for transformation.

However, if I use the projection function (which is meant for square matrices) from blockdiag_butterfly_projection.py to try and recover the butterfly matrices from this matrix, I run into the issue that it expects the matrix to decompose as follows M_permuted_batched = rearrange(M, '(p k) (r s) -> k r p s', k=sizes[1], r=sizes[0]), while in our case: r = 4 and s = 4, making it incompatible with the matrix dimensions.

Meanwhile, the einsum functions in blockdiag_butterfly_einsum.py gave different results from the original blockdiag_butterfly_multiply (comparing the forward multiplication step not the projection step). (see this colab)

In the paper, I did see the original derivation for algorithm 1:
image but I was unclear on how to actually perform the decomposition step when we can't decompose the tensor into an m x m x m x m shape.

@sashaDoubov sashaDoubov changed the title Monarch Projection Step for rectangular/non-sqrt(n) # of blocks matrices Monarch Projection Step for rectangular/non-sqrt(m) # of blocks matrices Nov 25, 2022
@sashaDoubov sashaDoubov changed the title Monarch Projection Step for rectangular/non-sqrt(m) # of blocks matrices Monarch Projection Step for rectangular blocks and where the number of blocks != not sqrt(input dimension) Nov 25, 2022
@sashaDoubov sashaDoubov changed the title Monarch Projection Step for rectangular blocks and where the number of blocks != not sqrt(input dimension) Monarch Projection Step for rectangular blocks and where the number of blocks != sqrt(input dimension) Nov 25, 2022
@tridao
Copy link
Contributor

tridao commented Dec 7, 2022

Maybe the function you're looking for is block_diag_butterfly_project_einsum_rank.
(you can see our tests here that the projection recovers the original factors)

def test_block_diag_butterfly_project_einsum_rank(device):

@sashaDoubov
Copy link
Author

Thanks! just to make sure, the forward function called in monarch_linear.py:

def forward(ctx, x, w1_bfly, w2_bfly):

is then equivalent to
def blockdiag_butterfly_multiply_einsum_rank(x, w1_bfly, w2_bfly):

and we can just use the block_diag_butterfly_project_einsum_rank function for the projection step? I compared the two forward functions on a number of inputs, and they seemed equivalent to me, but just wanted to double check.

@sashaDoubov
Copy link
Author

Another related question, I'm seeing relatively high projection error for arbitrary weight matrices.

ie. if I generate a standard normal matrix M, with dimensions 1024 x 4096, project this into two monarch matrices with the function you suggested, then compute the overall projected matrix \tilde{M}, I get a max element-wise difference of ~4. Is this expected? I'm finding that fine-tuning dense -> sparse fine-tuning is not performing well due to this projection error.

I'm wondering whether I'm using the function suggested correctly.

I've shown this in the colab here: https://colab.research.google.com/drive/18uQy0nWP-oH0bXcViwipzxsA-5MpfMpk?usp=sharing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants