Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redrawing normalized samples using QR slows down training #6

Closed
Parskatt opened this issue Oct 20, 2020 · 4 comments
Closed

Redrawing normalized samples using QR slows down training #6

Parskatt opened this issue Oct 20, 2020 · 4 comments

Comments

@Parskatt
Copy link

Doing the QR-decomposition:

def orthogonal_matrix_chunk(cols, device = None):
unstructured_block = torch.randn((cols, cols), device = device)
q, _ = torch.qr(unstructured_block, some = True)
return q.t()

Slows down training substantially (at least for batch sizes of ~4). For example, in my own experiments I get ~2.5 batches/s per GPU without redrawing, and ~1.4 batches/s with redrawing.

I found one solution from pytorch GP, which dispatches to CPU for small QR factorizations:

cornellius-gp/gpytorch#1224

Perhaps a similar strategy could be used? I think num_cols should never really be more than about ~100 though, so perhaps you should always use cpu here?

@Parskatt
Copy link
Author

Using CPU instead of GPU gives me ~2 batches/s.
It's not perfect, but its better.

@lucidrains
Copy link
Owner

@Parskatt thank you for looking into this! i noticed this as well, but didn't know CPU would be faster

https://github.com/lucidrains/performer-pytorch/releases/tag/0.0.9

@Parskatt
Copy link
Author

Thanks for being super fast as usual :)

I think I will personally use trainable projection matrices, initialized as N(0,I). I'll let you know if it works out ;)

I'll close this issue

@lucidrains
Copy link
Owner

@Parskatt please do!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants