In [1]:
import torch
import rf
from utils import lattice

In [2]:
# test rf.average_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = rf.average_pool(rf_u, out_shape)
# manually perform average_pool
probs = torch.div(rf_u, rf_u.shape[-1])
assert torch.all(torch.eq(h_mean, torch.reshape(probs, out_shape)))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=torch.softmax(probs, -1)).sample()
assert torch.all(torch.eq(h_sample, torch.reshape(samples, out_shape)))
p_probs = torch.mul(probs, samples)
assert torch.all(torch.eq(p_mean, torch.reshape(p_probs, out_shape)))
p_samples = samples.clone()
assert torch.all(torch.eq(p_sample, torch.reshape(p_samples, out_shape)))

In [3]:
# test rf.div_norm_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
n = 2.
sigma = 0.5
h_mean, h_sample, p_mean, p_sample = rf.div_norm_pool(rf_u, out_shape, n=n, sigma=sigma)
# manually perform div_norm_pool
rf_u_n = torch.pow(rf_u, n)
sigma_n = torch.pow(torch.as_tensor(sigma, dtype=rf_u.dtype), n)
probs = torch.div(rf_u_n, sigma_n + torch.sum(rf_u_n, dim=-1, keepdim=True))
assert torch.all(torch.eq(h_mean, torch.reshape(probs, out_shape)))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=torch.softmax(probs, -1)).sample()
assert torch.all(torch.eq(h_sample, torch.reshape(samples, out_shape)))
p_probs = torch.reshape(torch.mul(probs, samples), out_shape)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = torch.reshape(samples, out_shape)
assert torch.all(torch.eq(p_sample, p_samples))

In [4]:
# test rf.prob_max_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = rf.prob_max_pool(rf_u, out_shape)
# manually perform prob_max_pool
rf_u_0 = torch.cat([rf_u, torch.zeros(1,10,1)], -1)
rf_u_0_softmax = torch.softmax(rf_u_0, -1)
probs = torch.reshape(rf_u_0_softmax[:,:,:-1], out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=rf_u_0_softmax).sample()
samples = torch.reshape(samples[:,:,:-1], out_shape)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(torch.reshape(1. - rf_u_0_softmax[:,:,-1], (1,10,1,1)), samples)
assert torch.all(torch.eq(p_mean, p_probs))

In [5]:
# test rf.stochastic_max_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = rf.stochastic_max_pool(rf_u, out_shape)
# manually perform prob_max_pool
rf_u_softmax = torch.softmax(rf_u, -1)
probs = torch.reshape(rf_u_softmax, out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
samples = torch.distributions.Multinomial(probs=rf_u_softmax).sample()
samples = torch.reshape(samples, out_shape)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(probs, samples)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = samples.clone()
assert torch.all(torch.eq(p_sample, p_samples))

In [6]:
# test rf.sum_pool
rf_u = torch.rand(1,10,16)
torch.random.manual_seed(0)
out_shape = (1,10,4,4)
h_mean, h_sample, p_mean, p_sample = rf.sum_pool(rf_u, out_shape)
# manually perform sum_pool
probs = torch.reshape(rf_u, out_shape)
assert torch.all(torch.eq(h_mean, probs))
torch.random.manual_seed(0)
samples = torch.zeros_like(rf_u)
samples.add_(torch.sum(rf_u, dim=-1, keepdim=True))
samples = torch.reshape(samples, out_shape)
sampled_pos = torch.reshape(torch.distributions.Multinomial(probs=torch.softmax(rf_u, -1)).sample(), out_shape)
samples = torch.mul(samples, sampled_pos)
assert torch.all(torch.eq(h_sample, samples))
p_probs = torch.mul(h_mean, sampled_pos)
assert torch.all(torch.eq(p_mean, p_probs))
p_samples = torch.mul(h_sample, sampled_pos)
assert torch.all(torch.eq(p_sample, p_samples))

In [7]:
# test rf.pool defaults #TODO: update with p_mean, p_sample
u = torch.rand(1,10,16,16)
torch.random.manual_seed(0)
h_mean, h_sample, p_mean, p_sample = rf.pool(u)
# manually perform rf.pool with defaults
torch.random.manual_seed(0)
b = []
for r in range(2):
    for c in range(2):
        b.append(u[:,:,r::2,c::2].unsqueeze(-1))
b = torch.cat(b, -1)
tmp_probs, tmp_samples, tmp_p_probs, tmp_p_samples = rf.prob_max_pool(b, b.shape)
probs = torch.zeros_like(u)
samples = torch.zeros_like(u)
p_probs = torch.zeros(1,10,8,8)
p_samples = torch.zeros(1,10,8,8)
for r in range(2):
    for c in range(2):
        probs[:,:,r::2,c::2] = tmp_probs[:,:,:,:,r*2+c]
        samples[:,:,r::2,c::2] = tmp_samples[:,:,:,:,r*2+c]
        p_probs = torch.max(p_probs, tmp_p_probs[:,:,:,:,r*2+c])
        p_samples = torch.max(p_samples, tmp_p_samples[:,:,:,:,r*2+c])
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))

In [8]:
# test rf.pool top-down, div_norm, block_size, pool_args #TODO: update with p_mean, p_sample
u = torch.rand(1,10,16,16)
t = torch.rand(1,10,4,4)
torch.random.manual_seed(0)
pool_type = 'div_norm'
block_size = (4,4)
pool_args = [1., 0.5]
h_mean, h_sample, p_mean, p_sample = rf.pool(u, t, pool_type=pool_type, 
                                             block_size=block_size, pool_args=pool_args)
# manually perform rf.pool with top-down, div_norm, block_size, pool_args
b = []
for r in range(block_size[0]):
    for c in range(block_size[1]):
        u[:,:,r::block_size[0],c::block_size[1]].add_(t)
        b.append(u[:,:,r::block_size[0],c::block_size[1]].unsqueeze(-1))
b = torch.cat(b, -1)
# get tmp_probs, tmp_samples from div_norm_pool using pool_args
torch.random.manual_seed(0)
tmp_probs, tmp_samples, tmp_p_probs, tmp_p_samples = rf.div_norm_pool(b, b.shape, *pool_args)
probs = torch.zeros_like(u)
samples = torch.zeros_like(u)
# get probs, samples, p_probs, p_samples
p_probs = torch.zeros_like(t)
p_samples = torch.zeros_like(t)
for r in range(block_size[0]):
    for c in range(block_size[1]):
        probs[:,:,r::block_size[0],c::block_size[1]] = tmp_probs[:,:,:,:,r*block_size[0]+c]
        samples[:,:,r::block_size[0],c::block_size[1]] = tmp_samples[:,:,:,:,r*block_size[0]+c]
        p_probs = torch.max(p_probs, tmp_p_probs[:,:,:,:,r*block_size[0]+c])
        p_samples = torch.max(p_samples, tmp_p_samples[:,:,:,:,r*block_size[0]+c])
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))

In [9]:
# test rf.pool kernels
mu, sigma = lattice.init_uniform_lattice((8,8), 4, 4, 1.)
rfs = lattice.gaussian_kernel_lattice(mu, sigma, (16,16))
u = torch.rand(1,10,16,16)
torch.random.manual_seed(0)
h_mean, h_sample, p_mean, p_sample = rf.pool(u, rfs=rfs, block_size=(1,1))
# manually perform pooling with gaussian rfs
torch.random.manual_seed(0)
probs = torch.zeros_like(u)
samples = torch.zeros_like(u)
p_probs = torch.zeros_like(u)
p_samples = torch.zeros_like(u)
rfs = rfs.reshape(-1, 1, 1, 16, 16)
g_u = torch.mul(u.unsqueeze(0), rfs)
rf_index = torch.gt(rfs, 1e-5)
rf_index = rf_index.repeat(1, 1, 10, 1, 1)
for i, rf_i in enumerate(rf_index):
    rf_u = g_u[i][rf_i]
    [
        probs[rf_i], samples[rf_i], 
        p_probs[rf_i], p_samples[rf_i]
    ] = rf.prob_max_pool(rf_u.reshape(1, 10, -1), rf_u.shape)
assert torch.all(torch.eq(h_mean, probs))
assert torch.all(torch.eq(h_sample, samples))
assert torch.all(torch.eq(p_mean, p_probs))
assert torch.all(torch.eq(p_sample, p_samples))