Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy parameters and bijectors with metaclasses #59

Closed
wants to merge 1 commit into from
Closed

Lazy parameters and bijectors with metaclasses #59

wants to merge 1 commit into from

Conversation

stefanwebb
Copy link
Contributor

@feynmanliang my fork of facebookincubator/flowtorch was deleted when I connected my FB employee account to GitHub, so I have duplicated #58 here (and added some additional functionality)

Motivation

Shape information for a normalizing flow only becomes known when the base distribution has been specified. We have been searching for an ideal solution to express the delayed instantiation of Bijector and Params for this purpose. Several possible solutions are outlined in #57.

Changes proposed

The purpose of this PR is to showcase a prototype for a solution that uses metaclasses to express delayed instantiation. This works by intercepting .call when a class initiated and returning a lazy wrapper around the class and bound arguments if only partial arguments are given to .init. If all arguments are given then the actual object is initialized. The lazy wrapper can have additional arguments bound to it, and will only become non-lazy when all the arguments are filled (or have defaults).

@stefanwebb stefanwebb added enhancement New feature or request refactor Refactoring of code labels Sep 6, 2021
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2021
@stefanwebb stefanwebb changed the title Lazy init Lazy parameters and bijectors with metaclasses Sep 6, 2021
@stefanwebb
Copy link
Contributor Author

I have confirmed that you can run the (modified) example on the front page of flowtorch.ai:

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import flowtorch.bijectors as bij
import flowtorch.distributions as dist
import flowtorch.parameters as params
import matplotlib.pyplot as plt

input_shape = torch.Size([4])

# Lazily instantiated flow plus base and target distributions
bijectors = bij.AffineAutoregressive(params=params.DenseAutoregressive(hidden_dims=(32,)))
base_dist = torch.distributions.Independent(torch.distributions.Normal(torch.zeros(2), torch.ones(2)), 1)
target_dist = torch.distributions.Normal(torch.zeros(2)+5, torch.ones(2)*0.5)

# Instantiate transformed distribution and parameters
flow = dist.Flow(base_dist, bijectors)

y_initial = flow.sample(torch.Size([300,])).detach().numpy()
y_target = target_dist.sample(torch.Size([300,])).detach().numpy()

# Training loop
opt = torch.optim.Adam(flow.parameters(), lr=5e-3)
frame = 0
for idx in range(3001):
    opt.zero_grad()

    # Minimize KL(p || q)
    y = target_dist.sample((1000,))
    loss = -flow.log_prob(y).mean()

    if idx % 500 == 0:
        print('epoch', idx, 'loss', loss)

        # Save SVG
        y_learnt = flow.sample(torch.Size([300,])).detach().numpy()

        plt.figure(figsize=(5,5), dpi= 100)
        plt.plot(y_target[:,0], y_target[:,1], 'o', color='blue', alpha=0.95, label='target')
        plt.plot(y_initial[:,0], y_initial[:,1], 'o', color='grey', alpha=0.95, label='initial')
        plt.plot(y_learnt[:,0], y_learnt[:,1], 'o', color='red', alpha=0.95, label='learnt')
        plt.xlim((-4,8))
        plt.ylim((-4,8))
        plt.xlabel('$x_1$')
        plt.ylabel('$x_2$')
        plt.legend(loc='lower right', facecolor=(1, 1, 1, 1.0))
        plt.savefig(f'bivariate-normal-frame-{frame}.svg', bbox_inches='tight', transparent=True)

        frame += 1
        
    loss.backward()
    opt.step()

@feynmanliang
Copy link
Contributor

Nice! Would you like to try importing this into Phabricator yourself?

@facebook-github-bot
Copy link
Contributor

@stefanwebb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Summary:
feynmanliang my fork of `facebookincubator/flowtorch` was deleted when I connected my FB employee account to GitHub, so I have duplicated #58 here (and added some additional functionality)

### Motivation
Shape information for a normalizing flow only becomes known when the base distribution has been specified. We have been searching for an ideal solution to express the delayed instantiation of Bijector and Params for this purpose. Several possible solutions are outlined in #57.

### Changes proposed
The purpose of this PR is to showcase a prototype for a solution that uses metaclasses to express delayed instantiation. This works by intercepting .__call__ when a class initiated and returning a lazy wrapper around the class and bound arguments if only partial arguments are given to .__init__. If all arguments are given then the actual object is initialized. The lazy wrapper can have additional arguments bound to it, and will only become non-lazy when all the arguments are filled (or have defaults).

Pull Request resolved: #59

Reviewed By: jpchen

Differential Revision: D30782184

Pulled By: stefanwebb

fbshipit-source-id: c53649d26e25478565681063361ad4e5a32110be
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D30782184

facebook-github-bot pushed a commit that referenced this pull request Sep 15, 2021
Summary:
feynmanliang my fork of `facebookincubator/flowtorch` was deleted when I connected my FB employee account to GitHub, so I have duplicated #58 here (and added some additional functionality)

### Motivation
Shape information for a normalizing flow only becomes known when the base distribution has been specified. We have been searching for an ideal solution to express the delayed instantiation of Bijector and Params for this purpose. Several possible solutions are outlined in #57.

### Changes proposed
The purpose of this PR is to showcase a prototype for a solution that uses metaclasses to express delayed instantiation. This works by intercepting .__call__ when a class initiated and returning a lazy wrapper around the class and bound arguments if only partial arguments are given to .__init__. If all arguments are given then the actual object is initialized. The lazy wrapper can have additional arguments bound to it, and will only become non-lazy when all the arguments are filled (or have defaults).

Pull Request resolved: #59

Reviewed By: jpchen

Differential Revision: D30782184

Pulled By: stefanwebb

fbshipit-source-id: f0be468015f298bfa0b40412142c493400c7efec
@stefanwebb stefanwebb deleted the lazy_init branch November 23, 2021 03:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request refactor Refactoring of code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants