# Invertible Residual Network

A very basic implementation of the [invertible residual network](https://arxiv.org/pdf/1811.00995.pdf) for illustration purposes. Not sure the implementation correct...

In [14]:
import math
import numpy as np
import torch
from SpectralNormGouk import spectral_norm

from bokeh.plotting import figure, show, output_notebook, gridplot

output_notebook()

In [8]:
K = 100

d1 = np.random.multivariate_normal(np.zeros(2), np.array([[1.5, -.6], [-.6, .5]]), K) + 2
d2 = np.random.multivariate_normal(np.zeros(2), np.array([[2, .8], [.8, .4]]), K) 
data = np.vstack([d1, d2])
X = torch.from_numpy(data)

fig = figure()
fig.circle(data[:, 0], data[:, 1])
show(fig)

In [9]:
class SNLayer(torch.nn.Module):
    '''y = W * x + b'''
    
    def __init__(self, dim_in, dim_out, magnitude=0.89, n_power_iterations=10):
        super().__init__()
        self.affine = spectral_norm(torch.nn.Linear(dim_in, dim_out), magnitude=magnitude,
                                                    n_power_iterations=n_power_iterations)
        
    def forward(self, X):
        return self.affine(X)
    

class InvResidualBlock(torch.nn.Module):
    
    @staticmethod
    def _approx_log_det(X, fn_outputs, niters):
        'Stochastic approximation of the log-determinant.'
        logdet = torch.zeros(len(X), device=fn_outputs.device, 
                             dtype=fn_outputs.dtype)
        V = torch.randn(*fn_outputs.shape, dtype=fn_outputs.dtype,
                        device=fn_outputs.device)
        W = V
        for k in range(1, niters + 1):
            W = torch.autograd.grad(fn_outputs, X, grad_outputs=W, 
                                    create_graph=True, 
                                    retain_graph=True)[0]
            logdet += (-1) ** (k+1) * (W * V).sum(dim=-1) / k
        return logdet

    
    def __init__(self, dim_in, dim_expand, fn=torch.nn.Tanh, depth=1):
        super().__init__()
        self.layer_in = SNLayer(dim_in, dim_expand)
        self.layer_mid = SNLayer(dim_expand, dim_expand)
        self.layer_out = SNLayer(dim_expand, dim_in)
        self.fn = fn()
    
    def forward(self, inputs, niters=10, nsamples=1):
        # inputs = (X, prev_log_det)
        X = inputs[0]
        logdet = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
        for i in range(nsamples):
            #X = X.clone().detach().requires_grad_(True)
            H = self.layer_out(self.fn(self.layer_mid(self.fn(self.layer_in(X)))))
            logdet += self._approx_log_det(X, H, niters=niters) / nsamples
        return X + H, inputs[1] + logdet
                                                

class TransformedNormal(torch.nn.Module):
    
    def __init__(self, iresnet):
        super().__init__()
        self.iresnet = iresnet
        
    def forward(self, X, niters=10, nsamples=1, logdet_scale=1.):
        logdet = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
        X.requires_grad_(True)
        for block in self.iresnet:
            X, logdet = block((X, logdet), niters=niters, nsamples=nsamples)
        
        # Std. Normal log-likelihood
        dim = X.shape[-1]
        log_p_z = -.5 * (dim * math.log(2 * math.pi) \
                         + (X**2).sum(dim=-1))
        
        return log_p_z + logdet_scale * logdet
        

def create_iresnet(dim_in, dim_expand, depth):
    iresnet = torch.nn.Sequential(*[InvResidualBlock(dim_in=dim_in, dim_expand=dim_expand) for i in range(depth)])
    return TransformedNormal(iresnet).double()

In [10]:
model = create_iresnet(dim_in=2, dim_expand=20, depth=3)

In [11]:
xy = np.mgrid[-10:10:100j, -10:10:100j].reshape(2, -1).T
XY = torch.tensor(xy, requires_grad=True)
llh_xy = model(XY, niters=10, nsamples=10).detach().exp().numpy()
llh_xy = llh_xy.reshape(100, 100)
llh_xy

fig = figure(x_range=(-10, 10), y_range=(-10, 10))
fig.image(image=[llh_xy.T], x=-10, y=-10, dh=20, dw=20, palette='Viridis256')
fig.circle(data[:, 0], data[:, 1], color='white')
show(fig)

In [12]:
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 5_000
avg_llhs = []
for i in range(epochs):
    optim.zero_grad()
    llh = model(X).mean()
    (-llh).backward()
    optim.step()
    avg_llhs.append(float(llh))
    
fig = figure()
fig.line(range(len(avg_llhs)), avg_llhs)
show(fig)

In [15]:
xy = np.mgrid[-10:10:100j, -10:10:100j].reshape(2, -1).T
XY = torch.from_numpy(xy)
llh_xy = model(XY, nsamples=20, logdet_scale=1.).detach().exp().numpy()
llh_xy = llh_xy.reshape(100, 100)
llh_xy

fig = figure(x_range=(-10, 10), y_range=(-10, 10))
fig.image(image=[llh_xy.T], x=-10, y=-10, dh=20, dw=20, palette='Viridis256')
fig.circle(data[:, 0], data[:, 1], alpha=1., color='white')
show(fig)