In [69]:
import torch

In [70]:
import torch
import torch.nn as nn

class SpaceTimeCubeEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SpaceTimeCubeEmbedding, self).__init__()
        # 3D Convolution layer with kernel size and stride of (2, 4, 4)
        # This will reduce the temporal dimension by a factor of 2 and spatial dimensions by a factor of 4
        self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=(2, 4, 4), stride=(2, 4, 4))
        # Layer Normalization
        self.layer_norm = nn.LayerNorm(out_channels)

    def forward(self, x):
        x_permuted = x.permute(1, 0, 2, 3 )  # Move the channel dimension to the end for LayerNorm
        x = self.conv3d(x_permuted)
        x = x.permute(1, 2, 3, 0)  # Move the channel dimension to the end for LayerNorm
        x_normalized = self.layer_norm(x)
        x_out = x_normalized.permute( 0, 3, 1, 2)
        return x_out


model = SpaceTimeCubeEmbedding(in_channels=70, out_channels=128)


In [71]:
random_tensor = torch.randn(2, 70, 721, 1440)
x = model.forward(random_tensor)
x.shape


torch.Size([1, 128, 180, 360])