# SpecTTTra: A Mock Illustration of Audio Input to Spectral & Temporal Clips

## 1. Load audio and extract Mel Spectrogram

We use `librosa` to load an audio file and extract its mel spectrogram. The mel spectrogram is a time-frequency representation that captures both the spectral (frequency) and temporal (time) characteristics of the audio signal.

In [None]:
import librosa
import numpy as np
import matplotlib.pyplot as plt

# Load an example audio file
audio_path = 'audio/folded.mp3' 
sample_rate = 16000
max_time = 5 
max_len = sample_rate * max_time

# Load audio and trim/pad to max_len
audio, sr = librosa.load(audio_path, sr=sample_rate)
if len(audio) < max_len:
    audio = np.pad(audio, (0, max_len - len(audio)), mode='constant')
else:
    audio = audio[:max_len]

# Extract mel spectrogram
n_fft = 2048
hop_length = 512
n_mels = 128
melspec = librosa.feature.melspectrogram(
    y=audio,
    sr=sample_rate,
    n_fft=n_fft,
    hop_length=hop_length,
    n_mels=n_mels,
    fmin=20,
    fmax=8000,
    power=2
)
melspec_db = librosa.power_to_db(melspec, ref=np.max)

## 2. Divide the Spectrogram into Temporal and Spectral Clips

SpecTTTra divides the spectrogram into two types of clips:
- **Temporal Clips:** Slices along the time axis, capturing short time windows across all frequencies.
- **Spectral Clips:** Slices along the frequency axis, capturing frequency bands across all time frames.

In [None]:
# Define clip sizes
t_clip = 5  # temporal clip size (frames)
f_clip = 3  # spectral clip size (mel bins)

# Calculate number of clips
num_temporal_clips = (melspec_db.shape[1] - t_clip) // t_clip + 1
num_spectral_clips = (melspec_db.shape[0] - f_clip) // f_clip + 1

# Extract temporal clips
temporal_clips = []
for i in range(num_temporal_clips):
    start = i * t_clip
    end = start + t_clip
    clip = melspec_db[:, start:end]
    temporal_clips.append(clip)

# Extract spectral clips
spectral_clips = []
for i in range(num_spectral_clips):
    start = i * f_clip
    end = start + f_clip
    clip = melspec_db[start:end, :]
    spectral_clips.append(clip)

## 3. Visualization

We visualize the mel spectrogram and overlay rectangles to show the first few temporal and spectral clips.

In [None]:
plt.figure(figsize=(12, 6))
plt.imshow(melspec_db, aspect='auto', origin='lower', cmap='magma')
plt.title('Mel Spectrogram with Temporal and Spectral Clips')
plt.xlabel('Time Frames')
plt.ylabel('Mel Frequency Bins')

# Overlay temporal clips (vertical rectangles)
for i in range(min(10, num_temporal_clips)):
    start = i * t_clip
    plt.gca().add_patch(plt.Rectangle((start, 0), t_clip, n_mels, edgecolor='cyan', facecolor='none', lw=2, label='Temporal Clip' if i==0 else None))

# Overlay spectral clips (horizontal rectangles)
for i in range(min(10, num_spectral_clips)):
    start = i * f_clip
    plt.gca().add_patch(plt.Rectangle((0, start), melspec_db.shape[1], f_clip, edgecolor='lime', facecolor='none', lw=2, label='Spectral Clip' if i==0 else None))

handles, labels = plt.gca().get_legend_handles_labels()
if handles:
    plt.legend(handles, labels)
plt.colorbar(label='dB')
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn

# Set token dimension (embedding size)
token_dim = 384

# Define SpecTTTra-repo-based tokenizer
class ClipTokenizer(nn.Module):
    def __init__(self, input_dim, token_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, token_dim)
    def forward(self, x):
        # x: (num_clips, clip_size)
        return self.proj(x)

# Tokenize temporal clips
temporal_clip_size = n_mels * t_clip
temporal_tokenizer = ClipTokenizer(temporal_clip_size, token_dim)

temporal_clips_flat = [torch.tensor(clip.flatten(), dtype=torch.float32) for clip in temporal_clips]
temporal_clips_batch = torch.stack(temporal_clips_flat)  # shape: (num_temporal_clips, clip_size)
temporal_tokens = temporal_tokenizer(temporal_clips_batch)  # shape: (num_temporal_clips, token_dim)

# Tokenize spectral clips
spectral_clip_size = f_clip * melspec_db.shape[1]
spectral_tokenizer = ClipTokenizer(spectral_clip_size, token_dim)

spectral_clips_flat = [torch.tensor(clip.flatten(), dtype=torch.float32) for clip in spectral_clips]
spectral_clips_batch = torch.stack(spectral_clips_flat)  # shape: (num_spectral_clips, clip_size)
spectral_tokens = spectral_tokenizer(spectral_clips_batch)  # shape: (num_spectral_clips, token_dim)

# Show token shapes and a sample
print("Temporal tokens shape:", temporal_tokens.shape)
print("Spectral tokens shape:", spectral_tokens.shape)
print("Sample temporal token:", temporal_tokens[0].detach().numpy())
print("Sample spectral token:", spectral_tokens[0].detach().numpy())

In [None]:
import torch
import torch.nn as nn

# Parameters from config
embed_dim = 384
num_heads = 6
num_layers = 12
pe_learnable = True

# Temporal_tokens and spectral_tokens are already computed and have shape (num_clips, embed_dim)
# Concatenate both types of tokens
all_tokens = torch.cat([temporal_tokens, spectral_tokens], dim=0)  # shape: (num_total_clips, token_dim)

# Project tokens to embed_dim 
if all_tokens.shape[1] != embed_dim:
    projector = nn.Linear(all_tokens.shape[1], embed_dim)
    all_tokens = projector(all_tokens)

num_tokens = all_tokens.shape[0]

# Positional Embedding (learnable)
if pe_learnable:
    pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim))
    nn.init.trunc_normal_(pos_embed, std=0.02)
    tokens_with_pos = all_tokens.unsqueeze(0) + pos_embed  # (1, num_tokens, embed_dim)
else:
    # Use fixed sinusoidal embedding if not learnable
    def get_sinusoid_encoding(n_position, d_hid):
        ''' Sinusoid position encoding table '''
        def get_angle(pos, i):
            return pos / (10000 ** (2 * (i // 2) / d_hid))
        table = torch.zeros(n_position, d_hid)
        for pos in range(n_position):
            for i in range(d_hid):
                table[pos, i] = get_angle(pos, i)
        table[:, 0::2] = torch.sin(table[:, 0::2])
        table[:, 1::2] = torch.cos(table[:, 1::2])
        return table
    pos_embed = get_sinusoid_encoding(num_tokens, embed_dim).unsqueeze(0)
    tokens_with_pos = all_tokens.unsqueeze(0) + pos_embed

# Transformer Encoder (PyTorch nn.TransformerEncoder)
encoder_layer = nn.TransformerEncoderLayer(
    d_model=embed_dim,
    nhead=num_heads,
    dim_feedforward=int(embed_dim * 2.67),  # mlp_ratio from config
    dropout=0.1,  # pos_drop_rate, attn_drop_rate
    activation='gelu',
    batch_first=True
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

# Forward pass through transformer
encoded_tokens = transformer_encoder(tokens_with_pos)  # shape: (1, num_tokens, embed_dim)

print("Encoded tokens shape:", encoded_tokens.shape)
print("Sample encoded token:", encoded_tokens[0, 0].detach().numpy())