In [1]:
import sys 
sys.path.append('..')
from utils import wasserstein2
from kernels import *
from distributions import *
from kernelGAN import *
import torch, matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML

device = torch.device('cpu')
%matplotlib inline
plt.rcParams.update({'font.size': 20})
%load_ext autoreload
%autoreload 2

In [30]:
# distribution params
n_true, n_gen, d = 2, 4, 2
n_mixture, sigma = 2, 1
device = torch.device('cpu')
# initializing points
X_true = torch.tensor([[4, 0],[-4, 0]], dtype=torch.float32)
X_gen_init = torch.vstack([
    X_true[0]+torch.randn(n_gen-1, d)*sigma/2,
    X_true[1]+torch.randn(1, d)*sigma/10])
p_gen = torch.ones(n_gen).softmax(-1)
p_true = torch.ones(n_true).softmax(-1)
G = PointGenerator(X_gen_init, p_gen)

# initializing discriminator kernel
depth=5
DK = Kernel('nngp', depth=depth)

In [31]:
# training params
T, lr_d, lr_g, lam = 1000, 1e-1, 1e-1, 1e-2

# find generated trajectories
D = KernelDiscriminator(DK, d, lam, lr_d)
model = KernelGAN(G,D, device)
model.train(X_true, p_true, lr_d, lr_g, T, lam, e_threshold=1e-3)

100%|██████████| 999/999 [00:01<00:00, 567.35it/s]


## Animation

In [32]:
# animate the generated point trajectories
width=5
lim = 8
figure, axes = plt.subplots(figsize=(8,5)) 
camera = Camera(figure)

# plot initial distribution
cc1 = plt.Circle(X_true[0], width, alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
# cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2) 
plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
# plt.scatter(X_true[:,0],X_true[:,1],c='w', edgecolor='b',alpha=.5)
plt.xlim([-lim,lim]); plt.ylim([-lim,lim])
scat=plt.scatter(model.G.X_gen[0,:,0],model.G.X_gen[0,:,1],c='g', label='gen')
plt.legend(loc='upper right')

# updating generated points
plt_interval = 25
for i, sample_idx in enumerate(range(0,T, plt_interval)):
    # print(sample_idx)
    # print(X_gen[sample_idx,:,0])
    cc1 = plt.Circle(X_true[0], width , alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
    # cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
    axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2)
    x0, x1, u, v = model.get_grad_field(sample_idx-1, xlim=[-lim,lim], ylim=[-lim,lim], nplt=20)
    plt.quiver(x0,x1,u,v)
    plt.scatter(model.G.X_gen[sample_idx,:,0], model.G.X_gen[sample_idx,:,1], c='g', label='gen')
    plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
    camera.snap()

plt.close()
anim = camera.animate()
HTML(anim.to_html5_video())
# anim.save('../data/nngp_weird_field.gif',
#                dpi=300,
#                savefig_kwargs={
#                    'frameon': False,
#                    'pad_inches': 'tight'
#                }
#               )