# PyTorch Audio Classification with M5 Network and Speech Commands

This example demonstrates how to train an audio classification model using the M5 Network architecture on the Google Speech Commands dataset with PyTorch and Kubeflow Trainer.

**Model**: [M5 Network](https://arxiv.org/abs/1610.00087) - A lightweight CNN that processes raw audio waveforms directly using 1D convolutions.

**Dataset**: [Speech Commands](https://arxiv.org/abs/1804.03209) - One-second audio clips of 35 spoken words ("yes", "no", "up", "down", etc.).

This notebook shows how to run distributed training locally and scale to a Kubernetes cluster using Kubeflow Trainer.

## Install the Kubeflow SDK

In [1]:
!pip install -U kubeflow



## Install PyTorch Dependencies

In [2]:
!pip install torch torchaudio



## Define the Training Function

This function includes the M5 model architecture, data loading, and distributed training logic.

In [3]:
def train_m5_speechcommands(batch_size: int = 256, epochs: int = 2, lr: float = 0.01):
    import os
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torch.distributed as dist
    import torchaudio
    from torchaudio.datasets import SPEECHCOMMANDS
    from torch.utils.data import DataLoader, DistributedSampler

    # M5 Network - lightweight CNN for audio classification
    class M5(nn.Module):
        def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
            super().__init__()
            # 4 conv blocks with increasing channels
            self.feature_extractor = nn.Sequential(
                nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride),
                nn.BatchNorm1d(n_channel),
                nn.ReLU(),
                nn.MaxPool1d(4),
                nn.Conv1d(n_channel, n_channel, kernel_size=3),
                nn.BatchNorm1d(n_channel),
                nn.ReLU(),
                nn.MaxPool1d(4),
                nn.Conv1d(n_channel, 2*n_channel, kernel_size=3),
                nn.BatchNorm1d(2*n_channel),
                nn.ReLU(),
                nn.MaxPool1d(4),
                nn.Conv1d(2*n_channel, 2*n_channel, kernel_size=3),
                nn.BatchNorm1d(2*n_channel),
                nn.ReLU(),
                nn.MaxPool1d(4)
            )
            self.classifier = nn.Linear(2 * n_channel, n_output)

        def forward(self, x):
            x = self.feature_extractor(x)
            x = F.avg_pool1d(x, x.shape[-1]).permute(0, 2, 1)
            return F.log_softmax(self.classifier(x), dim=2)

    # Custom collate to handle variable-length audio samples
    def collate_fn(batch):
        tensors, targets = [], []
        for waveform, _, label, *_ in batch:
            tensors += [waveform.t()]
            targets += [torch.tensor(labels.index(label))]
        # Pad to same length
        tensors = torch.nn.utils.rnn.pad_sequence(
            tensors, batch_first=True, padding_value=0.
        ).permute(0, 2, 1)
        return tensors, torch.stack(targets)

    # Setup distributed training
    device_type, backend = (
        ("cuda", "nccl") if torch.cuda.is_available() else ("cpu", "gloo")
    )
    print(f"Using Device: {device_type}, Backend: {backend}")
    
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    dist.init_process_group(backend=backend)
    print(
        f"Distributed Training - WORLD_SIZE: {dist.get_world_size()}, "
        f"RANK: {dist.get_rank()}, LOCAL_RANK: {local_rank}"
    )
    
    device = torch.device(f"{device_type}:{local_rank}")

    # Download dataset on rank 0 only to avoid conflicts
    if dist.get_rank() == 0:
        train_set = SPEECHCOMMANDS("./", download=True, subset="training")
    dist.barrier()
    
    train_set = SPEECHCOMMANDS("./", download=False, subset="training")
    labels = sorted(list(set(d[2] for d in train_set)))
    print(f"Number of classes: {len(labels)}")
    
    # Resample to 8kHz for faster processing
    waveform, sample_rate, *_ = train_set[0]
    resampler = torchaudio.transforms.Resample(sample_rate, 8000).to(device)
    
    # Distributed data loader
    sampler = DistributedSampler(train_set)
    loader = DataLoader(
        train_set,
        batch_size=batch_size,
        sampler=sampler,
        collate_fn=collate_fn,
        num_workers=1 if device_type == "cuda" else 0
    )

    # Initialize model with DDP
    model = M5(n_input=waveform.shape[0], n_output=len(labels)).to(device)
    model = nn.parallel.DistributedDataParallel(model)
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    # Training loop
    model.train()
    dist.barrier()
    
    for epoch in range(1, epochs + 1):
        sampler.set_epoch(epoch)  # Ensure different shuffle each epoch
        
        for batch_idx, (data, target) in enumerate(loader):
            data = resampler(data.to(device))
            target = target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output.squeeze(), target)
            loss.backward()
            optimizer.step()
            
            # Log from rank 0 only
            if batch_idx % 20 == 0 and dist.get_rank() == 0:
                print(
                    f"Epoch: {epoch} "
                    f"[{batch_idx * len(data) * dist.get_world_size()}/{len(train_set)} "
                    f"({100.0 * batch_idx / len(loader):.0f}%)]\t"
                    f"Loss: {loss.item():.6f}"
                )
        
        scheduler.step()
    
    dist.barrier()
    if dist.get_rank() == 0:
        print("Training Complete")
    
    dist.destroy_process_group()

## Scale with Kubeflow TrainJob

Now scale the training across multiple nodes on your Kubernetes cluster.

In [4]:
from kubeflow.trainer import CustomTrainer, TrainerClient

client = TrainerClient()

## List Available Training Runtimes

In [5]:
for runtime in client.list_runtimes():
    print(runtime)
    if runtime.name == "torch-distributed":
        torch_runtime = runtime

Runtime(name='deepspeed-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='deepspeed', image='ghcr.io/kubeflow/trainer/deepspeed-runtime:v2.1.0', num_nodes=1, device='Unknown', device_count='1'), pretrained_model=None)
Runtime(name='mlx-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='mlx', image='ghcr.io/kubeflow/trainer/mlx-runtime:v2.1.0', num_nodes=1, device='Unknown', device_count='1'), pretrained_model=None)
Runtime(name='torch-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', image='pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None)
Runtime(name='torchtune-llama3.2-1b', trainer=RuntimeTrainer(trainer_type=<TrainerType.BUILTIN_TRAINER: 'BuiltinTrainer'>, framework='torchtune', image='ghcr.io/kubeflow/trainer/torchtune-trainer:

## Submit Distributed Training Job

In [6]:
job_id = client.train(
    trainer=CustomTrainer(
        func=train_m5_speechcommands,
        func_args={"epochs": 5, "batch_size": 256, "lr": 0.01},
        num_nodes=1,
        packages_to_install=["torchaudio", "torch", "soundfile"],
        resources_per_node={
            "cpu": "8",
            "memory": "32Gi",
            # Comment for CPU only node
            "nvidia.com/gpu": 1,  # Request 1 GPU per node
        },
    ),
    runtime=torch_runtime,
)
print(f"Job ID: {job_id}")

Job ID: nb39c911399d


## Monitor Training Job

In [7]:
client.wait_for_job_status(name=job_id, status={"Running"})

TrainJob(name='nb39c911399d', runtime=Runtime(name='torch-distributed', trainer=RuntimeTrainer(trainer_type=<TrainerType.CUSTOM_TRAINER: 'CustomTrainer'>, framework='torch', image='pytorch/pytorch:2.7.1-cuda12.8-cudnn9-runtime', num_nodes=1, device='Unknown', device_count='Unknown'), pretrained_model=None), steps=[Step(name='node-0', status='Running', pod_name='nb39c911399d-node-0-0-lgsnk', device='gpu', device_count='1')], num_nodes=1, creation_timestamp=datetime.datetime(2026, 1, 3, 21, 26, 29, tzinfo=TzInfo(0)), status='Running')

In [8]:
for step in client.get_job(name=job_id).steps:
    print(f"Step: {step.name}, Status: {step.status}, Devices: {step.device} x {step.device_count}")

Step: node-0, Status: Running, Devices: gpu x 1


## View Training Logs

In [9]:
for logline in client.get_job_logs(job_id, follow=True):
    print(logline)

W0103 21:26:32.268000 1 site-packages/torch/distributed/run.py:766] 
W0103 21:26:32.268000 1 site-packages/torch/distributed/run.py:766] *****************************************
W0103 21:26:32.268000 1 site-packages/torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0103 21:26:32.268000 1 site-packages/torch/distributed/run.py:766] *****************************************
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Using Device: cpu, Backend: gloo
Usin

## Cleanup

In [10]:
client.delete_job(job_id)