In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from net.utils import get_model_memory_nolog
# Define the MLP Gaussian Decoder
class MLP_GaussianDecoder(nn.Module):
    def __init__(self, input_dim, num_gaussians):
        super().__init__()
        self.num_gaussians = num_gaussians
        self.conv1d1 = nn.Conv1d(576, 12, kernel_size=1, stride=1, dilation=1 ,padding=0) #[351,576]-[351,96]
        # MLP to generate Gaussian parameters
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 512),  # Input -> hidden layer 0.30G
            nn.ReLU(),
            nn.Linear(512, 256),       # Hidden -> hidden
            nn.ReLU(),
            nn.Linear(256, num_gaussians * 6)  # Hidden -> Gaussian parameters
        )

    def forward(self, x): # [281, 10, 576]
        """
        z: Input encoded vector of shape [batch, length, channel]
        Returns: Decoded Gaussian-based image of shape [batch, output_h, output_w]
        """
        
        x = x.reshape(x.shape[1], x.shape[2], -1)  # 1DConv输入：Reshape to (batch_size, input_channel, seq_len) [10, 576, 281]
        batch_size, channel, length = x.shape
        x = self.conv1d1(x) #576变12 [10, 12, 281]

        # Flatten the input for MLP (combine length and channel dimensions)
        x_flat = x.view(batch_size, -1)  # Shape: [batch, length * 12channel] [10, 3372]

        # Generate Gaussian parameters
        params = self.mlp(x_flat)  # Shape: [batch, num_gaussians * 6] [10, 600]
        params = params.view(batch_size, self.num_gaussians, 6)  # [batch, num_gaussians, 6] [10, 100, 6]

        # Separate Gaussian parameters: x, y, sigma_x, sigma_y, rho, intensity
        xc, yc, sigma_x, sigma_y, rho, intensity = torch.chunk(params, 6, dim=-1)

        # Normalize xc, yc to [-1, 1] (image coordinate range)
        xc = torch.tanh(xc)
        yc = torch.tanh(yc)

        # Ensure sigma_x, sigma_y > 0 (use softplus activation)
        sigma_x = F.softplus(sigma_x) + 1e-6
        sigma_y = F.softplus(sigma_y) + 1e-6

        # Ensure rho is within a valid range [-1, 1] (correlation coefficient)
        rho = torch.tanh(rho)

        # Intensity can be scaled to [0, 1] using sigmoid
        intensity = torch.sigmoid(intensity)

        # Create a 2D grid for the output image
        x = torch.linspace(-1, 1, 720, device=x.device)
        y = torch.linspace(-1, 1, 360, device=x.device)
        X, Y = torch.meshgrid(x, y, indexing="ij")  # X, Y shape: [output_w, output_h]

        # Expand X, Y to support broadcasting across batches and gaussians
        X = X.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, output_w, output_h]
        Y = Y.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, output_w, output_h]

        # Expand Gaussian parameters to match grid dimensions
        xc = xc.unsqueeze(-1).unsqueeze(-1)  # [batch, num_gaussians, 1, 1]
        yc = yc.unsqueeze(-1).unsqueeze(-1)
        sigma_x = sigma_x.unsqueeze(-1).unsqueeze(-1)
        sigma_y = sigma_y.unsqueeze(-1).unsqueeze(-1)
        rho = rho.unsqueeze(-1).unsqueeze(-1)
        intensity = intensity.unsqueeze(-1).unsqueeze(-1)

        # Compute Gaussian function for all Gaussians in parallel
        X_diff = X - xc  # [batch, num_gaussians, output_w, output_h]
        Y_diff = Y - yc  # [batch, num_gaussians, output_w, output_h]

        gaussians = intensity * torch.exp(
            -(
                (X_diff ** 2) / (2 * sigma_x ** 2)
                + (Y_diff ** 2) / (2 * sigma_y ** 2)
                + rho * (X_diff * Y_diff) / (sigma_x * sigma_y)
            )
        )  # [batch, num_gaussians, output_w, output_h]

        # Sum over all Gaussians
        output = gaussians.sum(dim=1)  # Shape: [batch, output_w, output_h]

        return output

# Hyperparameters
batch_size = 10
length = 281
channel = 576
num_gaussians = 100  # Number of Gaussians


# Input encoded vector (randomly generated)
encoded = torch.randn(length, batch_size, channel)  # Shape: [281, 10, 576]

# Initialize the decoder
decoder = MLP_GaussianDecoder(input_dim=length * 12, num_gaussians=num_gaussians)
get_model_memory_nolog(decoder)
# Forward pass
output = decoder(encoded)  # Shape: [10, 360, 720]

# Print output shape
print("Output shape:", output.shape)  # Should be [10, 360, 720]

print(output)
#怎么output不一样了。。我笑死 不过好像问题不大，只要形状对上，能训练就行

模型占用0.0075GB
Output shape: torch.Size([10, 1, 720, 360])
tensor([[[[6.2354, 6.3076, 6.3802,  ..., 6.4246, 6.3517, 6.2791],
          [6.2712, 6.3438, 6.4168,  ..., 6.4608, 6.3874, 6.3145],
          [6.3071, 6.3801, 6.4535,  ..., 6.4971, 6.4233, 6.3499],
          ...,
          [6.6159, 6.6910, 6.7666,  ..., 6.3335, 6.2600, 6.1870],
          [6.5792, 6.6539, 6.7291,  ..., 6.2976, 6.2246, 6.1520],
          [6.5426, 6.6169, 6.6917,  ..., 6.2619, 6.1893, 6.1171]]],


        [[[6.2629, 6.3355, 6.4086,  ..., 6.6752, 6.6008, 6.5267],
          [6.2983, 6.3714, 6.4448,  ..., 6.7122, 6.6374, 6.5629],
          [6.3338, 6.4073, 6.4812,  ..., 6.7494, 6.6741, 6.5992],
          ...,
          [6.6032, 6.6781, 6.7534,  ..., 6.5245, 6.4505, 6.3770],
          [6.5670, 6.6415, 6.7164,  ..., 6.4878, 6.4142, 6.3411],
          [6.5310, 6.6050, 6.6794,  ..., 6.4512, 6.3781, 6.3054]]],


        [[[6.2888, 6.3602, 6.4321,  ..., 6.8231, 6.7489, 6.6751],
          [6.3249, 6.3968, 6.4690,  ..., 6.8610