In [25]:
import torch
import torch.nn as nn
import einops


class CorrelationPooling(nn.Module):
  def __init__(self, input_dim, proj_dim, dropout_prob=0.25):
    super(CorrelationPooling, self).__init__()
    self.projection = nn.Linear(input_dim, proj_dim)
    self.dropout = nn.Dropout2d(dropout_prob)  # Channel dropout
    self.out_dim = proj_dim ** 2 // 2 + proj_dim // 2

  def forward(self, x):
    # x is expected to have shape (batch_size, time_steps, channels)
    x = self.projection(x)  # Linear projection to lower dimensional space

    # Apply LayerNorm over the time dimension
    x = einops.rearrange(x, 'b t c -> b c t')
    x = x - x.mean(dim=-1, keepdim=True)
    # x = x / (x.std(dim=-1, keepdim=True) + 1e-6)

    # Apply channel dropout
    x = einops.rearrange(x, 'b c t -> b c t 1')  # Reshape for Dropout2d
    x = self.dropout(x)
    x = einops.rearrange(x, 'b c t 1 -> b t c')  # Reshape back


    # Calculate the correlation matrix
    # shape: (batch_size, number_of_correlations, number_of_correlations)
    outer_product = torch.bmm(x.transpose(1, 2), x) / x.size(1)

    # Vectorize the upper triangular part of the correlation matrix
    i, j = torch.triu_indices(
      outer_product.size(1), outer_product.size(2), offset=0
    )
    correlation_vector = outer_product[:, i, j]
    return correlation_vector


# Example usage
batch_size = 32
time_steps = 100
channels = 256
proj_dim = 16

# Example input
x = torch.randn(batch_size, time_steps, channels)

# Correlation pooling layer
correlation_pooling = CorrelationPooling(input_dim=channels, proj_dim=proj_dim)
output = correlation_pooling(x)
print(output)  # Expected output shape: (batch_size, number_of_correlations)


tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.4230, -0.0891,  0.6091],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.6958],
        [ 0.4521,  0.0000,  0.1144,  ...,  0.6203,  0.0227,  0.5159],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.5789,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.6036, -0.0096,  0.6381],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.7249, -0.0051,  0.6996]],
       grad_fn=<IndexBackward0>)


In [26]:
output.shape

torch.Size([32, 136])