# Experimenting with a Transformer Encoder

In [None]:
from einops import rearrange
from einops import repeat
import torch

from synthmap.models.soundstream import SoundStreamEncoder
from synthmap.utils.audio_utils import load_wav_dir_as_tensor

In [None]:
audio = load_wav_dir_as_tensor("../dataset/mars808", length=48000, sample_rate=48000)
audio = audio[:2]
print(audio.shape)

In [None]:
soundstream = SoundStreamEncoder(
    input_channels=1, hidden_channels=16, output_channels=128, strides=(2, 4, 4, 4)
)

In [None]:
z = soundstream(audio[:, None, :])

In [None]:
print(z.shape)

In [None]:
class TransformerAggregator(torch.nn.Module):
    def __init__(self, input_dim: int, output_dim: int, clip_length: int):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.clip_length = clip_length

        # Create the transormer encoder
        tlayer = torch.nn.TransformerEncoderLayer(
            d_model=self.input_dim,
            nhead=4,
            activation="gelu",
            batch_first=True,
            dim_feedforward=512,
        )
        self.transformer = torch.nn.TransformerEncoder(
            tlayer, num_layers=6, norm=torch.nn.LayerNorm(self.input_dim)
        )

        # Output projection
        self.proj = torch.nn.Linear(self.input_dim, self.output_dim)

        # Class tokens
        self.num_tokens = 2
        self.cls_token = torch.nn.Parameter(
            torch.zeros(1, self.num_tokens, self.input_dim)
        )

        # Positional encoding
        self.pos_emb = torch.nn.Parameter(
            torch.zeros(1, self.clip_length + self.num_tokens, self.input_dim)
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, torch.nn.Linear):
            with torch.no_grad():
                if isinstance(m, torch.nn.Linear) and m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
                    # nn.init.constant_(m.weight, 1)
        elif isinstance(m, torch.nn.LayerNorm):
            torch.nn.init.constant_(m.bias, 0)
            torch.nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, torch.nn.Parameter):
            with torch.no_grad():
                m.weight.data.normal_(0.0, 0.02)
                # nn.init.orthogonal_(m.weight)

    def forward(self, x):
        x = rearrange(z, "b f s -> b s f")

        # Add class token and append to the beginning of the input sequence
        tokens = repeat(self.cls_token, "() n d -> b n d", b=x.shape[0])
        x = torch.cat((tokens, x), dim=1)

        # Apply positional encoding
        x = x + self.pos_emb

        out = self.transformer(x)
        out = self.proj(out[:, 0, :])
        return out

In [None]:
encoder = TransformerAggregator(input_dim=128, output_dim=14, clip_length=z.shape[-1])

In [None]:
y = encoder(z)

print(y.shape)