# HSIC bottleneck toy summary

### pseudo code

    train_hsic.py ->  train_misc.py - >   hsic.py       
        hsic_train()    hsic_objective()   hsic_normalized_cca()
        
    xout, [x0,x1,...,xout] = model(data)
        
    for xi, layer in [x0,x1,...,xout], [lay0, lay1,..., layn]:
        # xi, data, target .shape = (m, dims)
        hx = hisc(xi, data, σ)
        hy = hisc(xi, target, σ)

        loss = hx - λhy
        sgd_update(layer.weight, loss)
        
    hsic()
        Kx = e^-((X.X - 2XxXT + X.XT^2) /var ) x (1 - I/m)
        Rx = K x (ϵ+K^-1)
        
    hsic = Sum(Rx . RyT)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
torch.backends.cudnn.benchmark=True

In [2]:
def distmat_old(X):
    """ distance matrix
    """
    r = torch.sum(X*X, 1)
    r = r.view([-1, 1])
    
    a = torch.mm(X, torch.transpose(X,0,1))
    
    D = r.expand_as(a) - 2*a +  torch.transpose(r,0,1).expand_as(a)
    D = torch.abs(D)
    return D

In [3]:
def distmat(x, requires_grad=False):
    _cloned = False
    if x.requires_grad and not requires_grad:
        x = x.clone().detach()
        _cloned = True
    out = torch.mm(x, x.T).mul_(-2.0)
    out.add_((x*x).sum(1, keepdim=True))
    out.add_((x*x).sum(1, keepdim=True).T)
    if _cloned: 
        del x
    return out.abs_()

### distmat with ridiculous batch size 256 * 256 
    * distmat(d) alocates 16G, distmat_old(d) 32G

In [4]:
gen_data = lambda m, device="cpu": torch.randn(m*m, device=device).view(-1,1)

# dlen = 256
# data = gen_data(dlen, "cuda") # hsic gets data passed as a blockd
# print(data.shape)
# # torch.Size([65536, 1])

# x = distmat(data) # ok
#del x
#torch.cuda.empty_cache()

# x = distmatold(data) # fail to allocate
#del x
#torch.cuda.empty_cache()

### speed test 

In [5]:
data = gen_data(5, "cuda")

In [6]:
# %timeit x=distmat2b(d); del x
# 21.6 µs ± 98.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [7]:
# %timeit x=distmat_old(d); del x
# 27.8 µs ± 454 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### validate behavior, run model

In [8]:
import sys
hsicpath = "/home/z/work/HSIC-bottleneck/source"
if hsicpath not in sys.path:
    sys.path.append(hsicpath)

In [9]:
import hsicbt
from hsicbt.model.mhconv import ModelConv
from hsicbt.model.mresconv import ModelResConv
# from HSIC-bottleneck/config/resconv-hsicbt.yaml
sigma = 5.

In [10]:
M = ModelResConv(last_hidden_width=512)
M.to(device='cuda')
t = torch.randn(512,1,28,28, device="cuda") # standard MNIST
# for different image size pass: in_width=h*w

output, hiddens = M(t)

In [11]:
print("dims of latent data", [len(h.shape) for h in hiddens])
[h.shape for h in hiddens]#, output.shape

dims of latent data [4, 4, 4, 4, 4, 4, 2]


[torch.Size([512, 64, 12, 12]),
 torch.Size([512, 64, 12, 12]),
 torch.Size([512, 64, 12, 12]),
 torch.Size([512, 64, 12, 12]),
 torch.Size([512, 64, 12, 12]),
 torch.Size([512, 64, 12, 12]),
 torch.Size([512, 512])]

In [12]:
i=0
# to run hsic
h = hiddens[i].view(-1, np.prod(hiddens[i].size()[1:]))
h_data = t.view(-1, np.prod(t.size()[1:]))
print("input size ", tuple(t.shape),"\t->", tuple(h_data.shape))
print("hidden size", tuple(hiddens[i].shape), "\t->", tuple(h.shape))

input size  (512, 1, 28, 28) 	-> (512, 784)
hidden size (512, 64, 12, 12) 	-> (512, 9216)


### check values of distmat() vs distmat_old(): ok

In [13]:
do = distmat_old(h)
d = distmat(h, requires_grad=False)
print("distmat(h) == distmat_old(h): ",(d==do).all().item(), tuple(d.shape))

distmat(h) == distmat_old(h):  True (512, 512)


In [14]:
# distmat(x, requires_grad=False) # removes gradient
print("hidden[i].requires_grad\t\t", h.requires_grad, "\tdevice", h.device)
print("distmat_old(h).requires_grad\t", do.requires_grad, "\tdevice", do.device)
print("distmat(h, False).requires_grad\t",d.requires_grad, "\tdevice", d.device)

hidden[i].requires_grad		 True 	device cuda:0
distmat_old(h).requires_grad	 True 	device cuda:0
distmat(h, False).requires_grad	 False 	device cuda:0


In [15]:
m, dim = h.shape
variance = (2.* sigma * sigma* dim)
k = torch.exp(-d / variance)

### use in place operations

In [16]:
torch.exp_(d.mul_(-1.0/variance))

tensor([[1.0000, 0.9864, 0.9867,  ..., 0.9852, 0.9849, 0.9873],
        [0.9864, 1.0000, 0.9863,  ..., 0.9840, 0.9857, 0.9872],
        [0.9867, 0.9863, 1.0000,  ..., 0.9855, 0.9868, 0.9868],
        ...,
        [0.9852, 0.9840, 0.9855,  ..., 1.0000, 0.9853, 0.9861],
        [0.9849, 0.9857, 0.9868,  ..., 0.9853, 1.0000, 0.9862],
        [0.9873, 0.9872, 0.9868,  ..., 0.9861, 0.9862, 1.0000]],
       device='cuda:0')

In [17]:
print("e^(-d/var) in place is equal to old:", (d==k).all().item())

e^(-d/var) in place is equal to old: True


In [18]:
def kernelmat(X, sigma=None, requires_grad=False):
    """ kernel matrix baker
        Args
            X             (tensor_ shape (batchsize, datadimension)
            sigma         (float [None]) from config
            requires_grad (bool [False]) removes gradient from output
    """
    m, dim = X.size()
    H = torch.eye(m, device=X.device).sub_(1/m) 
    Kx = distmat(X, requires_grad=requires_grad)

    if sigma:
        variance = 2.*sigma*sigma*dim  
        torch.exp_(Kx.mul_(-1.0/variance))
    else:
        try:
            sx = sigma_estimation(X, X)
            variance = 2.*sx*sx
            torch.exp_(Kx.mul_(-1.0/variance))
        except RuntimeError as e:
            raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format(
                sx, torch.max(X), torch.min(X)))

    Kxc =  torch.mm(Kx, H)
    del H
    del Kx
    return Kxc

In [19]:
def kernelmat_old(X, sigma=None, debug=False, fixdevices=True):
    """ kernel matrix baker
    """
    m = int(X.size()[0])
    H = (torch.eye(m) - (1./m) * torch.ones([m, m]))
    if fixdevices:
        H = H.to(device=X.device)
    Dxx = distmat_old(X)

    if sigma:
        variance = 2.*sigma*sigma*X.size()[1]
        if fixdevices:
            Kx = torch.exp( -Dxx / variance)
        else:
            Kx = torch.exp( -Dxx / variance).type(torch.FloatTensor)    # kernel matrices
        if debug:
            # print(sigma, torch.mean(Kx), torch.max(Kx), torch.min(Kx))
            print("X    grad", X.requires_grad, ", device", X.device)
            print("Dx   grad", Dxx.requires_grad, ", device", Dxx.device)
            print("Kx   grad", Kx.requires_grad, ", device", Kx.device, "<--.type(torch.FloatTensor) casts a new tensor") 
    else:
        try:
            sx = sigma_estimation(X, X)
            if fixdevices:
                Kx = torch.exp( -Dxx / (2.*sx*sx))
            else:
                Kx = torch.exp( -Dxx / (2.*sx*sx)).type(torch.FloatTensor)
        except RuntimeError as e:
            raise RuntimeError("Unstable sigma {} with maximum/minimum input ({},{})".format(
                sx, torch.max(X), torch.min(X)))
     
    Kxc = torch.mm(Kx, H)

    return Kxc

### test kernls

In [20]:
kxold = kernelmat_old(h, sigma, debug=True)
print("Kxc  grad", kxold.requires_grad, ", device", kxold.device)

X    grad True , device cuda:0
Dx   grad True , device cuda:0
Kx   grad True , device cuda:0 <--.type(torch.FloatTensor) casts a new tensor
Kxc  grad True , device cuda:0


In [21]:
kx = kernelmat(h, sigma)
print("Kxc  grad", kx.requires_grad, ", device", kx.device)
kx = kernelmat(h, sigma, requires_grad=True)
print("Kxc  grad", kx.requires_grad, ", device", kx.device)

Kxc  grad False , device cuda:0
Kxc  grad True , device cuda:0


In [22]:
# %timeit kxold = kernelmat_old(h, sigma)
# # 2.03 ms ± 49.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [23]:
# %timeit kx = kernelmat(h, sigma)
# # 817 µs ± 23.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [24]:
e = 1e-6
print("|K_old - K| > 1e-6", (torch.abs(kx.clone().detach().cpu()- kxold.clone().detach().cpu()) > e).all().item())

|K_old - K| > 1e-6 False


### test hsic

In [25]:
def hsic_normalized_cca_old(x, y, sigma, use_cuda=True, to_numpy=True, fixdevices=True):
    """
    """
    m = int(x.size()[0])
    Kxc = kernelmat_old(x, sigma=sigma, fixdevices=fixdevices)
    Kyc = kernelmat_old(y, sigma=sigma, fixdevices=fixdevices)

    epsilon = 1E-5
    K_I = torch.eye(m)
    if fixdevices:
        K_I = K_I.to(device=x.device)
    Kxc_i = torch.inverse(Kxc + epsilon*m*K_I)
    Kyc_i = torch.inverse(Kyc + epsilon*m*K_I)
    Rx = (Kxc.mm(Kxc_i))
    Ry = (Kyc.mm(Kyc_i))
    Pxy = torch.sum(torch.mul(Rx, Ry.t()))

    return Pxy

In [26]:
def hsic_normalized_cca(x, y, sigma=None, requires_grad=True):
    """ reuse tensors, cleanup, maintains device, cleans grad
        x, y of shape (num_batches, -1)
    """
    epsilon = 1E-5
    m = x.size()[0]
    K_I = torch.eye(m, device=x.device).mul_(epsilon*m)

    Kc = kernelmat(x, sigma=sigma, requires_grad=requires_grad)
    
    Rx = Kc.mm(Kc.add(K_I).inverse())

    Kc = kernelmat(y, sigma=sigma, requires_grad=requires_grad)
    Ry = Kc.mm(Kc.add(K_I).inverse())

    out = Rx.mul_(Ry.t()).sum()
    

    del Rx
    del Ry
    del Kc
    del K_I
    return out

### test new hsic

In [27]:
M = ModelResConv(last_hidden_width=512) # for different image size pass: in_width=h*w
M.to(device='cuda')
t = torch.randn(512,1,28,28, device="cuda") # standard MNIST

In [28]:
output, hiddens = M(t)

i=0
# to run hsic
h = hiddens[i].view(-1, np.prod(hiddens[i].size()[1:]))
h_data = t.view(-1, np.prod(t.size()[1:]))

In [29]:
hx = hsic_normalized_cca(h, h_data, sigma, True)
hx.backward()
print(hx)

tensor(277.2852, device='cuda:0', grad_fn=<SumBackward0>)


### test old hsic, fixingdevices

In [30]:
output, hiddens = M(t)

i=0
h = hiddens[i].view(-1, np.prod(hiddens[i].size()[1:]))
h_data = t.view(-1, np.prod(t.size()[1:]))

In [31]:
hxo = hsic_normalized_cca_old(h, h_data, sigma, fixdevices=True)
#hxo.backward()
print(hxo)

tensor(277.2852, device='cuda:0', grad_fn=<SumBackward0>)


### test old hsic, whitout fixing devices

In [32]:
output, hiddens = M(t)

i=0
h = hiddens[i].view(-1, np.prod(hiddens[i].size()[1:]))
h_data = t.view(-1, np.prod(t.size()[1:]))

In [33]:
hxo = hsic_normalized_cca_old(h, h_data, sigma, fixdevices=False)
#hxo.backward()
print(hxo)

tensor(277.2774, grad_fn=<SumBackward0>)
