In [14]:
import imp
import timeit
import torch
import numpy as np
import ops
from utils import lattice

In [374]:
imp.reload(ops)

<module 'ops' from '/Users/theissjd/Documents/Berkeley/code/rf_pool/rf_pool/ops.py'>

In [375]:
# test ops.local_softmax
rf_u = torch.rand(1,10,16)
mask = torch.zeros_like(rf_u)
mask[:,:,:8] = 1.
out0 = ops.local_softmax(rf_u, -1, mask)
# manually apply softmax with mask
mask = -1. * torch.exp(np.inf * (1. - 2. * mask))
rf_u = torch.add(rf_u, mask)
out1 = torch.softmax(rf_u, -1)
assert torch.all(torch.eq(out0, out1))

In [376]:
# test ops.prob_max_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = ops.prob_max_pool(rf_u, out_shape)
# manually perform prob_max_pool
rf_u_0 = torch.cat([rf_u, torch.zeros(1,10,1)], -1)
rf_u_0_softmax = torch.softmax(rf_u_0, -1)
probs = torch.reshape(rf_u_0_softmax[:,:,:-1], out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=rf_u_0_softmax).sample()
samples = torch.reshape(samples[:,:,:-1], out_shape)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(torch.reshape(1. - rf_u_0_softmax[:,:,-1], (1,10,1,1)), samples)
assert torch.all(torch.eq(p_mean, p_probs))

In [377]:
# test ops.stochastic_max_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = ops.stochastic_max_pool(rf_u, out_shape)
# manually perform prob_max_pool
rf_u_softmax = torch.softmax(rf_u, -1)
probs = torch.reshape(rf_u_softmax, out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=rf_u_softmax).sample()
samples = torch.reshape(samples, out_shape)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(probs, samples)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = samples.clone()
assert torch.all(torch.eq(p_sample, p_samples))

In [378]:
# test ops.div_norm_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
n = 2.
sigma = 0.5
h_mean, h_sample, p_mean, p_sample = ops.div_norm_pool(rf_u, out_shape, n=n, sigma=sigma)
# manually perform div_norm_pool
rf_u_n = torch.pow(rf_u, n)
sigma_n = torch.pow(torch.as_tensor(sigma, dtype=rf_u.dtype), n)
probs = torch.div(rf_u_n, sigma_n + torch.sum(rf_u_n, dim=-1, keepdim=True))
assert torch.all(torch.eq(h_mean, torch.reshape(probs, out_shape)))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=torch.softmax(probs, -1)).sample()
assert torch.all(torch.eq(h_sample, torch.reshape(samples, out_shape)))
p_probs = torch.reshape(torch.mul(probs, samples), out_shape)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = torch.reshape(samples, out_shape)
assert torch.all(torch.eq(p_sample, p_samples))

In [379]:
# test ops.average_pool
rf_u = torch.rand(1,10,16)
mask = None
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = ops.average_pool(rf_u, out_shape, mask)
# manually perform average_pool
n_units = torch.as_tensor(rf_u.shape[-1], dtype=rf_u.dtype)
if type(mask) is torch.Tensor:
    n_units = torch.sum(mask, -1, keepdim=True)
probs = torch.div(rf_u, n_units)
assert torch.all(torch.eq(h_mean, torch.reshape(probs, out_shape)))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=ops.local_softmax(probs, -1, mask)).sample()
assert torch.all(torch.eq(h_sample, torch.reshape(samples, out_shape)))
p_probs = torch.mul(probs, samples)
assert torch.all(torch.eq(p_mean, torch.reshape(p_probs, out_shape)))
p_samples = samples.clone()
assert torch.all(torch.eq(p_sample, torch.reshape(p_samples, out_shape)))

In [380]:
# test ops.sum_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = ops.sum_pool(rf_u, out_shape)
# manually perform sum_pool
probs = torch.reshape(rf_u, out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
sum_val = torch.zeros_like(rf_u)
sum_val = torch.add(sum_val, torch.sum(rf_u, dim=-1, keepdim=True))
sum_val = torch.reshape(sum_val, out_shape)
samples = torch.reshape(torch.distributions.Multinomial(probs=torch.softmax(rf_u, -1)).sample(), out_shape)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(sum_val, samples)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = samples.clone()
assert torch.all(torch.eq(p_sample, p_samples))

In [381]:
# test ops.rf_pool defaults
u = torch.rand(1,10,16,16)
torch.random.manual_seed(0)
h_mean, h_sample, p_mean, p_sample = ops.rf_pool(u)
# manually perform ops.rf_pool with defaults
torch.random.manual_seed(0)
b = []
for r in range(2):
    for c in range(2):
        b.append(u[:,:,r::2,c::2].unsqueeze(-1))
b = torch.cat(b, -1)
tmp_probs, tmp_samples, tmp_p_probs, tmp_p_samples = ops.prob_max_pool(b, b.shape)
probs = torch.zeros_like(u)
samples = torch.zeros_like(u)
p_probs = torch.zeros(1,10,8,8)
p_samples = torch.zeros(1,10,8,8)
for r in range(2):
    for c in range(2):
        probs[:,:,r::2,c::2] = tmp_probs[:,:,:,:,r*2+c]
        samples[:,:,r::2,c::2] = tmp_samples[:,:,:,:,r*2+c]
        p_probs = torch.max(p_probs, tmp_p_probs[:,:,:,:,r*2+c])
        p_samples = torch.max(p_samples, tmp_p_samples[:,:,:,:,r*2+c])
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))

In [382]:
# test ops.rf_pool top-down, div_norm, block_size, pool_args
u = torch.rand(1,10,16,16)
t = torch.rand(1,10,4,4)
torch.random.manual_seed(0)
pool_type = 'div_norm'
block_size = (4,4)
h_mean, h_sample, p_mean, p_sample = ops.rf_pool(u, t, pool_type=pool_type, 
                                                 block_size=block_size, n=1., sigma=0.5)
# manually perform ops.rf_pool with top-down, div_norm, block_size, pool_args
b = []
for r in range(block_size[0]):
    for c in range(block_size[1]):
        u[:,:,r::block_size[0],c::block_size[1]].add_(t)
        b.append(u[:,:,r::block_size[0],c::block_size[1]].unsqueeze(-1))
b = torch.cat(b, -1)
# get tmp_probs, tmp_samples from div_norm_pool using pool_args
torch.random.manual_seed(0)
tmp_probs, tmp_samples, tmp_p_probs, tmp_p_samples = ops.div_norm_pool(b, b.shape, mask=None,
                                                                       n=1., sigma=0.5)
probs = torch.zeros_like(u)
samples = torch.zeros_like(u)
# get probs, samples, p_probs, p_samples
p_probs = torch.zeros_like(t)
p_samples = torch.zeros_like(t)
for r in range(block_size[0]):
    for c in range(block_size[1]):
        probs[:,:,r::block_size[0],c::block_size[1]] = tmp_probs[:,:,:,:,r*block_size[0]+c]
        samples[:,:,r::block_size[0],c::block_size[1]] = tmp_samples[:,:,:,:,r*block_size[0]+c]
        p_probs = torch.max(p_probs, tmp_p_probs[:,:,:,:,r*block_size[0]+c])
        p_samples = torch.max(p_samples, tmp_p_samples[:,:,:,:,r*block_size[0]+c])
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))

In [385]:
# test ops.rf_pool kernels with batch_size > 1
mu, sigma = lattice.init_uniform_lattice((8,8), 4, 4, 1.)
delta_mu = torch.rand(2,1,16,2)
delta_sigma = torch.rand(2,1,16,1)
mu = mu + delta_mu
sigma = sigma + delta_sigma
rfs = lattice.gaussian_kernel_lattice(mu, sigma, (16,16))
u = torch.rand(2,10,16,16)
torch.random.manual_seed(0)
h_mean, h_sample, p_mean, p_sample = ops.rf_pool(u, rfs=rfs, pool_type='sum', block_size=(1,1))
# manually perform pooling with gaussian rfs
u.unsqueeze_(2)
rf_u = torch.mul(u, rfs)
rf_index = torch.as_tensor(torch.gt(rfs, 1e-5), dtype=rf_u.dtype)
torch.random.manual_seed(0)
probs, samples, p_probs, p_samples = ops.sum_pool(rf_u.flatten(-2), rf_u.shape, rf_index.flatten(-2))
probs = torch.max(probs, -3)[0]
samples = torch.max(samples, -3)[0]
p_probs = torch.max(p_probs, -3)[0]
p_samples = torch.max(p_samples, -3)[0]
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))

In [386]:
# test time taken for indexing or adding mask
s = timeit.default_timer()
rfs = torch.add(torch.zeros_like(u), rfs)
g_u = torch.mul(u, rfs).permute(2,0,1,3,4)
rf_index = torch.gt(rfs, 1e-5).permute(2,0,1,3,4)
rf_u0 = torch.zeros_like(g_u) - np.inf
rf_u0[rf_index] = g_u[rf_index]
e = timeit.default_timer()
print('index:', e - s)

s = timeit.default_timer()
rfs = torch.add(torch.zeros_like(u), rfs)
g_u = torch.mul(u, rfs).permute(2,0,1,3,4)
# create rf_mask of 0s at rf and -inf elsewhere (optimized)
thr = 1e-5
rf_mask = torch.as_tensor(torch.le(rfs.permute(2,0,1,3,4), thr), dtype=g_u.dtype)
rf_mask = -1. * torch.exp(np.inf * (2. * rf_mask - 1.))
rf_u1 = torch.add(g_u, rf_mask)
e = timeit.default_timer()
print('add:', e - s)
assert torch.all(torch.eq(rf_u0, rf_u1))

index: 0.025917944964021444
add: 0.003499198006466031


In [387]:
# testing indexing vs. all at once for applying pool_fn across receptive fields
# for loop
n_kernels = 25
mask_thr = 1e-5
rf_kernels = torch.rand(2,10,n_kernels,28,28)
u_t = torch.rand(2,10,n_kernels,28,28)

s = timeit.default_timer()
h_mean = torch.zeros(2,10,28,28)
h_sample = torch.zeros(2,10,28,28)
p_mean = torch.zeros(2,10,28,28)
p_sample = torch.zeros(2,10,28,28)

rf_u = torch.mul(u_t, rf_kernels).permute(2,0,1,3,4)

# create rf_mask of receptive field kernels
rf_mask = torch.as_tensor(torch.gt(rf_kernels, mask_thr).permute(2,0,1,3,4),
                          dtype=rf_u.dtype)
rf_mask = torch.flatten(rf_mask, -2)
for u, rf in zip(rf_u, rf_mask):
    # apply pool function across image dims
    h_mean_i, h_sample_i, p_mean_i, p_sample_i = ops.prob_max_pool(u.flatten(-2), u.shape,
                                                             mask=rf)
    # max across receptive fields
    h_mean = torch.max(h_mean, h_mean_i)
    h_sample = torch.max(h_sample, h_sample_i)
    p_mean = torch.max(p_mean, p_mean_i)
    p_sample = torch.max(p_sample, p_sample_i)
    
e = timeit.default_timer()
print(e - s)

# all at once
rf_kernels = torch.rand(2,10,n_kernels,28,28)
u_t = torch.rand(2,10,n_kernels,28,28)

s = timeit.default_timer()
rf_u = torch.mul(u_t, rf_kernels)

# create rf_mask of receptive field kernels
rf_mask = torch.as_tensor(torch.gt(rf_kernels, mask_thr),
                          dtype=rf_u.dtype)

# apply pool function across image dims
h_mean, h_sample, p_mean, p_sample = ops.prob_max_pool(rf_u.flatten(-2), rf_u.shape,
                                                         mask=rf_mask.flatten(-2))
# max across receptive fields
h_mean = torch.max(h_mean, -3)[0]
h_sample = torch.max(h_sample, -3)[0]
p_mean = torch.max(p_mean, -3)[0]
p_sample = torch.max(p_sample, -3)[0]
    
e = timeit.default_timer()
print(e - s)

0.05273100297199562
0.04342709400225431
