In [1]:
import torch
from torch import Tensor
import torch.nn.functional as F
from torch import nn

In [6]:
def batch_from_list(tensors):
    pos_list = [t.shape[0] for t in tensors ]
    batch = torch.cat(tensors, dim=0)
    return batch, pos_list

In [99]:
def preprocess_poses(poses, features_dim):
    '''T[batch_size, K, 3] -> T[batch_size, 2K, D, D]'''
    batch_size, K, _ = poses.shape
    features_poses = []
    visibilities = []
    for keypoint in poses:
        x, y = keypoint[:,0], keypoint[:,1]
        visibilities.append(keypoint[:,2].unsqueeze(0))
        features = []
        for x_val, y_val in zip(x, y):
            features.append(torch.full((1,1,features_dim,features_dim), x_val))
            features.append(torch.full((1,1,features_dim,features_dim), y_val))
        features = torch.cat(features, dim=1)
        features_poses.append(features)
    features_poses = torch.cat(features_poses, dim=0)
    visibilities = torch.cat(visibilities, dim=0).unsqueeze(-1)
    return features_poses, visibilities
    
    
#     for k in range(poses.shape[1]):
#         x = torch.cat([
#             torch.full((batch_size, 1, features_dim, features_dim), v) 
#             for v in poses[:,k,0]
#         ], dim=1)
#         y = torch.cat([
#             torch.full((batch_size, 1, features_dim, features_dim), v) 
#             for v in poses[:,k,1]
#         ], dim=1)
#         print(x.shape, y.shape)

In [128]:
torch.ones(2,5)/(2*torch.ones(2,1))

tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]])

In [100]:
features = [torch.randn(10, 256, 7, 7) for _ in range(15)]
batch_features, pos_list = batch_from_list(features)
batch_features.shape

torch.Size([150, 256, 7, 7])

In [121]:
list_features = list(batch_features.split(pos_list, 0))
len(list_features), list_features[0].shape

(15, torch.Size([10, 256, 7, 7]))

In [101]:
batch_poses = torch.randn(150, 14, 3)

In [102]:
f, v = preprocess_poses(batch_poses,7)
v.shape

torch.Size([150, 14, 1])

In [14]:
batch_size, K, _ = batch_poses.shape
batch_size

150

In [114]:
f.shape
p = F.avg_pool2d(f, 7)
px = torch.cat([p[:,i] for i in range(0, p.shape[1], 2)], dim=1)
py = torch.cat([p[:,i] for i in range(1, p.shape[1]+1, 2)], dim=1)
p = torch.cat([px,py], dim=-1)

In [115]:
px.shape, py.shape, p.shape

(torch.Size([150, 14, 1]), torch.Size([150, 14, 1]), torch.Size([150, 14, 2]))

In [5]:
pooled = torch.mean(poses_features, dim=[3])
pooled.size()

torch.Size([256, 28, 7])

In [132]:
init_poses[0]

tensor([[0., 0., 1.],
        [0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 1.]])

In [131]:
init_poses = torch.Tensor([[[0,0,1],
                            [0,1,1],
                            [1,0,1],
                            [1,1,1]]])
target_poses = 0.5*init_poses

In [144]:
v = init_poses[:,:,2].unsqueeze(-1)
num_iterations = 3
for it in range(1, num_iterations+1):
    w = it/(num_iterations+1)
    tmp = torch.lerp(input=init_poses[:,:,0:2], end=target_poses[:,:,0:2], weight=w)
    print(tmp.shape, v.shape)
    tmp = torch.cat([tmp, v], dim=-1)
    print(tmp)

torch.Size([1, 4, 2]) torch.Size([1, 4, 1])
tensor([[[0.0000, 0.0000, 1.0000],
         [0.0000, 0.8750, 1.0000],
         [0.8750, 0.0000, 1.0000],
         [0.8750, 0.8750, 1.0000]]])
torch.Size([1, 4, 2]) torch.Size([1, 4, 1])
tensor([[[0.0000, 0.0000, 1.0000],
         [0.0000, 0.7500, 1.0000],
         [0.7500, 0.0000, 1.0000],
         [0.7500, 0.7500, 1.0000]]])
torch.Size([1, 4, 2]) torch.Size([1, 4, 1])
tensor([[[0.0000, 0.0000, 1.0000],
         [0.0000, 0.6250, 1.0000],
         [0.6250, 0.0000, 1.0000],
         [0.6250, 0.6250, 1.0000]]])


In [148]:
bbx = [45, 2, 85, 85]
area = abs(bbx[2]-bbx[0])*abs(bbx[3]-bbx[1])
area

3320

In [201]:
input_poses = torch.cat([init_poses, init_poses], dim=0)
target_poses = torch.cat([target_poses,target_poses], dim=0)
areas = torch.randn(input_poses.shape[0], 1)
weights = torch.randn(input_poses.shape[1])

In [202]:
areas.shape, weights.shape

(torch.Size([2, 1]), torch.Size([4]))

In [203]:
input_x, input_y = input_poses[:,:,0], input_poses[:,:,1]
target_x, target_y = target_poses[:,:,0], target_poses[:,:,1]
visibilities = target_poses[:,:,2].to(torch.uint8)
offsets = (input_x-target_x)**2 + (input_y-target_y)**2
offsets /= areas
for i in range(offsets.shape[0]):
    print(offsets[i].shape)
    offsets[i] = offsets[i]/(2*weights)
    print(offsets[i])
oks = torch.exp(exponent)

torch.Size([4])
tensor([ 0.0000, -0.6091,  0.4357, -0.5313])
torch.Size([4])
tensor([ 0.0000, -0.3583,  0.2563, -0.3126])


In [204]:
oks

tensor([[1.0000, 1.4055, 1.4055, 1.9754],
        [1.0000, 1.0690, 1.0690, 1.1428],
        [1.0000, 0.9796, 0.9796, 0.9597],
        [1.0000, 0.9640, 0.9640, 0.9293]])

In [205]:
oks.shape

torch.Size([4, 4])

In [261]:
f = lambda x, **y: x

In [265]:
f

<function __main__.<lambda>(x, **y)>

In [264]:
f(oks)

tensor([[1.0000, 1.4055, 1.4055, 1.9754],
        [1.0000, 1.0690, 1.0690, 1.1428],
        [1.0000, 0.9796, 0.9796, 0.9597],
        [1.0000, 0.9640, 0.9640, 0.9293]])

In [284]:
class Oks(nn.Module):
    def __init__(self, weights: Tensor, normalize: bool):
        super().__init__()
        self.weights = weights
        self.normalize = normalize
        
    def forward(
        self, 
        input_poses: Tensor, 
        target_poses: Tensor, 
        areas: Tensor=None
    ):
        dists = ((input_poses[:,:,0:2] - target_poses[:,:,0:2])**2).sum(dim=-1)
        visible = target_poses[:,:,2].to(torch.bool)
        if self.normalize:
            assert areas is not None, f'`areas` is required when `normalize=True`'
            if len(areas.shape) < 2: areas = areas.unsqueeze(-1)
            dists /= areas
        if len(self.weights.shape) < 2: self.weights = self.weights.unsqueeze(0)
        dists /= 2*self.weights
        dists[~visible] = 0.
        oks = torch.exp(-dists).mean(1)
        return oks

In [285]:
K = 14
N = 100

input_poses = torch.randn(N, K, 3)
target_poses = torch.randn(N, K, 3)
areas = torch.randn(N)
weights = torch.full((K,), 1/K)

In [286]:
oks_fn = Oks(weights, True)

In [287]:
oks_fn(input_poses, target_poses, areas)

tensor([       inf, 6.8006e+15, 1.2298e-01, 1.3482e-01, 2.8744e-02,        inf,
        2.2375e-02,        inf, 9.8292e-02, 8.2052e-02, 2.6378e-02, 2.4942e+33,
        2.5937e+31, 4.2028e-05, 2.2975e+28,        inf, 1.0559e-01,        inf,
               inf,        inf,        inf,        inf, 6.4847e-04,        inf,
        4.2245e-02, 5.8403e+16,        inf, 2.4163e-25, 2.5934e+21, 6.2382e-03,
        4.0754e-05,        inf, 2.7037e-03, 2.2665e+27,        inf, 0.0000e+00,
        9.5822e-04,        inf,        inf, 2.3654e-03, 2.0441e-27, 1.3375e-05,
        1.2480e+21, 7.7099e-05, 1.1209e-02, 2.2767e-06, 1.0175e+22, 3.1317e-04,
        3.1393e-02,        inf, 1.0848e-02, 1.6526e-02, 1.4833e-03, 3.3481e-06,
        1.5458e-01, 7.2229e-04,        inf, 9.5561e+21,        inf, 9.5622e-03,
        6.7893e+29,        inf, 1.5659e-02, 6.8765e-08, 6.8272e-03,        inf,
        1.2151e+24, 1.2326e-08, 7.7308e-08,        inf, 1.7343e-04,        inf,
        8.0377e-06, 1.4623e+22, 1.9225e-

In [266]:
f((2, 2))

(2, 2)