In [1]:
import sys 
sys.path.append('..')
from utils import wasserstein2
from kernels import *
from distributions import *
from kernelGAN import *
import torch, matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML
from tqdm import tqdm
from scipy.spatial import distance

device = torch.device('cpu')
%matplotlib inline
plt.rcParams.update({'font.size': 20})
%load_ext autoreload
%autoreload 2

### Simple Sanity Check

In [27]:
X = torch.arange(12.0).reshape((4,3))
Z = torch.arange(9.0).reshape((3,3))
M = torch.diag(torch.arange((3.0)))+1
# M = torch.eye(3)
print(euclidean(X,Z))
print(mahalanobis(X, M=M, Z=Z))
print(mahalanobis_old(X, M, Z=Z))
print(distance.mahalanobis(X[3,:], Z[2,:], M)**2)

tensor([[  0.,  27., 108.],
        [ 27.,   0.,  27.],
        [108.,  27.,   0.],
        [243., 108.,  27.]])
tensor([[  0., 108., 432.],
        [108.,   0., 108.],
        [432., 108.,   0.],
        [972., 432., 108.]])
tensor([[  0., 108., 432.],
        [108.,   0., 108.],
        [432., 108.,   0.],
        [972., 432., 108.]])
107.99999116774507


### Timing Computations

In [3]:
X = torch.arange(2000.0).reshape((1000,2))
Z = torch.arange(20.0).reshape((10,2))
M = torch.diag(torch.arange((2.0)))+1
# M = torch.zeros((3,3))
# M = torch.eye(3)
for i in tqdm(range(1000)):
    out = euclidean(X,Z)
for i in tqdm(range(1000)):
    out=mahalanobis(X, M=M, Z=Z)
for i in tqdm(range(1000)):
    out=mahalanobis(X, Z=Z)
for i in tqdm(range(1000)):
    out=mahalanobis_old(X, Z=Z)

100%|██████████| 1000/1000 [00:00<00:00, 5630.35it/s]
100%|██████████| 1000/1000 [00:00<00:00, 5783.95it/s]
100%|██████████| 1000/1000 [00:00<00:00, 5096.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1225.77it/s]


In [4]:
for i in tqdm(range(1000)):
    out = laplacian_grad_vp(X,M,Z)
for i in tqdm(range(1000)):
    out=laplacian_grad_vp(X,Z=Z)
for i in tqdm(range(1000)):
    out=old_laplacian_grad_vp(X,Z=Z)

100%|██████████| 1000/1000 [00:00<00:00, 1677.23it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1566.39it/s]
100%|██████████| 1000/1000 [00:00<00:00, 2116.23it/s]


In [51]:
# distribution params
n_true, n_gen, d = 2, 4, 2
n_mixture, sigma = 2, 1

# initializing points
X_true = torch.tensor([[3, 0],[-3, 0]], dtype=torch.float32)
X_gen_init = torch.vstack([
    X_true[0]+torch.randn(n_gen//2, d)*sigma/2,
    X_true[1]+torch.randn(n_gen//2, d)*sigma/10])
p_gen = torch.ones(n_gen)/n_gen
p_true = torch.randn(n_true)/n_true
G = PointGenerator(X_gen_init, p_gen)

# initializing discriminator kernel
width_param=1
# set kernel parameters
DK = DiscriminatorKernel('laplacian', width_param=width_param)

Without RFM updates

In [56]:
# training params
T, lr_d, lr_g, lam = 2000, 1e-1, 1e-2, 1e-1
# find generated trajectories
model = KernelGAN(G, DK, device)
model.train(X_true, p_true, lr_d, lr_g, T, lam, log_interval=200)

100%|██████████| 1999/1999 [00:02<00:00, 711.69it/s]


In [53]:
# animate the generated point trajectories
width=width_param
lim = 6
figure, axes = plt.subplots(figsize=(8,5)) 
camera = Camera(figure)

# plot initial distribution
cc1 = plt.Circle(X_true[0], width, alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
# cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2) 
plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
# plt.scatter(X_true[:,0],X_true[:,1],c='w', edgecolor='b',alpha=.5)
plt.xlim([-lim,lim]); plt.ylim([-lim,lim])
scat=plt.scatter(model.G.X_gen[0,:,0],model.G.X_gen[0,:,1],c='g', label='gen')
plt.legend(loc='upper right')

# updating generated points
plt_interval = 50
for i, sample_idx in enumerate(range(0,T, plt_interval)):
    # print(sample_idx)
    # print(X_gen[sample_idx,:,0])
    cc1 = plt.Circle(X_true[0], width , alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
    # cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
    axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2)
    x0, x1, u, v = model.get_grad_field(sample_idx-1, xlim=[-lim,lim], ylim=[-lim,lim], nplt=20)
    plt.quiver(x0,x1,u,v)
    plt.scatter(model.G.X_gen[sample_idx,:,0], model.G.X_gen[sample_idx,:,1], c='g', label='gen')
    plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
    camera.snap()

plt.close()
anim = camera.animate()
HTML(anim.to_html5_video())

With RFM Updates

In [59]:
# training params
T, lr_d, lr_g, lam = 2000, 1e-1, 1e-2, 1e-1
# find generated trajectories
model = KernelGAN(G, DK, device)
model.train(X_true, p_true, lr_d, lr_g, T, lam, log_interval=200, RFM=True)

100%|██████████| 1999/1999 [00:03<00:00, 644.36it/s]


In [60]:
# animate the generated point trajectories
width=width_param
lim = 6
figure, axes = plt.subplots(figsize=(8,5)) 
camera = Camera(figure)

# plot initial distribution
cc1 = plt.Circle(X_true[0], width, alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
# cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2) 
plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
# plt.scatter(X_true[:,0],X_true[:,1],c='w', edgecolor='b',alpha=.5)
plt.xlim([-lim,lim]); plt.ylim([-lim,lim])
scat=plt.scatter(model.G.X_gen[0,:,0],model.G.X_gen[0,:,1],c='g', label='gen')
plt.legend(loc='upper right')

# updating generated points
plt_interval = 50
for i, sample_idx in enumerate(range(0,T, plt_interval)):
    # print(sample_idx)
    # print(X_gen[sample_idx,:,0])
    cc1 = plt.Circle(X_true[0], width , alpha=0.1); cc2 = plt.Circle(X_true[1], width , alpha=0.1) 
    # cc3 = plt.Circle(X_true[2], width_param , alpha=0.1); cc4 = plt.Circle(X_true[3], width_param , alpha=0.1) 
    axes.set_aspect(1); axes.add_artist(cc1); axes.add_artist(cc2)
    x0, x1, u, v = model.get_grad_field(sample_idx-1, xlim=[-lim,lim], ylim=[-lim,lim], nplt=20)
    plt.quiver(x0,x1,u,v)
    plt.scatter(model.G.X_gen[sample_idx,:,0], model.G.X_gen[sample_idx,:,1], c='g', label='gen')
    plt.scatter(X_true[:,0],X_true[:,1],c='r', label='true', marker='*')
    camera.snap()

plt.close()
anim = camera.animate()
HTML(anim.to_html5_video())

  length = a * (widthu_per_lenu / (self.scale * self.width))
  length = a * (widthu_per_lenu / (self.scale * self.width))
