Skip to content

Commit

Permalink
2d notebook refactor grid and point flow plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 22, 2020
1 parent b451804 commit e570530
Showing 1 changed file with 97 additions and 101 deletions.
198 changes: 97 additions & 101 deletions torch_mnf/notebooks/2d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# %%
# from itertools import chain

import matplotlib.pyplot as plt
import numpy as np
import torch
Expand All @@ -24,50 +22,56 @@
base = MultivariateNormal(torch.zeros(2), torch.eye(2))

# Construct the flow.
# %% -- RNVP --
flows = [nf.AffineHalfFlow(dim=2, parity=i % 2) for i in range(9)]

# ### RealNVP
# flows = [nf.AffineHalfFlow(dim=2, parity=i % 2) for i in range(9)]

# ### NICE
# # %% -- NICE --
# flows = [nf.AffineHalfFlow(dim=2, parity=i % 2, scale=False) for i in range(4)]
# flows.append(nf.AffineConstantFlow(dim=2, shift=False))

# ### MAF (with MADE net, so we get very fast density estimation)

# # %% -- MAF --
# # for fast density estimation
# flows = [nf.MAF(dim=2, parity=i % 2) for i in range(9)]

# ### IAF (with MADE net, so we get very fast sampling)
flows = [nf.IAF(dim=2, parity=i % 2) for i in range(9)]

# ### insert ActNormFlows to any of the flows above
# norms = [nf.ActNormFlow(dim=2) for _ in flows]
# flows = list(chain(*zip(norms, flows)))
# # %% -- IAF --
# # for fast sampling
# flows = [nf.IAF(dim=2, parity=i % 2) for i in range(9)]

# ### Glow paper
# flows = [nf.Glow(dim=2) for i in range(3)]
# norms = [nf.ActNormFlow(dim=2) for _ in flows]
# couplings = [nf.AffineHalfFlow(dim=2, parity=i % 2, nh=32) for i in range(len(flows))]
# flows = list(
# chain(*zip(norms, flows, couplings))
# ) # append a coupling layer after each 1x1

# ### Neural splines, coupling
# flows = [nf.NSF_CL(dim=2, K=8, B=3, hidden_dim=16) for _ in range(3)]
# convs = [nf.Glow(dim=2) for _ in flows]
# norms = [nf.ActNormFlow(dim=2) for _ in flows]
# flows = list(chain(*zip(norms, convs, flows)))
# # %% -- ActNorm --
# # prepend ActNormFlows to every layer in any of the flows above
# for idx in reversed(range(len(flows))):
# flows.insert(idx, nf.ActNormFlow(dim=2))


# # %% -- Glow --
# flows = [nf.Glow(dim=2) for _ in range(3)]
# # prepend each Glow (1x1 convolution) with ActNormFlow and append with AffineHalfFlow
# for idx in reversed(range(len(flows))):
# flows.insert(idx, nf.ActNormFlow(dim=2))
# flows.insert(idx + 2, nf.AffineHalfFlow(dim=2, parity=idx % 2, nh=32))

# Construct the model.
model = nf.NormalizingFlowModel(base, flows)

# %% -- Neural Spline Flow --
flows = [nf.NSF_CL(dim=2, K=8, B=3, hidden_dim=16) for _ in range(3)]
# prepend each NSF flow with ActNormFlow and Glow
for idx in reversed(range(len(flows))):
flows.insert(idx, nf.ActNormFlow(dim=2))
flows.insert(idx + 1, nf.Glow(dim=2))


# %%
# Construct the flow model.
model = nf.NormalizingFlowModel(base, flows)
# TODO: tune WD
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
print("number of params: ", sum(p.numel() for p in model.parameters()))
print(f"number of params: {sum(p.numel() for p in model.parameters()):,}")


def train(steps=1000, n_samples=128, report_every=100, cb=None):
model.train()
for step in range(steps + 1):
x = sample_target_dist(n_samples)

Expand All @@ -91,85 +95,91 @@ def train(steps=1000, n_samples=128, report_every=100, cb=None):


# %%
model.eval()

# draw samples from the target dist and flow them through the model to the base dist
target_samples = sample_target_dist(128)
zs, *_ = model.inverse(target_samples)
target_samples = target_samples.detach().numpy()
z_last = zs[-1].detach().numpy()

p = model.base.sample([128, 2])
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.scatter(*p.T, c="y", s=5)
plt.scatter(*z_last.T, c="r", s=5)
plt.scatter(*target_samples.T, c="b", s=5)
plt.legend(["base", "x->z", "data"])
plt.axis("scaled")
plt.title("x -> z")

# draw samples from the model's base dist to compare how well the model maps real data
# to latent space
base_samples = model.base.sample([128, 2])
_, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))
ax1.scatter(*base_samples.T, c="y", s=5)
ax1.scatter(*z_last.T, c="r", s=5)
ax1.scatter(*target_samples.T, c="b", s=5)
ax1.legend(["base", "x->z", "data"])
ax1.axis("scaled")
ax1.set(title="x -> z")

# draw samples from the model's output dist and compare with real data
xs = model.sample(128 * 4)
x_last = xs[-1].detach().numpy()
plt.subplot(122)
plt.scatter(*target_samples.T, c="b", s=5, alpha=0.5)
plt.scatter(*x_last.T, c="r", s=5, alpha=0.5)
plt.legend(["data", "z->x"])
plt.title("z -> x")
plt.axis("scaled")
ax2.scatter(*target_samples.T, c="b", s=5, alpha=0.5)
ax2.scatter(*x_last.T, c="r", s=5, alpha=0.5)
ax2.legend(["data", "z->x"])
ax2.axis("scaled")
ax1.set(title="z -> x")


# %%
def plot_grid_warp(ax, z, target_samples, n_lines):
"""plots how the flow warps space"""

grid = z.reshape((n_lines, n_lines, 2))
# y coords
p1 = np.reshape(grid[1:, :, :], (n_lines ** 2 - n_lines, 2))
p2 = np.reshape(grid[:-1, :, :], (n_lines ** 2 - n_lines, 2))
lcy = LineCollection(zip(p1, p2), alpha=0.3)
# x coords
p1 = np.reshape(grid[:, 1:, :], (n_lines ** 2 - n_lines, 2))
p2 = np.reshape(grid[:, :-1, :], (n_lines ** 2 - n_lines, 2))
lcx = LineCollection(zip(p1, p2), alpha=0.3)
# draw the lines
ax.add_collection(lcx)
ax.add_collection(lcy)
ax.axis([-3, 3, -3, 3])
ax.set_title(f"grid warp after layer {idx+1}")

# draw the data too
plt.scatter(*target_samples.T, c="r", s=5)


def plot_point_flow(ax, z0, z1):
"""plots how the samples travel from one flow to the next"""
ax.scatter(*z0.T, c="r", s=3, alpha=0.3)
ax.scatter(*z1.T, c="b", s=3)


# %%
# plot the coordinate warp
n_grid = 20 # number of grid points
ticks = np.linspace(-3, 3, n_grid)
xy = np.stack(np.meshgrid(ticks, ticks), axis=-1)
latent_grid = np.stack(np.meshgrid(ticks, ticks), axis=-1)
# seems appropriate since we use radial distributions as base distributions
in_circle = np.sqrt((xy ** 2).sum(axis=-1)) <= 3
xy = xy.reshape((n_grid * n_grid, 2))
xy = torch.from_numpy(xy.astype("float32"))
in_circle = np.sqrt((latent_grid ** 2).sum(axis=-1)) <= 3
latent_grid = latent_grid.reshape((n_grid * n_grid, 2))
latent_grid = torch.from_numpy(latent_grid.astype("float32"))

x_val = sample_target_dist(128 * 5)
xs, *_ = model.forward(latent_grid)
xs = [z.detach().numpy() for z in xs]

zs, *_ = model.forward(xy)
for idx, [z0, z1] in enumerate(zip(xs, xs[1:])):
_, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))

# %%
reverse_flow_names = [type(f).__name__ for f in reversed(model.flows)]
for idx in range(len(zs) - 1):
z0 = zs[idx].detach().numpy()
z1 = zs[idx + 1].detach().numpy()

# plot how the samples travel at this stage
figs, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))
ax1.scatter(*z0.T, c="r", s=3)
ax1.scatter(*z1.T, c="b", s=3)
title = f"layer {idx} ->{idx+1} ({reverse_flow_names[idx]})"
plot_point_flow(ax1, z0, z1)
title = f"layer {idx} ->{idx+1} ({model.flows[idx].__class__.__name__})"
ax1.set(xlim=[-3, 3], ylim=[-3, 3], title=title)

q = z1.reshape((n_grid, n_grid, 2))
# y coords
p1 = np.reshape(q[1:, :, :], (n_grid ** 2 - n_grid, 2))
p2 = np.reshape(q[:-1, :, :], (n_grid ** 2 - n_grid, 2))
lcy = LineCollection(zip(p1, p2), linewidths=1, alpha=0.5, color="k")
# x coords
p1 = np.reshape(q[:, 1:, :], (n_grid ** 2 - n_grid, 2))
p2 = np.reshape(q[:, :-1, :], (n_grid ** 2 - n_grid, 2))
lcx = LineCollection(zip(p1, p2), linewidths=1, alpha=0.5, color="k")
# draw the lines
ax2.add_collection(lcy)
ax2.add_collection(lcx)
ax2.axis([-3, 3, -3, 3])
ax2.set_title(f"grid warp after layer {idx+1}")

# draw the data too
plt.scatter(*target_samples.T, c="r", s=5, alpha=0.5)
plot_grid_warp(ax2, z1, target_samples, n_grid)


# %%
# Callback to render progress while training. Do this with an untrained model to see
# significant changes.
def plot_learning():
zs, _ = model.forward(xy)
zs = [z.detach().numpy() for z in zs]
xs, _ = model.forward(latent_grid)
xs = [z.detach().numpy() for z in xs]

# create a square grid of subplots, one for each step in the flow as many as the
# largest square that can be filled completely
Expand All @@ -178,29 +188,15 @@ def plot_learning():
fig, axes = plt.subplots(plot_grid_height, 2 * plot_grid_height, figsize=(20, 10))
fig.subplots_adjust(wspace=0.05, hspace=0.05)

for zi, zip1, ax in zip(zs, zs[1:], axes[:, :plot_grid_height].flat):
ax.scatter(*zi.T, c="r", s=1)
ax.scatter(*zip1.T, c="b", s=1)
for z0, z1, ax in zip(xs, xs[1:], axes[:, :plot_grid_height].flat):
plot_point_flow(ax, z0, z1)
ax.set(xlim=[-4, 4], ylim=[-4, 4], xticks=[], yticks=[])

ax = fig.add_subplot(122)
grid = zs[-1].reshape((n_grid, n_grid, 2))
# y coords
p1 = np.reshape(grid[1:, :, :], (n_grid ** 2 - n_grid, 2))
p2 = np.reshape(grid[:-1, :, :], (n_grid ** 2 - n_grid, 2))
lcy = LineCollection(zip(p1, p2), linewidths=1, alpha=0.5, color="k")
# x coords
p1 = np.reshape(grid[:, 1:, :], (n_grid ** 2 - n_grid, 2))
p2 = np.reshape(grid[:, :-1, :], (n_grid ** 2 - n_grid, 2))
lcx = LineCollection(zip(p1, p2), linewidths=1, alpha=0.5, color="k")
# draw the lines
ax.add_collection(lcy)
ax.add_collection(lcx)
# draw the data too
ax.scatter(*x_val.T, c="r", s=20, alpha=0.5)
ax.set(xlim=[-2, 3], ylim=[-1.5, 2], xticks=[], yticks=[])
big_ax = fig.add_subplot(122)
plot_grid_warp(big_ax, xs[-1], target_samples, n_grid)
big_ax.set(xlim=[-2, 3], ylim=[-1.5, 2], xticks=[], yticks=[])

# hide unused subplots below the big one
# hide unused axes below the big one
for ax in axes[:, plot_grid_height:].flat:
ax.axis("off")

Expand Down

0 comments on commit e570530

Please sign in to comment.