In [12]:
import torch

In [29]:
indices = torch.tensor([2, 0, 1, 2])
unique_indices = torch.unique_consecutive(indices)
print(unique_indices)

tensor([2, 0, 1, 2])


In [4]:
import torch
from torch.nn import functional as F

def get_inter_keypoint_loss(img_feats_kd, pts_feats_kd):
    """Calculate the inter-keypoint similarities, guide the student keypoint features to mimic the feature relationships between different N keypoints of the teacher’s"""
    C_img = img_feats_kd.shape[1] 
    C_pts = pts_feats_kd.shape[1]
    N = 10
    
    img_feats_kd = img_feats_kd.view(-1,C_img,N).permute(0,2,1).matmul(
        img_feats_kd.view(-1,C_img,N)) #-1,N,N
    pts_feats_kd = pts_feats_kd.view(-1,C_pts,N).permute(0,2,1).matmul(
        pts_feats_kd.view(-1,C_pts,N))
    
    img_feats_kd = F.normalize(img_feats_kd, dim=2)
    pts_feats_kd = F.normalize(pts_feats_kd, dim=2)
    
    loss_inter_keypoint = F.mse_loss(img_feats_kd, pts_feats_kd, reduction='none')
    loss_inter_keypoint = loss_inter_keypoint.sum(-1)
    loss_inter_keypoint = loss_inter_keypoint.mean()
    return loss_inter_keypoint


C_img = 8
C_pts = 12
N = 10

img_feats_kd = torch.randn(1, C_img, N)
pts_feats_kd = torch.randn(1, C_pts, N)

loss_inter_keypoint = get_inter_keypoint_loss(img_feats_kd, pts_feats_kd)
print(loss_inter_keypoint)


## torch_scatter

In [3]:
from torch_scatter import scatter_mean, scatter_add
import torch


x = torch.randint(3, (3, 4))
print(x, "x")
sorted, indices = torch.sort(x, -1)
print(sorted, "sorted")
print(indices, "indices")

data = torch.arange(12).reshape(3, 4).float()
print(data, 'data')

output = scatter_mean(data, x, dim=-1)
print(output, 'scatter_mean output')

sorted_data = torch.gather(data, -1, indices)
print(sorted_data)

output = scatter_mean(data, sorted, dim=-1)
print(output)

sorted_output = scatter_mean(sorted_data, sorted, dim=-1)
print(sorted_output, "sorted_output")




tensor([[0, 2, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 0, 2]]) x
tensor([[0, 0, 0, 2],
        [0, 0, 0, 1],
        [0, 1, 1, 2]]) sorted
tensor([[0, 2, 3, 1],
        [1, 2, 3, 0],
        [2, 0, 1, 3]]) indices
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]]) data
tensor([[ 1.6667,  0.0000,  1.0000],
        [ 6.0000,  4.0000,  0.0000],
        [10.0000,  8.5000, 11.0000]]) scatter_mean output
tensor([[ 0.,  2.,  3.,  1.],
        [ 5.,  6.,  7.,  4.],
        [10.,  8.,  9., 11.]])
tensor([[ 1.0000,  0.0000,  3.0000],
        [ 5.0000,  7.0000,  0.0000],
        [ 8.0000,  9.5000, 11.0000]])
tensor([[ 1.6667,  0.0000,  1.0000],
        [ 6.0000,  4.0000,  0.0000],
        [10.0000,  8.5000, 11.0000]]) sorted_output
