In [1]:
import torch
import torch.nn as nn

In [9]:
class local_planar_guidance(nn.Module):
    def __init__(self, upratio):
        super(local_planar_guidance, self).__init__()
        self.upratio = upratio
        self.u = torch.arange(self.upratio).reshape([1, 1, self.upratio]).float()
        self.v = torch.arange(int(self.upratio)).reshape([1, self.upratio, 1]).float()
        self.upratio = float(upratio)

    def forward(self, plane_eq, focal):
        plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.upratio), 2)
        plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded, int(self.upratio), 3)
        n1 = plane_eq_expanded[:, 0, :, :]
        n2 = plane_eq_expanded[:, 1, :, :]
        n3 = plane_eq_expanded[:, 2, :, :]
        n4 = plane_eq_expanded[:, 3, :, :]
        
        u = self.u.repeat(plane_eq.size(0), plane_eq.size(2) * int(self.upratio), plane_eq.size(3))
        u = (u - (self.upratio - 1) * 0.5) / self.upratio
        
        v = self.v.repeat(plane_eq.size(0), plane_eq.size(2), plane_eq.size(3) * int(self.upratio))
        v = (v - (self.upratio - 1) * 0.5) / self.upratio

        return n4 / (n1 * u + n2 * v + n3)

In [18]:
lpg = local_planar_guidance(2)

In [19]:
input = torch.rand((3, 64, 56, 112))

In [20]:
out = lpg(input, 1)

In [21]:
out.shape

torch.Size([3, 112, 224])

In [22]:
out

tensor([[[  0.2193,   0.1056,   3.1676,  ...,   0.8431,   0.9337,   0.7985],
         [  0.1284,   0.0788,   1.0557,  ...,   0.6134,   0.5277,   0.4817],
         [ -1.0997,   2.3664,  -3.0942,  ...,   1.4981,  -0.5578,   7.3888],
         ...,
         [  0.8632,   0.7793,   1.5923,  ...,   4.2318,   0.5301,   0.3695],
         [  4.8367,   1.6146,   0.5885,  ...,   1.2237,   1.1499,   0.8050],
         [  2.0622,   1.1142,   0.2089,  ...,   0.7335,   0.4999,   0.4214]],

        [[ -6.3920,   0.7919,   1.5745,  ...,   0.1941,   3.9621,   1.3609],
         [  2.3017,   0.5395,   1.5041,  ...,   0.0559,   1.6003,   0.9031],
         [  3.8978,   3.7995,   3.9643,  ...,   4.8488,   9.1098,   1.6297],
         ...,
         [  1.8510,   0.8313,   1.2431,  ...,   0.2215,   1.2058,   0.9787],
         [  5.5132,   4.6277,   2.4093,  ...,   3.4904,  -1.6510,   7.1536],
         [  2.8292,   2.5762,   1.1768,  ...,   3.0055,  12.4282,   1.2107]],

        [[ -3.2323,  47.8861, -11.9774,  ...

In [23]:
torch.mean(out)

tensor(-1.4884)