In [1]:
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch import distributions

import numpy as np
torch.manual_seed(0)
np.random.seed(0)

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
from pylab import rcParams
rcParams['figure.figsize'] = 10, 8
rcParams['figure.dpi'] = 300

from datasets import make_circles_ssl, make_moons_ssl, make_github_cat, make_npz
from distributions import SSLGaussMixture

from itertools import chain
from invertible.toy_flow import iToy
%load_ext autoreload
%autoreload 2

ModuleNotFoundError: No module named 'datasets'

In [None]:
def grid_image(mapping, xx, yy, extradim=False, extra_noise=0):
    lines = np.hstack([xx.reshape([-1, 1]), yy.reshape([-1, 1])])
    if extra_noise:
        lines = np.hstack([lines, np.random.rand(len(lines), extra_noise)])
    if extradim:
        lines = lines[:, None, :]
    lines = torch.from_numpy(lines).float()
    print(lines.shape)
    img_lines = mapping(lines).detach().numpy()
    
    if extradim:
        img_xx, img_yy = img_lines[:, 0, 0], img_lines[:, 0, 1]
    else:
        img_xx, img_yy = img_lines[:, 0], img_lines[:, 1]
    img_xx = img_xx.reshape(xx.shape)
    img_yy = img_yy.reshape(yy.shape)
    return img_xx, img_yy

In [None]:
data, labels = make_github_cat("github.png")
#data, labels = make_moons_ssl()
#data, labels = make_npz("8gauss.npz")
bs= 50
inner_dim = 2*bs
flow = iToy(2*bs,inner_dim)

# r=2.5
# means = torch.tensor([[-r, -r, -r], [r, r, r]])
# prior = SSLGaussMixture(means=means)

# flow.prior_nll = lambda x: -prior.log_prob(x)

In [None]:
lr_init = 3e-3
epochs = 8000
batch_size = bs
n_ul = np.sum(labels == -1)
n_l = np.shape(labels)[0] - n_ul
label_weight = 1.
print_freq = 500

optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=lr_init)
for t in range(epochs):    
    
    batch_idx = np.random.choice(n_l + n_ul, size=batch_size)
    batch_x, batch_y = data[batch_idx], labels[batch_idx]
    batch_x, batch_y = torch.from_numpy(batch_x).float(), torch.from_numpy(batch_y).float()
    
    loss = flow.nll(batch_x.reshape(1,-1)).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if t % print_freq == 0:
        print('iter %s:' % t, 'loss = %.3f' % loss)
        
    if t == int(epochs * 0.5) or t == int(epochs * 0.8):
        for p in optimizer.param_groups:
            p["lr"] /= 10

In [None]:
plt.figure(figsize=(12, 10))


grid_points = 50
grid_freq = 5
z_lims = [-4, 4]
x_lims = [-4, 4]
line_z = np.linspace(*z_lims, grid_points)
line_x = np.linspace(*x_lims, grid_points)
xx_z, yy_z = np.meshgrid(line_z, line_z)
xx_x, yy_x = np.meshgrid(line_x, line_x)


inv = flow(torch.from_numpy(data).reshape(-1,2*bs)).detach().reshape(-1,2).numpy()


plt.subplot(221)
plt.scatter(inv[:, 0], inv[:, 1], c=data[:, 0], cmap=plt.cm.rainbow)
# plt.scatter(inv[labels==0][:, 0], inv[labels==0][:, 1], marker="^", s=100, edgecolor="k")
# plt.scatter(inv[labels==1][:, 0], inv[labels==1][:, 1], marker="^", s=100, edgecolor="k")
# f_xx, f_yy = grid_image(lambda x: flow(x)[:,:2], xx_x, yy_x)
# plt.plot(f_xx[:, ::grid_freq], f_yy[:, ::grid_freq], '-r', alpha=0.35)
# f_xx, f_yy = grid_image(lambda x: flow(x)[:,:2], yy_x, xx_x)
# plt.plot(f_xx[:, ::grid_freq], f_yy[:, ::grid_freq], '-b', alpha=0.35)
plt.title(r'$z = f(X)$')
plt.xlim(z_lims)
plt.ylim(z_lims)

# z = inv + np.random.randn(*inv.shape) * np.array([1, 1, 0]) * 0.5 
# z = torch.randn(data.shape[0],inner_dim).data.numpy() * 1.
z = torch.randn(1000,2).data.numpy() * 1.
#z = prior.sample([50,]).numpy()
plt.subplot(222)
plt.scatter(z[:, 0], z[:, 1], c=z[:, 0], cmap=plt.cm.rainbow)
plt.plot(xx_z[:, ::grid_freq], yy_z[:, ::grid_freq], '-r', alpha=0.35)
plt.plot(yy_z[:, ::grid_freq], xx_z[:, ::grid_freq], '-b', alpha=0.35)
plt.title(r'$z \sim p(z)$')
plt.xlim(z_lims)
plt.ylim(z_lims)
x = data#flow.sample(1000).data.numpy()
plt.subplot(223)
plt.scatter(x[:, 0], x[:, 1], c=x[:, 0], cmap=plt.cm.rainbow)
# plt.scatter(data[labels==0][:, 0], data[labels==0][:, 1], marker="^", s=100, edgecolor="k")
# plt.scatter(data[labels==1][:, 0], data[labels==1][:, 1], marker="^", s=100, edgecolor="k")
plt.plot(xx_x[:, ::grid_freq], yy_x[:, ::grid_freq], '-r', alpha=0.35)
plt.plot(yy_x[:, ::grid_freq], xx_x[:, ::grid_freq], '-b', alpha=0.35)
plt.title(r'$X \sim p(X)$')
plt.xlim(x_lims)
plt.ylim(x_lims)

plt.subplot(224)
x = flow.inverse(torch.from_numpy(z).float().reshape(-1,2*bs)).detach().reshape(-1,2).numpy()
# g_xx, g_yy = grid_image(flow.inverse, xx_z, yy_z, extra_noise=inner_dim-2)
# plt.plot(g_xx[:, ::grid_freq], g_yy[:, ::grid_freq], '-r', alpha=0.35)
# g_xx, g_yy = grid_image(flow.inverse, yy_z, xx_z, extra_noise=inner_dim-2)
# plt.plot(g_xx[:, ::grid_freq], g_yy[:, ::grid_freq], '-b', alpha=0.35)
plt.scatter(x[:, 0], x[:, 1], alpha=0.2)# c=z[:, 0], cmap=plt.cm.rainbow)
plt.scatter(data[labels==0][:, 0], data[labels==0][:, 1], marker="^", s=100, edgecolor="k")
plt.scatter(data[labels==1][:, 0], data[labels==1][:, 1], marker="^", s=100, edgecolor="k")
plt.title(r'$X = g(z)$')
#plt.xlim(x_lims)
#plt.ylim(x_lims)

In [None]:
np.linalg.norm(flow.inverse(flow(torch.from_numpy(data[:10]))).data.numpy()-data[:10])/np.linalg.norm(data[:10])

In [None]:
plt.scatter(data[:,0],data[:,1])

In [None]:
flow(torch.from_numpy(data[:3]))

In [None]:
flow.logdet()

In [None]:
eps = 1e-3
directions = np.array([[1,0],[0,1]])
j0 = (flow(torch.from_numpy(data[:3]+eps*directions[0]).float()) - flow(torch.from_numpy(data[:3]).float()))/eps
j1 = (flow(torch.from_numpy(data[:3]+eps*directions[1]).float()) - flow(torch.from_numpy(data[:3]).float()))/eps

In [None]:
j0

In [None]:
J = torch.cat((j0[:,:,None],j1[:,:,None]),dim=2)

In [None]:
torch.log(J[:,0,0]*J[:,1,1] - J[:,0,1]*J[:,1,0])