-
Notifications
You must be signed in to change notification settings - Fork 6
/
test.py
65 lines (54 loc) · 1.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch.nn import functional as F
from torchvision.utils import save_image
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from maf import MAF
from utils_maf import (
val_maf,
test_maf,
sample_digits_maf,
)
from data.loader import data_loader
string = "maf_mnist_512"
dataset = "mnist"
batch_size = 128
model = torch.load("model_saves/" + string + ".pt")
train, train_loader, val_loader, test_loader, n_in = data_loader(dataset, batch_size)
test_maf(model, train, test_loader)
val_maf(model, train, val_loader)
# sample_digits_maf(model, "test")
if dataset == "mnist":
if not os.path.exists("figs"):
os.makedirs("figs")
_, _, _, test_loader, _ = data_loader(dataset, batch_size=1000)
model.eval()
batch = next(iter(test_loader))
u = model(batch)[0].detach().numpy()
fig, axes = plt.subplots(
ncols=6, nrows=4, sharex=True, sharey=True, figsize=(16, 10)
)
for ax in axes.reshape(-1):
dim1 = np.random.randint(28 * 28)
dim2 = np.random.randint(28 * 28)
ax.scatter(u[:, dim1], u[:, dim2], color="dodgerblue", s=0.5)
ax.set_ylabel("dim: " + str(dim2), size=14)
ax.set_xlabel("dim: " + str(dim1), size=14)
ax.set_xlim(-8, 8)
ax.set_ylim(-8, 8)
ax.set_aspect(1)
plt.savefig("figs/" + string + "_scatter.png", bbox_inches="tight", dpi=300)
plt.savefig("figs/" + string + "_scatter.pdf", bbox_inches="tight", dpi=300)
fig, axes = plt.subplots(
ncols=6, nrows=4, sharex=True, sharey=True, figsize=(16, 10)
)
for ax in axes.reshape(-1):
dim1 = np.random.randint(28 * 28)
sns.distplot(u[:, dim1], ax=ax, color="darkorange")
ax.set_xlabel("dim: " + str(dim1), size=14)
ax.set_xlim(-5, 5)
plt.savefig("figs/" + string + "_marginal.png", bbox_inches="tight", dpi=300)
plt.savefig("figs/" + string + "_marginal.pdf", bbox_inches="tight", dpi=300)
sample_digits_maf(model, "test")