Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of memory when the meta optimizer updates parameters #2

Open
zengxianyu opened this issue Aug 24, 2017 · 3 comments
Open

Out of memory when the meta optimizer updates parameters #2

zengxianyu opened this issue Aug 24, 2017 · 3 comments

Comments

@zengxianyu
Copy link

zengxianyu commented Aug 24, 2017

Hello, I find your code very helpful, but too much memory is consumed when the meta optimizer updates parameters of the model. On my computer, it always raises an error 'out of memory' when executes Line 140 of meta_optimizer.py.

I think it could consume less memory if the MetaModel class holds a flat version of parameters instead of wrapping a model. In this way, the MetaModel reshapes the parameters and computes result through nn.functional.conv/linear, so that the meta optimizer can directly use this flat version of parameters, without allocating extra memory for flatted parameters.

@Forbu
Copy link
Contributor

Forbu commented Feb 9, 2018

I have kind of the same issue.
On the line of code:
flat_params = self.f * flat_params - self.i * Variable(flat_grads), my computer take a lot of time (making the computation graph for 25000 parameters) and then I can't print flat_params (in normal running or in debugger mode).
I think my mac just don't have enought memory. A GPU is required to train meta-optimizer.

@Forbu
Copy link
Contributor

Forbu commented Feb 9, 2018

Nevermind that was not the problem, the problem was certainly version change in pytorch and so the operation: flat_params = self.f * flat_params - self.i * Variable(flat_grads) produce a 25450*25450 matrix (not support by my computer). I change to:

        flat_params = torch.t(self.f) * flat_params - torch.t(self.i) * Variable(flat_grads)
        flat_params = flat_params.view(-1)

and it works

@ikostrikov
Copy link
Owner

Sorry, I didn't have enough time recently to fix problems with this code.

Could you submit a PR with this fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants