Skip to content

Commit

Permalink
2d.py add model checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 25, 2020
1 parent 20ca441 commit cf266ec
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions torch_mnf/notebooks/2d.py
Expand Up @@ -4,11 +4,15 @@
import torch
from matplotlib.collections import LineCollection
from torch.distributions import MultivariateNormal
from tqdm import tqdm

import torch_mnf.flows as nf
from torch_mnf import data
from torch_mnf.data import ROOT

# %%
torch.manual_seed(0) # ensure reproducible results

sample_target_dist = data.sample_moons
# sample_target_dist = data.sample_siggraph
# sample_target_dist = data.sample_gaussian_mixture
Expand Down Expand Up @@ -66,13 +70,19 @@
# %%
# Construct the flow model.
model = nf.NormalizingFlowModel(base, flows)
model.step = 0
# TODO: tune WD
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
print(f"number of params: {sum(p.numel() for p in model.parameters()):,}")
SAVE_TO = f"{ROOT}/results/maf/"


# %%
def train_flow(steps=1000, n_samples=128, report_every=100, cb=None):
for step in range(steps + 1):
losses = []
model.step += steps
pbar = tqdm(range(steps))
for step in pbar:
x = sample_target_dist(n_samples)

_, log_det = model.inverse(x)
Expand All @@ -85,10 +95,14 @@ def train_flow(steps=1000, n_samples=128, report_every=100, cb=None):
optimizer.step() # update weights

if step % report_every == 0:
print(f"loss at step {step}: {loss:.4g}")
losses.append([step, loss])
pbar.set_postfix(loss=f"{loss:.4g}")
if callable(cb):
cb()

plt.plot(*list(zip(*losses)))
return losses


# %%
train_flow()
Expand All @@ -108,18 +122,19 @@ def train_flow(steps=1000, n_samples=128, report_every=100, cb=None):
ax1.scatter(*base_samples.T, c="y", s=5)
ax1.scatter(*z_last.T, c="r", s=5)
ax1.scatter(*target_samples.T, c="b", s=5)
ax1.legend(["base", "x->z", "data"])
ax1.legend(["base", r"x $\to$ z", "data"])
ax1.axis("scaled")
ax1.set(title="x -> z")
ax1.set(title=r"x $\to$ z")

# draw samples from the model's output dist and compare with real data
xs = model.sample(128 * 4)
x_last = xs[-1].detach().numpy()
ax2.scatter(*target_samples.T, c="b", s=5, alpha=0.5)
ax2.scatter(*x_last.T, c="r", s=5, alpha=0.5)
ax2.legend(["data", "z->x"])
ax2.legend(["data", r"z $\to$ x"])
ax2.axis("scaled")
ax1.set(title="z -> x")
ax1.set(title=r"z $\to$ x")
# plt.savefig(SAVE_TO + "z2x+x2z.pdf", bbox_inches="tight")


# %%
Expand Down Expand Up @@ -168,7 +183,7 @@ def plot_point_flow(ax, z0, z1):
_, [ax1, ax2] = plt.subplots(1, 2, figsize=(10, 5))

plot_point_flow(ax1, z0, z1)
title = f"layer {idx} ->{idx+1} ({model.flows[idx].__class__.__name__})"
title = f"layer {idx} $\\to$ {idx+1} ({model.flows[idx].__class__.__name__})"
ax1.set(xlim=[-3, 3], ylim=[-3, 3], title=title)

plot_grid_warp(ax2, z1, target_samples, n_grid)
Expand Down Expand Up @@ -202,4 +217,17 @@ def plot_learning():


# %%
train_flow(steps=400, cb=plot_learning)
losses = train_flow(steps=400, cb=plot_learning)
# plt.savefig(SAVE_TO + "point-flow.pdf", bbox_inches="tight")


# %%
# Save model state for later restoring with `checkpoint = torch.load(PATH)`.
# See https://pytorch.org/tutorials/beginner/saving_loading_models#save.
check_pt = {
"step": model.step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": losses[-1][1],
}
torch.save(check_pt, SAVE_TO + "checkpoint.pt")

0 comments on commit cf266ec

Please sign in to comment.