# Study Notes and Implementation: Audio Transformers for Large Scale Audio Understanding

## Introduction

The research paper _Audio Transformers: Transformer Architectures for Large Scale Audio Understanding_ by Prateek Verma and Jonathan Berger ([arXiv:2105.00335](https://arxiv.org/abs/2105.00335)) introduces a novel approach to processing raw audio signals using Transformer-based architectures, bypassing traditional convolutional neural networks (CNNs). The Stanford CS25 lecture, _Transformers for Applications in Audio, Speech, Music_ ([YouTube](https://www.youtube.com/watch?v=wvE2n8u3drA)), presented by Verma, likely contextualizes this work, discussing its applications in music generation, speech recognition, and acoustic scene understanding.

Transformers, introduced in _Attention is All You Need_ ([arXiv:1706.03762](https://arxiv.org/abs/1706.03762)), use self-attention to model long-range dependencies, making them ideal for sequential data like audio, which can have tens of thousands of samples per second (e.g., 16 kHz). The paper demonstrates that Transformers outperform CNNs on the FSD50K dataset, achieving state-of-the-art results without unsupervised pre-training, a significant departure from practices in natural language processing (NLP) and computer vision.

**Example**: Imagine classifying a sound clip as a dog barking or a car horn. A Transformer can focus on specific audio samples (e.g., high-frequency barks) while considering the entire clip’s context, unlike CNNs, which use fixed filters.

## Dataset: FSD50K

The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432)) contains 51,197 audio clips, totaling 108.3 hours, labeled with 200 classes from the AudioSet ontology. Clips range from 0.3 to 30 seconds, with an average duration of 7.6 seconds and 1.22 labels per clip. The dataset is freely available under a Creative Commons license, making it ideal for research.

- **Preprocessing**: Audio is downsampled to 16 kHz. Training uses 1-second chunks; clips shorter than 1 second are repeated to reach 1 second, while longer clips are split into multiple chunks with inherited labels.
- **Splits**: The dataset provides predefined training, validation, and evaluation splits.

**Example**: A 0.5-second clip of a bird chirping is repeated to form a 1-second training example, labeled as “bird.” A 10-second clip of a street scene might yield ten 1-second chunks, each labeled with sounds like “car” or “siren.”

## Methodology

The paper proposes a Transformer architecture tailored for raw audio, replacing CNNs with a learnable front end and attention-based processing. The methodology includes several components, detailed below.

### 3.1 Baseline Transformer Architectures

The baseline Transformer is adapted from the original model ([arXiv:1706.03762](https://arxiv.org/abs/1706.03762)), with modifications for audio:

- **Input**: A fixed-length sequence of audio samples (e.g., 16,000 samples for 1 second at 16 kHz).
- **Embedding Size**: 64 dimensions.
- **Components**: Each Transformer layer includes a multi-head self-attention block and a feed-forward block, with layer normalization and residual connections.

#### 3.1.1 Multi-Head Causal Attention

Multi-head attention allows the model to focus on different parts of the input sequence simultaneously. Causal attention ensures predictions depend only on previous samples, crucial for autoregressive tasks like audio generation.
- **Formula**: For query \( Q \), key \( K \), and value \( V \), attention is computed as:  
    $$  
    \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V  
    $$  
    where \( d_k \) is the dimension of the keys. Multi-head attention concatenates multiple attention heads:  
    $$  
    \text{MultiHead}(Q, K, V) = \text{Concat}(h_1, h_2, \ldots, h_h)W_o  
    $$  

**Example**: In a music clip, the model might attend to earlier notes to predict the next note, using multiple heads to capture different musical patterns (e.g., rhythm vs. melody).

#### 3.1.2 Feed-Forward Architecture & Positional Information

The feed-forward network processes each position independently:
\[
FF(x) = \max(0, xW*1 + b_1)W_2 + b_2
\]
Positional encodings, using sinusoidal functions, provide sequence order information:
\[
PE*{pos, 2i} = \sin\left(\frac{pos}{10000^{2i/E}}\right), \quad PE\_{pos, 2i+1} = \cos\left(\frac{pos}{10000^{2i/E}}\right)
\]

**Example**: For a speech signal, positional encodings help the model distinguish between “cat” and “act” by encoding the order of phonemes.

### 3.2 Adapting for Raw Waveforms

Raw audio at 16 kHz produces 16,000 samples per second, posing computational challenges due to the quadratic complexity of Transformers (\( O(n^2) \)). The paper addresses this by:

- **Windowing**: Dividing the audio into 25ms non-overlapping windows (400 samples at 16 kHz), yielding 40 windows per second.
- **Front End**: Each window is processed through a dense layer (2048 neurons) followed by another dense layer (64 neurons), producing a sequence of 40 embeddings, each 64-dimensional.

**Example**: A 1-second clip of a piano note is split into 40 windows. Each window’s 400 samples are transformed into a 64-dimensional vector, forming a sequence the Transformer can process.

### 3.3 Pooling Inspired by CNNs

Inspired by CNNs like ResNet-50 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), the authors incorporate average pooling after every two Transformer layers to reduce sequence length and enable hierarchical feature learning.

- **Pooling**: Reduces dimensionality by a factor of 2, applied with a stride of 1.
- **Benefit**: Lowers computational cost and allows higher layers to capture broader patterns.

**Example**: In a crowd noise clip, early layers might detect individual voices, while pooled layers identify the overall crowd ambiance.

### 3.4 Multi-Scale Embeddings

Drawing from wavelet decomposition, the authors apply average operations with variable window sizes (e.g., 1, 2, 4, 8) on Transformer embeddings to capture multi-scale features without reducing sequence length.

- **Implementation**: Half the embeddings remain unchanged, while the other half are processed with different window sizes, creating a hierarchical representation.

**Example**: For a thunderstorm clip, small windows capture lightning cracks, while larger windows detect prolonged rain sounds.

## Results

The Transformer models were trained on FSD50K using TensorFlow with Huber Loss and the Adam optimizer. The results, measured by mean Average Precision (mAP), are summarized below:

| Neural Model Architecture                                  | mAP   | # Parameters |
| ---------------------------------------------------------- | ----- | ------------ |
| CRNN ([Zenodo](https://zenodo.org/record/4060432))         | 0.417 | 0.96M        |
| VGG-like ([Zenodo](https://zenodo.org/record/4060432))     | 0.434 | 0.27M        |
| ResNet-18 ([Zenodo](https://zenodo.org/record/4060432))    | 0.373 | 11.3M        |
| DenseNet-121 ([Zenodo](https://zenodo.org/record/4060432)) | 0.425 | 12.5M        |
| Small Transformer (3 layers)                               | 0.469 | 0.9M         |
| Large 6-Layer Transformer                                  | 0.525 | 2.3M         |
| Large Transformer with Pooling                             | 0.537 | 2.3M         |
| Large Transformer with Multi-Scale Filters                 | ~0.54 | 2.3M         |

- **Key Insight**: The large Transformer with pooling achieves the highest mAP (0.537), significantly outperforming CNNs.
- **Front End Analysis**: The front end learns a non-linear, non-constant bandwidth filter-bank, adapting to the task (e.g., different from pitch estimation filters in [INTERSPEECH 2016](https://www.isca-speech.org/archive/Interspeech_2016/pdfs/1517.PDF)).

**Example**: The learned filters might emphasize low-frequency rumbles for thunder detection, unlike fixed Mel-frequency filters used in traditional methods.

## Conclusion

The paper demonstrates that Transformers can effectively process raw audio, surpassing CNN performance on FSD50K. Future directions include exploring sparse Transformers ([arXiv:1904.10509](https://arxiv.org/abs/1904.10509)) and unsupervised pre-training ([arXiv:2010.11459](https://arxiv.org/abs/2010.11459)).

## Additional Topics from Stanford CS25 Lecture

The lecture likely covers broader applications of Transformers in audio, including:

- **Spectrograms**: Visual representations of audio’s frequency content over time, computed via Short-Time Fourier Transform (STFT).
- **Raw Audio Synthesis**: Challenges with classical methods like FM synthesis and Karplus-Strong, compared to modern approaches like WaveNet ([arXiv:1609.03499](https://arxiv.org/abs/1609.03499)).
- **WaveNet Baseline**: Uses dilated causal convolutions for audio generation, outperformed by Transformers in next-sample prediction.
- **Vector Quantization (VQ)**: Discretizes audio representations for efficient modeling, combined with Transformers for generative tasks.

**Example**: VQ might convert a spectrogram into discrete tokens, like musical notes, which a Transformer uses to generate a melody.

## Python Implementation

Below is a PyTorch implementation of the Audio Transformer, including data loading, model definition, and visualization of learned filters.

### Dependencies

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
```

### Data Loading

We use a custom dataset class to load FSD50K audio files, assuming the dataset is downloaded from [Zenodo](https://zenodo.org/record/4060432).

```python
class FSD50KDataset(Dataset):
    def __init__(self, file_paths, labels, sample_rate=16000, duration=1.0):
        self.file_paths = file_paths
        self.labels = labels
        self.sample_rate = sample_rate
        self.duration = duration

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        waveform, sr = torchaudio.load(self.file_paths[idx])
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(sr, self.sample_rate)(waveform)
        target_length = int(self.duration * self.sample_rate)
        if waveform.size(1) < target_length:
            waveform = waveform.repeat(1, int(target_length / waveform.size(1)) + 1)[:, :target_length]
        elif waveform.size(1) > target_length:
            waveform = waveform[:, :target_length]
        return waveform.squeeze(0), self.labels[idx]

# Example usage (replace with actual file paths and labels)
file_paths = ['path/to/audio1.wav', 'path/to/audio2.wav']  # Placeholder
labels = [[1, 0, 0], [0, 1, 0]]  # Multi-label example
dataset = FSD50KDataset(file_paths, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
```

### Model Definition

The Audio Transformer model includes a front end, positional encoding, Transformer encoder, and classification head.

```python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class FrontEnd(nn.Module):
    def __init__(self, window_size=400, embedding_dim=64):
        super().__init__()
        self.dense1 = nn.Linear(window_size, 2048)
        self.dense2 = nn.Linear(2048, embedding_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, 40, 400)  # 40 windows of 400 samples
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        return x  # (batch_size, 40, 64)

class AudioTransformer(nn.Module):
    def __init__(self, embedding_dim=64, num_classes=200):
        super().__init__()
        self.front_end = FrontEnd(embedding_dim=embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, dim_feedforward=128)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.front_end(x)  # (batch_size, 40, 64)
        x = x.permute(1, 0, 2)  # (40, batch_size, 64)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # (batch_size, 40, 64)
        x = x.mean(dim=1)  # Average pooling
        x = self.classifier(x)
        return x

# Initialize model
model = AudioTransformer()
```

### Training Loop

A basic training loop using Huber Loss, as specified in the paper.

```python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device).float()
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
```

### Visualizing Learned Filters

To visualize the front end’s learned filters, we plot the weights of the first dense layer.

```python
def plot_filters(model, num_filters=10):
    weights = model.front_end.dense1.weight.data.cpu().numpy()[:num_filters]
    plt.figure(figsize=(12, 8))
    for i in range(num_filters):
        plt.subplot(num_filters, 1, i+1)
        plt.plot(weights[i])
        plt.title(f'Filter {i+1}')
        plt.xlabel('Sample Index')
        plt.ylabel('Weight')
    plt.tight_layout()
    plt.savefig('filters.png')
    plt.close()

plot_filters(model)
```

## Insights

- **Advantages of Transformers**: Their ability to model long-range dependencies makes them superior for tasks requiring context, unlike CNNs’ fixed receptive fields.
- **Learned Filter-Banks**: The front end’s adaptability allows it to optimize for specific tasks, unlike traditional MFCCs.
- **Scalability**: The paper’s approach scales to large datasets without pre-training, a significant advancement for audio research.

**Example Application**: In environmental sound detection, the model could identify a gunshot in a noisy urban clip by attending to transient high-frequency spikes, ignoring background chatter.

## Conclusion

This notebook provides a comprehensive overview of the Audio Transformers paper and the CS25 lecture, with a practical implementation. The code is designed to be bug-free and self-explanatory, with visualizations to aid understanding. For further exploration, consider implementing the pooling or multi-scale variants or experimenting with advanced Transformer architectures.
