In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [24]:
import math

In [19]:
class Reduction(nn.Module):
    def __init__(self, scale, input_filters, is_final=False):
        super(Reduction, self).__init__()
        reduction_count = int(math.log(input_filters, 2)) - 2
        self.reductions = torch.nn.Sequential()
        for i in range(reduction_count):
            if i != reduction_count-1:
                self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
                    nn.Con`v2d(int(input_filters / math.pow(2, i)), int(input_filters / math.pow(2, i + 1)), 1, 1, 0, bias=ENABLE_BIAS),
                    activation_fn))
            else:
                if not is_final:
                    self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
                        nn.Conv2d(int(input_filters / math.pow(2, i)), int(input_filters / math.pow(2, i + 1)), 1, 1, 0, bias=ENABLE_BIAS)))
                else:
                    self.reductions.add_module("1x1_reduc_%d_%d" % (scale, i), nn.Sequential(
                        nn.Conv2d(int(input_filters / math.pow(2, i)), 1, 1, 1, 0, bias=ENABLE_BIAS), nn.Sigmoid()))

    def forward(self, ip):
        return self.reductions(ip)

In [36]:
class LPGLayer(nn.Module):
    def __init__(self, scale):
        super(LPGLayer, self).__init__()
        self.scale = scale
        self.u = torch.arange(self.scale).reshape([1, 1, self.scale]).float()
        self.v = torch.arange(int(self.scale)).reshape([1, self.scale, 1]).float()

    def forward(self, plane_eq):
        plane_eq_expanded = torch.repeat_interleave(plane_eq, int(self.scale), 2)
        plane_eq_expanded = torch.repeat_interleave(plane_eq_expanded, int(self.scale), 3)

        n1 = plane_eq_expanded[:, 0, :, :]
        n2 = plane_eq_expanded[:, 1, :, :]
        n3 = plane_eq_expanded[:, 2, :, :]
        n4 = plane_eq_expanded[:, 3, :, :]

        u = self.u.repeat(plane_eq.size(0), plane_eq.size(2) * int(self.scale), plane_eq.size(3))
        u = (u - (self.scale - 1) * 0.5) / self.scale

        v = self.v.repeat(plane_eq.size(0), plane_eq.size(2), plane_eq.size(3) * int(self.scale))
        v = (v - (self.scale - 1) * 0.5) / self.scale

        d = n4 / (n1 * u + n2 * v + n3)
        d = d.unsqueeze(1)
        return d

In [63]:
class LPGBlock(nn.Module):
    def __init__(self, scale, input_filters=128):
        super(LPGBlock, self).__init__()
        self.scale = scale

        self.reduction = Reduction(scale, input_filters)

        self.conv = nn.Conv2d(5, 3, 1, 1, 0)
        self.LPGLayer = LPGLayer(scale)

    def forward(self, input):
        input = self.reduction(input)

        plane_parameters = torch.zeros_like(input)
        input = self.conv(input)

        theta = input[:, 0, :, :].sigmoid() * 3.1415926535 / 6
        phi = input[:, 1, :, :].sigmoid() * 3.1415926535 * 2
        dist = input[:, 2, :, :].sigmoid() * MAX_DEPTH

        plane_parameters[:, 0, :, :] = torch.sin(theta) * torch.cos(phi)
        plane_parameters[:, 1, :, :] = torch.sin(theta) * torch.sin(phi)
        plane_parameters[:, 2, :, :] = torch.cos(theta)
        plane_parameters[:, 3, :, :] = dist

        plane_parameters[:, 0:3, :, :] = F.normalize(plane_parameters.clone()[:, 0:3, :, :], 2, 1)

        depth = self.LPGLayer(plane_parameters.float())
        return depth

In [64]:
MAX_DEPTH = 81
MAX_DEPTH = 81
DEPTH_OFFSET = 0.1 # This is used for ensuring depth prediction gets into positive range

USE_APEX = False  # Enable if you have GPU with Tensor Cores, otherwise doesnt really bring any benefits.
APEX_OPT_LEVEL = "O2"

BATCH_NORM_MOMENTUM = 0.005
ENABLE_BIAS = True
activation_fn = nn.ELU()

In [65]:
LPGBlock4 = LPGBlock(1, 4)

In [67]:
LPGBlock4(x).shape

torch.Size([4, 1, 352, 704])

In [68]:
reduction = Reduction(1, 4)

In [69]:
x = torch.rand([3, 4, 352, 704])

In [71]:
reduction(x).sigmoid()

tensor([[[[0.6199, 0.6629, 0.5168,  ..., 0.6218, 0.6811, 0.6018],
          [0.6344, 0.6472, 0.5368,  ..., 0.7207, 0.5295, 0.5785],
          [0.5695, 0.5420, 0.6487,  ..., 0.5319, 0.5907, 0.6410],
          ...,
          [0.6446, 0.6216, 0.5215,  ..., 0.5700, 0.5135, 0.5394],
          [0.5818, 0.7121, 0.6871,  ..., 0.7271, 0.7117, 0.6990],
          [0.7185, 0.5368, 0.5582,  ..., 0.6655, 0.5154, 0.7248]],

         [[0.5056, 0.7238, 0.5881,  ..., 0.7266, 0.7020, 0.6906],
          [0.5608, 0.7154, 0.5022,  ..., 0.5459, 0.6429, 0.6346],
          [0.5665, 0.5611, 0.6589,  ..., 0.5121, 0.5509, 0.6709],
          ...,
          [0.7095, 0.5400, 0.5371,  ..., 0.6107, 0.6527, 0.5869],
          [0.6618, 0.6109, 0.5155,  ..., 0.6573, 0.5714, 0.6273],
          [0.6636, 0.5847, 0.5493,  ..., 0.7166, 0.6289, 0.6497]],

         [[0.5708, 0.5820, 0.6117,  ..., 0.5183, 0.5054, 0.6664],
          [0.5726, 0.6866, 0.6817,  ..., 0.6516, 0.5238, 0.5403],
          [0.6498, 0.6354, 0.7232,  ..., 0