In [46]:
import torch

In [47]:
import torch
import torch.nn as nn

class SpaceTimeCubeEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SpaceTimeCubeEmbedding, self).__init__()
        # 3D Convolution layer with kernel size and stride of (2, 4, 4)
        # This will reduce the temporal dimension by a factor of 2 and spatial dimensions by a factor of 4
        self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=(2, 4, 4), stride=(2, 4, 4))
        # Layer Normalization
        self.layer_norm = nn.LayerNorm(out_channels)

    def forward(self, x):
        # Applying 3D convolution
        x = self.conv3d(x)
        # Permuting the dimensions to apply LayerNorm correctly
        x_permuted = x.permute(1, 2, 3 ,0)  # Move the channel dimension to the end for LayerNorm
        # Applying layer normalization
        x_normalized = self.layer_norm(x_permuted)
        # Permuting back to the original dimension order
        x_out = x_normalized.permute(3, 0, 1, 2)
        return x_out

# Example usage
# Assuming the input tensor has dimensions [batch_size, in_channels, T, H, W]
# For example, an input shape might be [batch_size, 3, 360, 720, 1440] for RGB video data
# The output shape would be [batch_size, C, 180, 360, 720] assuming C is the number of output channels specified

# Example initialization
# in_channels = 3 (e.g., RGB channels)
# out_channels = C (desired number of output channels)
model = SpaceTimeCubeEmbedding(in_channels=2, out_channels=128)


In [50]:
random_tensor = torch.randn(2, 70, 721, 1440)
x = model.forward(random_tensor)
x.shape


torch.Size([128, 35, 180, 360])