In [1]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

import matplotlib.pyplot as plt

# Spatial parameters sampling

Take as input a batch of matrices of logits (B, L, L), where L is the linear size of the map.
Returns [[x_1,y_1],...,[x_B, y_B]] (or [y,x]) integer coordinates of the sampled pixels, toghether with their log probs (shape (B,)) and the probs of all the pixels (shape (B,L^2)).
Note: Softmax is performed inside

In [26]:
def old_version(x, L, x_first=True, debug=True):
    B = x.shape[0]
    ### usually self.size instead of size and it is already known
    size = L
    ###
    if debug: print("x.shape: ", x.shape)
    x = x.reshape((x.shape[0],-1))
    if debug: print("x.shape: ", x.shape)
    log_probs = F.log_softmax(x, dim=(-1))
    probs = torch.exp(log_probs)
    if debug: 
        print("log_probs.shape: ", log_probs.shape)
        print("log_probs.shape (reshaped): ", log_probs.view(-1, size, size).shape)

    # assume squared space
    x_lin = torch.arange(size).unsqueeze(0)
    xx = x_lin.repeat(B,size,1)
    if debug: print("xx.shape: ", xx.shape)
    # yx 
    args = torch.cat([xx.view(-1,size,size,1), xx.permute(0,2,1).view(-1,size,size,1)], axis=3)
    if debug: print("args.shape (before reshaping): ", args.shape)
    args = args.reshape(B,-1,2)
    if debug: print("args.shape (after reshaping): ", args.shape)
    #print("args (after reshape): ", args)
    index = Categorical(probs).sample()
    arg = args[torch.arange(B), index].detach().numpy() # and this are the sampled coordinates
    #print("index: ", index) 
    arg_lst = [list(a)  for a in arg] # swap to xy
    #print("arg_lst: ", arg_lst)
    log_probs = log_probs.reshape(B, size, size)
    # CORRECT
    return arg_lst, log_probs[torch.arange(B), arg[:,1], arg[:,0]], probs 
    # WRONG
    # return arg_lst, log_probs[torch.arange(B), arg[:,0], arg[:,1]], probs 

In [22]:
def working_version(x, L, x_first=True, debug=True):
    ### usually self.size instead of size and it is already known
    size = L
    ###
    B = x.shape[0]
    if debug: print("x.shape: ", x.shape)
    x = x.reshape((x.shape[0],-1))
    if debug: print("x.shape: ", x.shape)
    log_probs = F.log_softmax(x, dim=(-1))
    probs = torch.exp(log_probs)
    if debug: 
        print("log_probs.shape: ", log_probs.shape)
        print("log_probs.shape (reshaped): ", log_probs.view(-1, size, size).shape)

    # assume squared space
    x_lin = torch.arange(size).unsqueeze(0)
    xx = x_lin.repeat(B,size,1)
    if debug: print("xx.shape: ", xx.shape)
    # yx 
    args = torch.cat([xx.permute(0,2,1).view(-1,size,size,1), xx.view(-1,size,size,1)], axis=3)
    if debug: print("args.shape (before reshaping): ", args.shape)
    args = args.reshape(B,-1,2)
    if debug: print("args.shape (after reshaping): ", args.shape)
    #print("args (after reshape): ", args)
    index = Categorical(probs).sample()
    arg = args[torch.arange(B), index].detach().numpy() # and this are the sampled coordinates
    #print("index: ", index) 
    arg_lst = [list([a[1],a[0]])  for a in arg] # swap to xy
    #print("arg_lst: ", arg_lst)
    log_probs = log_probs.reshape(B, size, size)
    return arg_lst, log_probs[torch.arange(B), arg[:,0], arg[:,1]], probs 

In [32]:
def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))

In [52]:
def clean_version(x, L, x_first=True, debug=True):
    size = L
    B = x.shape[0]
    x = x.reshape((x.shape[0],-1))
    log_probs = F.log_softmax(x, dim=(-1))
    probs = torch.exp(log_probs)
    index = Categorical(probs).sample()
    print("index: ", index)
    y, x = unravel_index(index, (size,size))
    print("y, x: ", y, x)
    arg_lst = [[xi.item(),yi.item()] for xi, yi in zip(x,y)]
    log_prob = log_probs[torch.arange(B), index]
    return arg_lst, log_prob, probs

## Test1

In [54]:
B = 1
L = 4
torch.manual_seed(1)
value = [[0.,1.,1.,0.],
         [0.,0.,0.,0.],
         [0.,0.,0.,0.],
         [0.,0.,0.,0.]]
logits = torch.tensor([value,value])
mask = (logits==0)
logits = logits.masked_fill((mask).bool(), float('-inf'))
logits = logits.reshape((logits.shape[0],-1))
print(logits)

# x first result should be (2,0)

tensor([[-inf, 1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, 1., 1., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]])


In [55]:
arg_lst, log_prob, probs = old_version(logits, L)
print("arg_lst: ", arg_lst)
print("log_prob: ", log_prob)
print("probs: ", probs)

x.shape:  torch.Size([2, 16])
x.shape:  torch.Size([2, 16])
log_probs.shape:  torch.Size([2, 16])
log_probs.shape (reshaped):  torch.Size([2, 4, 4])
xx.shape:  torch.Size([2, 4, 4])
args.shape (before reshaping):  torch.Size([2, 4, 4, 2])
args.shape (after reshaping):  torch.Size([2, 16, 2])
arg_lst:  [[1, 0], [1, 0]]
log_prob:  tensor([-0.6931, -0.6931])
probs:  tensor([[0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])


In [56]:
arg_lst, log_prob, probs = working_version(logits, L)
print("arg_lst: ", arg_lst)
print("log_prob: ", log_prob)
print("probs: ", probs)

x.shape:  torch.Size([2, 16])
x.shape:  torch.Size([2, 16])
log_probs.shape:  torch.Size([2, 16])
log_probs.shape (reshaped):  torch.Size([2, 4, 4])
xx.shape:  torch.Size([2, 4, 4])
args.shape (before reshaping):  torch.Size([2, 4, 4, 2])
args.shape (after reshaping):  torch.Size([2, 16, 2])
arg_lst:  [[1, 0], [1, 0]]
log_prob:  tensor([-0.6931, -0.6931])
probs:  tensor([[0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])


In [57]:
clean_version(logits, L)

index:  tensor([2, 1])
y, x:  tensor([0, 0]) tensor([2, 1])


([[2, 0], [1, 0]],
 tensor([-0.6931, -0.6931]),
 tensor([[0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]))