In [6]:
import torch

def normalize_points(grid: torch.Tensor, K: torch.Tensor, scales: torch.Tensor):
    ''' 
    :param grid: coordinates (u, v), B x 2 x H x W
    :param K: intrinsics, B x 3 x 3
    :param scales: parameters of resizing of original image, B x 2
    '''    
    fx, fy, ox, oy = K[:, 0, 0], K[:, 1, 1], K[:, 0, 2], K[:, 1, 2]
    sx, sy = scales[ :,0], scales[:, 1]
    fx = fx * sx
    fy = fy * sy
    ox = ox * sx
    oy = ox * sy
    principal_point = torch.cat([ox[..., None], oy[..., None]], -1)[..., None, None]
    focal_length = torch.cat([fx[..., None], fy[..., None]], -1)[..., None, None]
    return (grid - principal_point) / focal_length

In [7]:
grid = torch.rand(3, 2, 40, 50)
K = torch.rand(3,3, 3)
scales = torch.rand(3, 2)


normalize_points(grid, K, scales)

tensor([[[[ 2.8709e+01,  1.9291e+01,  2.8918e+01,  ...,  5.4418e+00,
            1.2518e+01,  2.8570e+01],
          [ 7.4486e+00,  4.6796e+01,  3.9531e+01,  ...,  4.5706e+01,
            4.5604e+01,  4.5450e+01],
          [-2.8348e-01,  1.6179e+00,  4.3664e+01,  ...,  2.1528e+01,
            2.9537e+01,  3.0134e+01],
          ...,
          [ 2.2850e+01,  1.0731e+01,  8.3201e+00,  ...,  3.7056e+01,
            4.0439e+01,  1.5365e+01],
          [ 3.5669e+01,  4.3325e+01,  2.7519e+01,  ...,  3.7531e+01,
            4.2100e+01,  3.5181e+01],
          [ 5.0720e+01,  3.6655e+00,  2.6669e+01,  ...,  1.9353e+01,
            3.4374e+00,  5.0985e+01]],

         [[ 1.2045e-01,  1.0658e+00,  1.6406e-01,  ...,  1.2238e+00,
            3.6447e-01,  1.0487e+00],
          [ 7.7442e-01,  5.0132e-01,  3.8012e-01,  ...,  5.0854e-01,
            6.9147e-01,  8.1739e-01],
          [ 5.3171e-01,  7.8010e-01,  1.0373e-01,  ...,  3.4662e-01,
            8.2922e-03,  3.8844e-01],
          ...,
     