Skip to content

Commit

Permalink
mnf_mnist remove loss_accum, just use momentary loss
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Sep 25, 2020
1 parent fc4d951 commit 0dc1a5f
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions torch_mnf/notebooks/mnf_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,9 @@
def loss_fn(labels, preds):
nll = F.nll_loss(preds, labels).mean()

# The weighting factor dividing the KL divergence can be used as a hyperparameter.
# Decreasing it makes learning more difficult but prevents model overconfidence. If
# not seen as hyperparameter, it should be applied once per epoch, i.e. divided by
# the total number of samples in an epoch (batch_size * steps_per_epoch)
# The KL divergence acts as a regularizer to prevent overfitting.
batch_size = labels.size(0)
kl_div = mnf_lenet.kl_div() / (2 * batch_size)
kl_div = mnf_lenet.kl_div() / batch_size
loss = nll + kl_div

writer.add_scalar("NLL", nll, mnf_lenet.step)
Expand Down Expand Up @@ -91,39 +88,35 @@ def train_step(images, labels):
y_val = test_set.targets[:100]


# %%
grid = tv.utils.make_grid(X_val)
writer.add_image("images", grid, 0)
writer.add_graph(mnf_lenet, X_val) # add model graph to TensorBoard summary
writer.close()


# %%
epochs = 2
log_every = 50

for epoch in range(epochs):
pbar = tqdm(train_loader, desc=f"epoch {epoch + 1}/{epochs}")
loss_accum = 0
for samples, labels in pbar:
mnf_lenet.train()
loss_accum += train_step(samples, labels)
mnf_lenet.step += 1

loss = train_step(samples, labels)

if mnf_lenet.step % log_every == 0:

# Accuracy estimated by single call for speed. Would be more accurate to
# approximately integrate over the parameter posteriors by averaging across
# multiple calls.
mnf_lenet.eval()
val_preds = mnf_lenet(X_val)
val_acc = (y_val == val_preds.argmax(1)).float().mean()
pbar.set_postfix(
loss=f"{loss_accum/log_every:.4g}", val_acc=f"{val_acc:.4g}"
)
loss_accum = 0
pbar.set_postfix(loss=f"{loss:.4g}", val_acc=f"{val_acc:.4g}")

writer.add_scalar("validation accuracy", val_acc, mnf_lenet.step)


# %%
images, labels = next(iter(train_loader))

grid = tv.utils.make_grid(images)
writer.add_image("images", grid, 0)
writer.add_graph(mnf_lenet, images)
writer.close()
mnf_lenet.step += 1


# %%
Expand All @@ -139,4 +132,4 @@ def train_step(images, labels):


# %%
rot_img(lambda x: lenet(torch.tensor(x)), img9, plot_type="bar")
rot_img(lambda x: lenet(torch.tensor(x)).exp(), img9, plot_type="bar")

0 comments on commit 0dc1a5f

Please sign in to comment.