In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class FourierLayer(nn.Module):
    def __init__(self, B, concat_original=True):
        super(FourierLayer, self).__init__()
        self.B = B
        self.concat_original = concat_original

    def forward(self, x):
        # Applying the transformation for each value in the channel
        cos_transform = torch.cos(2 * np.pi * self.B * x)
        sin_transform = torch.sin(2 * np.pi * self.B * x)

        # Stacking the transformed channels
        transformed = torch.stack([cos_transform, sin_transform], dim=1)

        # Reshaping to match original shape but with added channels
        transformed = transformed.view(x.shape[0], -1, x.shape[2])

        # Optionally concatenate with the original data
        if self.concat_original:
            return torch.cat([x, transformed], dim=1)
        else:
            return transformed

# Testing the FourierLayer
dummy_input = torch.randn(8, 2, 16384)
fourier_layer = FourierLayer(B=0.1, concat_original=True)
fourier_output = fourier_layer(dummy_input)
fourier_output.shape


torch.Size([8, 6, 16384])