In [17]:
import sys
sys.path.append('/home/daniel/Documents/Uni/MT/poi-prediction/')
from src.models.SubmDenseNet import SubmDenseNet
from utils.heatmap_utils import heatmaps_to_coords

In [7]:
import torch
import torch.nn as nn

class SoftArgmax3D(nn.Module):
    def __init__(self):
        super(SoftArgmax3D, self).__init__()

    def forward(self, heatmap):
        """
        Apply the soft-argmax operation on a 3D heatmap.
        
        Args:
            heatmap (torch.Tensor): Input tensor of shape (b, n, h, w, d)

        Returns:
            torch.Tensor: Soft-argmax coordinates of shape (b, n, 3)
        """
        # Apply softmax to convert heatmaps into probability distributions
        batch_size, num_maps, height, width, depth = heatmap.shape
        probs = torch.softmax(heatmap.view(batch_size, num_maps, -1), dim=-1)
        probs = probs.view(batch_size, num_maps, height, width, depth)

        # Create coordinate grids for each dimension
        lin_h = torch.linspace(0, height - 1, steps=height, device=heatmap.device)
        lin_w = torch.linspace(0, width - 1, steps=width, device=heatmap.device)
        lin_d = torch.linspace(0, depth - 1, steps=depth, device=heatmap.device)

        # Expand grids to match batch size and number of maps
        grid_h = lin_h.view(1, 1, height, 1, 1).expand(batch_size, num_maps, -1, width, depth)
        grid_w = lin_w.view(1, 1, 1, width, 1).expand(batch_size, num_maps, height, -1, depth)
        grid_d = lin_d.view(1, 1, 1, 1, depth).expand(batch_size, num_maps, height, width, -1)

        # Compute the soft-argmax coordinates
        soft_argmax_h = torch.sum(probs * grid_h, dim=[2, 3, 4])
        soft_argmax_w = torch.sum(probs * grid_w, dim=[2, 3, 4])
        soft_argmax_d = torch.sum(probs * grid_d, dim=[2, 3, 4])

        # Stack results to get coordinates
        coords = torch.stack([soft_argmax_h, soft_argmax_w, soft_argmax_d], dim=-1)
        
        return coords

# Example of usage
# Initialize module
soft_argmax = SoftArgmax3D()

# Create a dummy heatmap (b=1, n=2, h=10, w=10, d=10)
heatmap = -torch.ones(1, 2, 10, 10, 10) * 100
heatmap[0, 0, 5, 5, 5] = 1
heatmap[0, 1, 2, 3, 4] = 1

# Compute soft-argmax
coords = soft_argmax(heatmap)
print("Soft-Argmax Coordinates:", coords)

Soft-Argmax Coordinates: tensor([[[5., 5., 5.],
         [2., 3., 4.]]])


In [18]:
import spconv.pytorch as spconv
import numpy as np
from spconv.pytorch import SparseSequential, SparseModule
import torch

In [19]:
sdn = SubmDenseNet(
    in_channels=1,
    n_landmarks=10,
    feature_l = 0
)

In [20]:
def generate_sparse_data(shape,
                         num_points,
                         num_channels,
                         integer=False,
                         data_range=(-1, 1),
                         with_dense=True,
                         dtype=np.float32,
                         shape_scale = 1):
    dense_shape = shape
    ndim = len(dense_shape)
    # num_points = np.random.randint(10, 100, size=[batch_size, ndim])
    num_points = np.array(num_points)
    # num_points = np.array([3, 2])
    batch_size = len(num_points)
    batch_indices = []
    coors_total = np.stack(np.meshgrid(*[np.arange(0, s // shape_scale) for s in shape]),
                           axis=-1)
    coors_total = coors_total.reshape(-1, ndim) * shape_scale
    for i in range(batch_size):
        np.random.shuffle(coors_total)
        inds_total = coors_total[:num_points[i]]
        inds_total = np.pad(inds_total, ((0, 0), (0, 1)),
                            mode="constant",
                            constant_values=i)
        batch_indices.append(inds_total)
    if integer:
        sparse_data = np.random.randint(data_range[0],
                                        data_range[1],
                                        size=[num_points.sum(),
                                              num_channels]).astype(dtype)
    else:
        sparse_data = np.random.uniform(data_range[0],
                                        data_range[1],
                                        size=[num_points.sum(),
                                              num_channels]).astype(dtype)

    # sparse_data = np.arange(1, num_points.sum() + 1).astype(np.float32).reshape(5, 1)

    res = {
        "features": sparse_data.astype(dtype),
    }
    if with_dense:
        dense_data = np.zeros([batch_size, num_channels, *dense_shape],
                              dtype=sparse_data.dtype)
        start = 0
        for i, inds in enumerate(batch_indices):
            for j, ind in enumerate(inds):
                dense_slice = (i, slice(None), *ind[:-1])
                dense_data[dense_slice] = sparse_data[start + j]
            start += len(inds)
        res["features_dense"] = dense_data.astype(dtype)
    batch_indices = np.concatenate(batch_indices, axis=0)
    res["indices"] = batch_indices.astype(np.int32)
    return res

In [21]:
sparse_dict = generate_sparse_data(shape = [128, 128, 96], num_points = [1077, 987, 1501, 1324], num_channels = 1)

In [22]:
features = np.ascontiguousarray(sparse_dict["features"]).astype(
    np.float32)
indices = np.ascontiguousarray(
    sparse_dict["indices"][:, [3, 0, 1, 2]]).astype(np.int32)

indices_t = torch.from_numpy(indices).int().cuda()
features_t = torch.from_numpy(features).cuda()

sparse_tensor = spconv.SparseConvTensor(features = features_t, indices = indices_t, spatial_shape=(128,128,96), batch_size=4)

In [23]:
sparse_tensor.dense().shape

torch.Size([4, 1, 128, 128, 96])

In [24]:
sdn.cuda()
heatmaps, feature_encodings = sdn(sparse_tensor.dense())

In [25]:
heatmaps.shape, feature_encodings.shape

(torch.Size([4, 10, 8, 8, 6]), torch.Size([4, 10, 0]))

In [11]:
heatmaps.mean(), heatmaps.std(), heatmaps.max(), heatmaps.min()

(tensor(0.0026, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.0008, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(0.0645, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(0.0002, device='cuda:0', grad_fn=<MinBackward1>))

In [16]:
feature_encodings.mean(), feature_encodings.std(), feature_encodings.max(), feature_encodings.min()

(tensor(7.1419e-06, device='cuda:0', grad_fn=<MeanBackward0>),
 tensor(0.0276, device='cuda:0', grad_fn=<StdBackward0>),
 tensor(0.4506, device='cuda:0', grad_fn=<MaxBackward1>),
 tensor(-0.4417, device='cuda:0', grad_fn=<MinBackward1>))