-
Notifications
You must be signed in to change notification settings - Fork 1
/
forward.py
77 lines (64 loc) · 2.16 KB
/
forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from utils import (
GDensity,
GMixDensity,
Uniform,
EmpiricalDensity
)
g1 = GDensity([0., 0.], [
[1., 0.9],
[0.9, 1.]
])
g2 = GDensity([0., 0.], [
[1., -0.9],
[-0.9, 1.]
])
G = GMixDensity([g2, g1])
N = GDensity([0., 0.],
[
[1., 0],
[0., 1.]
]
)
U = Uniform([-2, 2], [-2, 2])
n = EmpiricalDensity(N.sample(N=1000))
u = EmpiricalDensity(U.sample(N=1000))
def stoch_proc_1(x, delta=0.01):
return x + G.score(x) * delta + np.sqrt(2 * delta) * np.random.randn(*x.shape)
ITER = 50
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
XRANGE = [-3, 3]
YRANGE = [-3, 3]
x_grid, y_grid = np.meshgrid(np.linspace(-1, 3, 10), np.linspace(-3, 1, 10))
dfdx = lambda x, y: -x / np.sqrt(x ** 2 + y ** 2)
dfdy = lambda x, y: -y / np.sqrt(x ** 2 + y ** 2)
x_vf = dfdx(x_grid, y_grid)
y_vf = dfdy(x_grid, y_grid)
p_x, p_y = 1, -2
q_x, q_y = 2, 0
for i in tqdm(range(ITER)):
if i % (ITER // 50) == 0:
ax[0].cla()
ax[1].cla()
n.plot_density(ax=ax[0], cmap='Reds')
ax[1].quiver(x_grid, y_grid, x_vf, y_vf)
ax[1].scatter([p_x, ], [p_y, ], color='red')
# ax[1].scatter([q_x, ], [q_y, ], color='green')
# ax[1].scatter([0., ], [0., ], color='red')
ax[1].scatter([1, ], [-2, ], color='blue')
ax[1].text(0.3, 0.4, r'$q_{data}$', ha='center', va='center', color='red', fontsize=20)
ax[1].text(1.3, -2.4, r'$\mathcal{N}(0, I)$', ha='center', va='center', color='blue', fontsize=20)
p_x, p_y = p_x + 5.e-2 * dfdx(p_x, p_y), p_y + 5.e-2 * dfdy(p_x, p_y)
q_x, q_y = q_x + 5.e-2 * dfdx(q_x, q_y), q_y + 5.e-2 * dfdy(q_x, q_y)
ax[0].set_xlim(XRANGE); ax[0].set_ylim(YRANGE)
ax[1].set_xlim([-1, 3]); ax[1].set_ylim([-3, 1])
ax[0].set_title(r'$q_t\ |\ q_0 = q_{data}$', fontsize=20)
ax[1].set_title(r'$p$ space', fontsize=20)
ax[0].axis('off')
ax[1].axis('off')
# ax[2].axis('off')
plt.savefig(f'figs/test_{ITER - i}.png', bbox_inches='tight', pad_inches=0)
n.nudge(stoch_proc_1, delta=5.e-3)
u.nudge(stoch_proc_1, delta=5.e-3)