In [None]:
#default_exp modules.pillarcalc

In [None]:
#export
import torch
import torch.nn as nn


In [None]:
#export

In [None]:
#export
class PillarCalc(nn.Module):

    def __init__(self, pillars_cfg: dict):
        super(PillarCalc, self).__init__()

        self.pillars_cfg = pillars_cfg
        self.min = torch.cuda.FloatTensor([self.pillars_cfg.getfloat('x_min'),
                                           self.pillars_cfg.getfloat('y_min')])
        self.step = torch.cuda.FloatTensor([self.pillars_cfg.getfloat('x_step'),
                                            self.pillars_cfg.getfloat('y_step')])
        self.z_center = torch.cuda.FloatTensor([(self.pillars_cfg.getfloat('z_max') -
                                                 self.pillars_cfg.getfloat('z_min')) / 2.0])

    def _pillar_centers_from_index(self, xy_index: torch.tensor):
        """
            converts the pillar bounds into centers.  Pillars center shape must be (pillar_nbr, 2] with the
            last dimension being [x_min, y_min] for each pillar
        """
        # bring z center on shape from xy_min for concatenation
        bs, n_p, n = xy_index.shape
        z_center = self.z_center.unsqueeze(0).expand(bs, n_p, -1)

        # The actual pillar boundaries (min has to be added again)
        xy_index = xy_index * self.step + self.min
        xy_index.add_(0.5 * self.step)
        xy_index = torch.cat((xy_index, z_center), dim=2)

        return xy_index

    def __call__(self, pillars: torch.Tensor, pillar_index: torch.Tensor, dt = torch.float32):
        """Returns the tensor with the given and all the calculated attributes.
            :param pillars:
            :param pillar_index:
            :param dt: datatype for torch tensors

            :returns:
        """
        # create mask for calculation because already zero padded
        centers = self._pillar_centers_from_index(pillar_index)
        mask = (pillars != 0)[:,:,:,:3]
        val = pillars[:,:,:,:3]

        # calculate the mean
        mean = val.mul_(mask).sum(dim=2)
        mean /= mask.sum(dim=2)

        # calculate difference to mean
        mean = mean.unsqueeze(2).expand(-1, -1, val.shape[2], -1).clone()
        mean *= -1 * mask
        mean += val

        # replace the NaN with zeros
        mean[torch.isnan(mean)] = 0

        # calculate difference to centers
        centers = centers.unsqueeze(2).expand(-1, -1, val.shape[2], -1).clone()
        centers *= -1 * mask
        centers += val
        centers = centers[:,:,:,:2]

        return torch.cat((pillars, mean, centers), dim=3).permute(0, 3, 1, 2)