In [1]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')

import os

import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import torchmetrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

from waveform_encoder import WaveformEncoder
from spectrogram_encoder import SpectrogramEncoder
from wf_sg_cross_attn import WaveformSpectrogramCrossEncoder

In [2]:
TRAIN_DIR_PATH = 'voice-commands-classification-2025/train'
TEST_DIR_PATH = 'voice-commands-classification-2025/adv_test'

In [3]:
BATCH_SIZE = 256
N_WORKERS = 6
N_CLASSES = 35
EPOCHS = 20
LR = 0.005

N_MFCC = 120
NOISE_AMPLITUDE = 0.00
MASK_PROB = 0.1

DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')

DEVICE

device(type='cuda', index=0)

In [24]:
# Dataset

def noise_waveform(waveform: torch.Tensor, noise_amplitude: float = 0.05) -> torch.Tensor:
    noise = noise_amplitude * torch.randn(waveform.shape).to(waveform.device)
    noisy_waveform = waveform + noise
    noisy_waveform = torch.clamp(noisy_waveform, -1.0, 1.0)
    return noisy_waveform

class SpeechCommandDataset(Dataset):
    def __init__(self, dir_path, data, labels=None, dict_label_to_index=None, transform=None, noise_amplitude=0.00):
        self.dir_path = dir_path
        self.data = data
        self.labels = labels
        self.dict_label_to_index = dict_label_to_index
        self.transform = transform
        self.noise_amplitude = noise_amplitude

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_name = self.data[idx]
        waveform = np.load(os.path.join(self.dir_path, file_name))
        if waveform.shape[1] < 16000:
            waveform = np.pad(
                waveform, pad_width=((0, 0), (0, 16000 - waveform.shape[1])),
                mode='constant',
                constant_values=0
            )

        waveform = torch.from_numpy(waveform).float()
        if self.noise_amplitude > 0:
            waveform = noise_waveform(waveform, self.noise_amplitude)

        if self.transform != None:
            spectrogram = self.transform(waveform)
        else:
            spectrogram = None
        
        out_labels = []
        if self.labels is not None:
            if self.labels[idx] in self.dict_label_to_index:
                out_labels = self.dict_label_to_index[self.labels[idx]]

        return waveform, spectrogram, out_labels, int(file_name.split('.')[0])

In [5]:
df_train = pd.read_csv(
    os.path.join(TRAIN_DIR_PATH, 'metadata.csv')
)
dict_label_to_index = {}
dict_index_to_label = {}
for index, key in enumerate(df_train['label'].unique()):
    dict_label_to_index[key] = index
    dict_index_to_label[index] = key

dict_label_to_index

{'stop': 0,
 'go': 1,
 'right': 2,
 'dog': 3,
 'left': 4,
 'yes': 5,
 'zero': 6,
 'four': 7,
 'bird': 8,
 'cat': 9,
 'five': 10,
 'off': 11,
 'learn': 12,
 'six': 13,
 'two': 14,
 'on': 15,
 'up': 16,
 'three': 17,
 'nine': 18,
 'one': 19,
 'follow': 20,
 'wow': 21,
 'seven': 22,
 'sheila': 23,
 'down': 24,
 'no': 25,
 'bed': 26,
 'eight': 27,
 'house': 28,
 'tree': 29,
 'visual': 30,
 'forward': 31,
 'marvin': 32,
 'backward': 33,
 'happy': 34}

In [6]:
df_train_data, df_val_data = train_test_split(
    df_train,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

train_data = df_train_data.file_name.values
train_labels = df_train_data.label.values

val_data = df_val_data.file_name.values
val_labels = df_val_data.label.values

In [7]:
# DataLoader, transform

train_dataloader = DataLoader(
    SpeechCommandDataset(
        dir_path=TRAIN_DIR_PATH,
        data=train_data,
        labels=train_labels,
        dict_label_to_index=dict_label_to_index,
        transform=torchaudio.transforms.MFCC(n_mfcc=N_MFCC, log_mels=True),
        noise_amplitude=NOISE_AMPLITUDE
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=N_WORKERS
)

valid_dataloader = DataLoader(
    SpeechCommandDataset(
        dir_path=TRAIN_DIR_PATH,
        data=val_data,
        labels=val_labels,
        dict_label_to_index=dict_label_to_index,
        transform=torchaudio.transforms.MFCC(n_mfcc=N_MFCC, log_mels=True),
        noise_amplitude=0.0
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=N_WORKERS
)


In [8]:
for item in train_dataloader:
    print(item[0].shape, item[1].shape)
    break

torch.Size([256, 1, 16000]) torch.Size([256, 1, 120, 81])


In [9]:
# Load waveform encoder state dict
state_dict = torch.load("waveform_encoder.pt", weights_only=True)
wf_enc_weight = {k[7:]: v for k, v in state_dict.items() if 'wf_enc' in k}
wf_enc = WaveformEncoder(n_input=1, stride=160, kernel_size=400, n_channel=N_MFCC)
wf_enc.load_state_dict(wf_enc_weight)

# Load spectrogram encoder state dict
state_dict = torch.load("spectrogram_encoder.pt", weights_only=True)
sg_enc_weight = {k[7:]: v for k, v in state_dict.items() if 'sg_enc' in k}
sg_enc = SpectrogramEncoder(n_layer=4, n_head=6, hidden_dim=N_MFCC, mask_prob=MASK_PROB)
sg_enc.load_state_dict(sg_enc_weight)

<All keys matched successfully>

In [10]:
class M5(nn.Module):
    def __init__(
            self,
            n_class,
            n_layer,
            n_head,
            wf_enc:WaveformSpectrogramCrossEncoder,
            sg_enc:SpectrogramEncoder,
            hidden_dim: int = 96,
            mask_prob: float = 0.1,
            cross_attn_dropout: float = 0.1,
            ):
        super().__init__()
        self.wf_enc = wf_enc
        self.sg_enc = sg_enc
        self.cross_attn = WaveformSpectrogramCrossEncoder(n_layer, n_head, hidden_dim, dropout=cross_attn_dropout, mask_prob=mask_prob)
        self.out = nn.Linear(hidden_dim, n_class)

    def forward(self, x, sg):
        x = self.wf_enc(x)
        sg = self.sg_enc(sg)
        logits = self.cross_attn(sg, x)
        logits = logits.transpose(-1, -2)
        logits = F.avg_pool1d(logits, logits.shape[-1])
        logits = logits.transpose(-1, -2)
        logits = self.out(logits)
        return F.log_softmax(logits, dim=2)

In [11]:
model = M5(n_class=N_CLASSES, n_layer=2, n_head=4, hidden_dim=N_MFCC, wf_enc=wf_enc, sg_enc=sg_enc, mask_prob=MASK_PROB)
model = model.to(DEVICE)

In [12]:
input_image = torch.rand(4, 1, 16000)
input_sp = torchaudio.transforms.MFCC(n_mfcc=N_MFCC, log_mels=True)(input_image).squeeze(1).transpose(-1, -2)
model = model.to(DEVICE)
result = model(input_image.to(DEVICE), input_sp.to(DEVICE))
print(result.size())

torch.Size([4, 1, 35])


In [13]:
def lr_lambda(current_step):
    return max(0.0, float(EPOCHS - current_step) / EPOCHS)


def train_model(model: nn.Module, train_data: DataLoader, valid_data: DataLoader):
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    criterion = nn.NLLLoss()

    accuracy_train = torchmetrics.classification.Accuracy(task="multiclass", num_classes=N_CLASSES).to(DEVICE)
    accuracy_val = torchmetrics.classification.Accuracy(task="multiclass", num_classes=N_CLASSES).to(DEVICE)

    for epoch in range(EPOCHS):
        train_loss = 0.0
        val_loss = 0.0

        model.train()
        for x, x_sp, y, _ in train_data:
            x = x.to(DEVICE)
            x_sp = x_sp.to(DEVICE)
            y = y.to(DEVICE)

            optimizer.zero_grad()

            y_hat = model(x, x_sp.squeeze(1).transpose(-1, -2)).squeeze()
            loss = criterion(y_hat, y)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)
            _, preds = torch.max(y_hat, 1)

            accuracy_train(
                y_hat,
                y
            )

        model.eval()
        for x, x_sp, y, _ in valid_data:
            x = x.to(DEVICE)
            x_sp = x_sp.to(DEVICE)
            y = y.to(DEVICE)

            y_hat = model(x, x_sp.squeeze(1).transpose(-1, -2)).squeeze()
            loss = criterion(y_hat, y)

            val_loss += loss.item() * x.size(0)
            _, preds = torch.max(y_hat, 1)

            accuracy_val(
                y_hat,
                y
            )

        train_loss = train_loss / len(train_dataloader.dataset)
        val_loss = val_loss / len(valid_dataloader.dataset)

        scheduler.step()

        print(f"Epoch {epoch + 1}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {accuracy_train.compute():.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {accuracy_val.compute():.4f}")
        

In [None]:
# Freeze encoder weights to train cross attention layer

for param in model.wf_enc.parameters():
    param.requires_grad = False
for param in model.sg_enc.parameters():
    param.requires_grad = False

In [15]:
train_model(
    model=model,
    train_data=train_dataloader,
    valid_data=valid_dataloader
)

Epoch 1/20
Train Loss: 0.0789, Train Acc: 0.9826
Val Loss: 0.2712, Val Acc: 0.9419
Epoch 2/20
Train Loss: 0.0168, Train Acc: 0.9892
Val Loss: 0.2837, Val Acc: 0.9418
Epoch 3/20
Train Loss: 0.0150, Train Acc: 0.9915
Val Loss: 0.2736, Val Acc: 0.9424
Epoch 4/20
Train Loss: 0.0109, Train Acc: 0.9929
Val Loss: 0.2804, Val Acc: 0.9429
Epoch 5/20
Train Loss: 0.0118, Train Acc: 0.9937
Val Loss: 0.2703, Val Acc: 0.9433
Epoch 6/20
Train Loss: 0.0097, Train Acc: 0.9944
Val Loss: 0.2935, Val Acc: 0.9430
Epoch 7/20
Train Loss: 0.0084, Train Acc: 0.9949
Val Loss: 0.2698, Val Acc: 0.9436
Epoch 8/20
Train Loss: 0.0066, Train Acc: 0.9954
Val Loss: 0.2641, Val Acc: 0.9441
Epoch 9/20
Train Loss: 0.0073, Train Acc: 0.9958
Val Loss: 0.2640, Val Acc: 0.9445
Epoch 10/20
Train Loss: 0.0077, Train Acc: 0.9960
Val Loss: 0.2547, Val Acc: 0.9449
Epoch 11/20
Train Loss: 0.0056, Train Acc: 0.9963
Val Loss: 0.2576, Val Acc: 0.9452
Epoch 12/20
Train Loss: 0.0053, Train Acc: 0.9965
Val Loss: 0.2523, Val Acc: 0.9456
E

In [16]:
torch.save(model.state_dict(), 'cross_attn_encoder.pt')

In [13]:
model.load_state_dict(torch.load('cross_attn_encoder.pt'))

<All keys matched successfully>

In [14]:
def lr_lambda(current_step):
    return max(0.0, float(EPOCHS - current_step) / EPOCHS)


def train_model(model: nn.Module, train_data: DataLoader, valid_data: DataLoader):
    optimizer = torch.optim.Adam(model.parameters(), lr=LR // 10, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    criterion = nn.NLLLoss()

    accuracy_train = torchmetrics.classification.Accuracy(task="multiclass", num_classes=N_CLASSES).to(DEVICE)
    accuracy_val = torchmetrics.classification.Accuracy(task="multiclass", num_classes=N_CLASSES).to(DEVICE)

    for epoch in range(EPOCHS):
        train_loss = 0.0
        val_loss = 0.0

        model.train()
        for x, x_sp, y, _ in train_data:
            x = x.to(DEVICE)
            x_sp = x_sp.to(DEVICE)
            y = y.to(DEVICE)

            optimizer.zero_grad()

            y_hat = model(x, x_sp.squeeze(1).transpose(-1, -2)).squeeze()
            loss = criterion(y_hat, y)

            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)
            _, preds = torch.max(y_hat, 1)

            accuracy_train(
                y_hat,
                y
            )

        model.eval()
        for x, x_sp, y, _ in valid_data:
            x = x.to(DEVICE)
            x_sp = x_sp.to(DEVICE)
            y = y.to(DEVICE)

            y_hat = model(x, x_sp.squeeze(1).transpose(-1, -2)).squeeze()
            loss = criterion(y_hat, y)

            val_loss += loss.item() * x.size(0)
            _, preds = torch.max(y_hat, 1)

            accuracy_val(
                y_hat,
                y
            )

        train_loss = train_loss / len(train_dataloader.dataset)
        val_loss = val_loss / len(valid_dataloader.dataset)

        scheduler.step()

        print(f"Epoch {epoch + 1}/{EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {accuracy_train.compute():.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {accuracy_val.compute():.4f}")
        

In [15]:
# Unfreeze encoder weights to fine-tune the model

for param in model.wf_enc.parameters():
    param.requires_grad = True
for param in model.sg_enc.parameters():
    param.requires_grad = True

In [16]:
train_model(
    model=model,
    train_data=train_dataloader,
    valid_data=valid_dataloader
)

Epoch 1/20
Train Loss: 0.0021, Train Acc: 0.9999
Val Loss: 0.2429, Val Acc: 0.9512
Epoch 2/20
Train Loss: 0.0020, Train Acc: 0.9999
Val Loss: 0.2429, Val Acc: 0.9512
Epoch 3/20
Train Loss: 0.0020, Train Acc: 0.9999
Val Loss: 0.2429, Val Acc: 0.9512
Epoch 4/20
Train Loss: 0.0020, Train Acc: 0.9998
Val Loss: 0.2427, Val Acc: 0.9512
Epoch 5/20
Train Loss: 0.0021, Train Acc: 0.9999
Val Loss: 0.2426, Val Acc: 0.9512
Epoch 6/20
Train Loss: 0.0020, Train Acc: 0.9999
Val Loss: 0.2429, Val Acc: 0.9512
Epoch 7/20
Train Loss: 0.0020, Train Acc: 0.9999
Val Loss: 0.2431, Val Acc: 0.9512
Epoch 8/20
Train Loss: 0.0020, Train Acc: 0.9999
Val Loss: 0.2431, Val Acc: 0.9512
Epoch 9/20
Train Loss: 0.0022, Train Acc: 0.9998
Val Loss: 0.2430, Val Acc: 0.9512
Epoch 10/20
Train Loss: 0.0020, Train Acc: 0.9998
Val Loss: 0.2430, Val Acc: 0.9512
Epoch 11/20
Train Loss: 0.0020, Train Acc: 0.9998
Val Loss: 0.2430, Val Acc: 0.9512
Epoch 12/20
Train Loss: 0.0021, Train Acc: 0.9998
Val Loss: 0.2431, Val Acc: 0.9512
E

In [25]:
df_test = pd.read_csv(
    os.path.join(TEST_DIR_PATH, 'metadata.csv')
)
test_dataloader = DataLoader(
    SpeechCommandDataset(
        dir_path=TEST_DIR_PATH,
        data=df_test.file_name.values,
        labels=None,
        dict_label_to_index=dict_label_to_index,
        transform=torchaudio.transforms.MFCC(n_mfcc=N_MFCC, log_mels=True),
        noise_amplitude=0.0
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=N_WORKERS
)


In [26]:
# ENSEMBLE PREDICTIONS AND SUBMIT
results = {
    'id': [],
    'label': []
}

model.eval()
for x, x_sp, y, ids in test_dataloader:
    x = x.to(DEVICE)
    x_sp = x_sp.to(DEVICE)
    with torch.no_grad():
        y_hat = model(x, x_sp.squeeze(1).transpose(-1, -2)).squeeze()
        _, preds = torch.max(y_hat, 1)
        for i in range(len(preds)):
            results["id"].append(ids[i].item())
            results["label"].append(dict_index_to_label[int(preds[i].item())])
        

pd.DataFrame(results).to_csv(
    'submission.csv',
    columns=['id', 'label'],
    index=False
)

In [27]:
torch.save(model.state_dict(), 'cross_attn_encoder_fine_tuned.pt')