<a href="https://colab.research.google.com/github/ayyucedemirbas/sensor_fusion/blob/main/SensorFusionNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn

class SensorFusionNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Camera stream (CNN)
        self.camera_cnn = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3),  # Output: (16, 62, 62)
            nn.ReLU(),
            nn.MaxPool2d(2),                   # Output: (16, 31, 31)
            nn.Flatten()                       # Output: 16*31*31 = 15,376
        )

        # LiDAR stream (PointNet-like)
        self.lidar_mlp = nn.Sequential(
            nn.Linear(3, 64),  # Input: XYZ coordinates per point
            nn.ReLU(),
            nn.Linear(64, 128) # Output: 128 features per point
        )

        # Fusion layer (adjusted input dimension)
        self.fusion = nn.Linear(15376 + 128, 256)  # 15,376 (camera) + 128 (LiDAR)
        self.output = nn.Linear(256, 4)            # Bounding box prediction

    def forward(self, camera_img, lidar_points):
        # Process camera
        camera_features = self.camera_cnn(camera_img)  # Shape: [batch, 15376]

        # Process LiDAR
        lidar_features = self.lidar_mlp(lidar_points)  # Shape: [batch, num_points, 128]
        lidar_features = torch.max(lidar_features, dim=1)[0]  # Global feature: [batch, 128]

        # Concatenate and predict
        fused = torch.cat([camera_features, lidar_features], dim=1)  # [batch, 15376+128]
        fused = self.fusion(fused)
        return self.output(fused)

# Test the corrected model
model = SensorFusionNet()
camera_input = torch.randn(1, 3, 64, 64)  # Batch of 1 RGB image
lidar_input = torch.randn(1, 100, 3)       # Batch of 1 LiDAR point cloud (100 points)
predicted_bbox = model(camera_input, lidar_input)  # No error!

In [4]:
predicted_bbox

tensor([[-0.3337, -0.3898, -0.2344,  0.0316]], grad_fn=<AddmmBackward0>)