# Audio Classification with PyTorch and Kubeflow Trainer

This example demonstrates how to train a CNN for audio classification using the GTZAN dataset and PyTorch Distributed Data Parallel (DDP).

It follows a development workflow designed for scale:
1. **Define**: Wrap training logic in a self-contained function.
2. **Test Locally**: Run the function in a local subprocess to verify code correctness.
3. **Scale Distributed**: Submit the same function as a distributed `TrainJob` to a Kubernetes cluster.

## 1. Installation

You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:

In [None]:
# !pip install -U kubeflow

Install the model dependencies.

In [None]:
# Install model dependencies
!pip install -q torch torchaudio soundfile kagglehub numpy

## 2. Define Training Function

We define `train_audio_classification` to encapsulate the entire training loop.

**Critical Requirement**: This function must be self-contained. It must import all necessary libraries inside itself because it will be serialized and executed in an isolated environment (container or subprocess).

In [None]:
def train_audio_classification():
    """
    Trains a CNN on the GTZAN audio dataset using PyTorch DDP.
    """
    import os
    import shutil
    from pathlib import Path
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import Dataset, DataLoader, DistributedSampler
    from torchaudio.transforms import MelSpectrogram
    import soundfile as sf
    import platform

    # Suppress warnings and progress bars for cleaner output
    import warnings
    warnings.filterwarnings('ignore')
    os.environ['KAGGLE_DOWNLOAD_PROGRESS'] = '0'

    # --- 1. Environment & Configuration ---
    # Detect if we are running on Windows (Local) or Linux (Cluster/Container)
    if platform.system() == 'Windows':
        data_dir = Path("C:/data/gtzan")
        output_dir = Path("C:/output")
    else:
        data_dir = Path.home() / "data" / "gtzan"
        output_dir = Path.home() / "output"
    
    data_dir.mkdir(parents=True, exist_ok=True)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Hyperparameters
    BATCH_SIZE = 16
    EPOCHS = 5
    LR = 1e-3
    SAMPLE_RATE = 22050
    AUDIO_LENGTH = 3  # seconds
    FIXED_LENGTH = SAMPLE_RATE * AUDIO_LENGTH

    # --- 2. DDP Initialization ---
    def setup_ddp():
        """Initializes the distributed process group (NCCL for GPU, Gloo for CPU)."""
        if all(k in os.environ for k in ["LOCAL_RANK", "RANK", "WORLD_SIZE"]):
            local_rank = int(os.environ["LOCAL_RANK"])
            rank = int(os.environ["RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            
            if torch.cuda.is_available():
                backend = "nccl"
                device = torch.device("cuda", local_rank)
                torch.cuda.set_device(device)
            else:
                backend = "gloo"
                device = torch.device("cpu")
            
            dist.init_process_group(backend=backend)
            print(f"[Rank {rank}/{world_size}] DDP Initialized using {backend}")
            return device, local_rank, rank, world_size
        else:
            print("[Single Process] DDP variables not found. Running in standalone mode.")
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            return device, 0, 0, 1

    def cleanup_ddp():
        if dist.is_initialized():
            dist.destroy_process_group()

    # --- 3. Dataset Download Function ---
    def download_gtzan_dataset(target_dir):
        """
        Download GTZAN dataset from Kaggle or create mock dataset for testing.
        """
        target_path = Path(target_dir)
        
        # If target already exists and contains data, skip
        if target_path.exists() and any(target_path.iterdir()):
            print(f"Dataset directory {target_dir} already exists. Skipping download.")
            return str(target_path)
        
        # Try to download from Kaggle
        try:
            import kagglehub
            print("Downloading GTZAN dataset from Kaggle...")
            download_path = kagglehub.dataset_download(
                "andradaolteanu/gtzan-dataset-music-genre-classification"
            )
            
            print(f"Dataset downloaded successfully.")
            
            # The downloaded dataset contains 'genres_original' folder
            source_genres_path = Path(download_path) / "genres_original"
            
            # Copy the genres_original folder to our target location
            if source_genres_path.exists():
                print(f"Preparing dataset...")
                if target_path.exists():
                    shutil.rmtree(target_path)
                shutil.copytree(source_genres_path, target_path)
                print(f"Dataset ready at {target_dir}")
                return str(target_path)
            else:
                raise RuntimeError(f"genres_original not found in {download_path}")
            
        except Exception as e:
            print(f"Unable to download from Kaggle: {e}")
            print("Creating mock dataset for testing/demonstration purposes...")
            
            # Create a minimal mock dataset with 2 genres and 10 audio files each
            target_path.mkdir(parents=True, exist_ok=True)
            genres = ["rock", "jazz"]
            
            for genre in genres:
                genre_path = target_path / genre
                genre_path.mkdir(exist_ok=True)
                
                # Create 10 mock audio files per genre
                for i in range(10):
                    # Generate 3 seconds of random audio at 22050 Hz
                    duration = 3
                    sample_rate = 22050
                    samples = int(duration * sample_rate)
                    
                    # Create random audio with some variation based on genre
                    if genre == "rock":
                        audio = np.random.uniform(-0.5, 0.5, samples).astype(np.float32)
                    else:
                        audio = np.random.uniform(-0.3, 0.3, samples).astype(np.float32)
                    
                    # Save as WAV file
                    wav_path = genre_path / f"{genre}.{i:05d}.wav"
                    sf.write(str(wav_path), audio, sample_rate)
            
            print(f"Mock dataset created at {target_dir} with {len(genres)} genres")
            print("NOTE: This is synthetic data for testing.")
            return str(target_path)

    # --- 4. Dataset Class ---
    class AudioDataset(Dataset):
        def __init__(self, root_dir):
            self.root = Path(root_dir)

            if not self.root.exists() or not self.root.is_dir():
                raise RuntimeError(f"Directory does not exist: {root_dir}")

            self.classes = sorted(
                p.name for p in self.root.iterdir()
                if p.is_dir() and not p.name.startswith(".")
            )

            if len(self.classes) < 2:
                raise RuntimeError("Must contain at least two class subdirectories")

            self.class_to_idx = {c: i for i, c in enumerate(self.classes)}

            self.files = []
            for cls in self.classes:
                wavs = list((self.root / cls).glob("*.wav"))
                if not wavs:
                    raise RuntimeError(f"No .wav files found in class folder: {cls}")
                self.files.extend(wavs)

            self.mel = MelSpectrogram(
                sample_rate=SAMPLE_RATE,
                n_mels=64
            )

        def __len__(self):
            return len(self.files)

        def __getitem__(self, idx):
            wav_path = self.files[idx]
            label = self.class_to_idx[wav_path.parent.name]

            try:
                # Load audio using soundfile
                waveform, sr = sf.read(str(wav_path))
                waveform = torch.FloatTensor(waveform)
                
                # Handle stereo to mono conversion
                if waveform.ndim == 2:
                    waveform = waveform.mean(dim=1)
                
                # Ensure it's 1D then add channel dimension
                if waveform.ndim == 1:
                    waveform = waveform.unsqueeze(0)
                
                # Resample if needed
                if sr != SAMPLE_RATE:
                    import torchaudio.functional as F
                    waveform = F.resample(waveform, sr, SAMPLE_RATE)
                
                # Fix length: pad or truncate to FIXED_LENGTH
                if waveform.shape[1] > FIXED_LENGTH:
                    waveform = waveform[:, :FIXED_LENGTH]
                elif waveform.shape[1] < FIXED_LENGTH:
                    padding = FIXED_LENGTH - waveform.shape[1]
                    waveform = torch.nn.functional.pad(waveform, (0, padding))

                mel = self.mel(waveform).squeeze(0)
                return mel, label
                
            except Exception as e:
                print(f"Error loading {wav_path}: {e}")
                # Return zeros as fallback
                return torch.zeros(64, 130), 0

    # --- 5. Model Architecture ---
    class AudioCNN(nn.Module):
        def __init__(self, num_classes):
            super().__init__()

            # Convolutional feature extractor
            self.features = nn.Sequential(
                nn.Conv2d(1, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(16, 32, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.AdaptiveAvgPool2d((1, 1))
            )

            self.classifier = nn.Linear(32, num_classes)

        def forward(self, x):
            # Add channel dimension
            x = x.unsqueeze(1)
            x = self.features(x)
            x = x.view(x.size(0), -1)
            return self.classifier(x)

    # --- 6. Main Training Logic ---
    try:
        device, local_rank, rank, world_size = setup_ddp()

        # --- Data Prep ---
        # Only Rank 0 downloads the data
        if rank == 0:
            print("Verifying/Downloading dataset...")
            download_gtzan_dataset(data_dir)
        
        if dist.is_initialized():
            dist.barrier()  # Wait for download to finish

        # Load dataset
        dataset = AudioDataset(data_dir)
        
        # DDP Sampler handles data sharding across nodes
        sampler = DistributedSampler(dataset) if dist.is_initialized() else None
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, 
                          shuffle=(sampler is None), num_workers=0)

        # --- Model & Optimizer ---
        model = AudioCNN(num_classes=len(dataset.classes)).to(device)
        if dist.is_initialized():
            model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None)

        optimizer = optim.Adam(model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()

        # --- Training Loop ---
        print(f"[Rank {rank}] Starting training on {len(dataset.classes)} classes...")
        for epoch in range(EPOCHS):
            model.train()
            if sampler: 
                sampler.set_epoch(epoch)
            
            total_loss = 0.0
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                
                optimizer.zero_grad()
                preds = model(x)
                loss = criterion(preds, y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(loader)
            if rank == 0:
                print(f"Epoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f}")

        # Save model (only rank 0)
        if rank == 0:
            model_path = output_dir / "model.pt"
            if dist.is_initialized():
                torch.save(model.module.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")

        cleanup_ddp()

    except Exception as e:
        print(f"[Error] Training failed: {e}")
        cleanup_ddp()
        raise

## 3. Run Training Locally

Before scaling to a cluster, we test the code locally using the `TrainerClient`. We use the `torch-distributed` runtime which creates a local DDP environment (simulating a cluster node).

In [None]:
from kubeflow.trainer import CustomTrainer, TrainerClient, LocalProcessBackendConfig

# Initialize the client with LocalProcessBackend to run locally
client = TrainerClient(
    backend_config=LocalProcessBackendConfig(cleanup_venv=True)
)

# Retrieve the 'torch-distributed' runtime object from the client's registry
local_runtime = None
for runtime in client.list_runtimes():
    if runtime.name == "torch-distributed":
        local_runtime = runtime
        break

if not local_runtime:
    raise ValueError("Local runtime 'torch-distributed' not found")

print("Starting local training...")

# Start training
job_name = client.train(
    trainer=CustomTrainer(
        func=train_audio_classification,
        # Libraries to install in the local virtual environment
        packages_to_install=["torch", "torchaudio", "soundfile", "kagglehub", "numpy"],
    ),
    runtime=local_runtime, 
)

# Stream logs
print(f"\nJob {job_name} started. Streaming logs...")
for logline in client.get_job_logs(job_name, follow=True):
    print(logline, end='')

## 4. Run Distributed Training (Kubernetes)

Now we scale the job to a Kubernetes cluster. This requires:
1. A Kubernetes cluster with the **Kubeflow Training Operator** installed.
2. The **TrainingRuntime** CRD (`torch-distributed`) installed on the cluster.
3. Access via `~/.kube/config`.

Based on your cluster's resources, we use a minimal configuration to avoid resource constraints.

In [None]:
from kubeflow.trainer import CustomTrainer, TrainerClient

# Initialize client (automatically loads ~/.kube/config)
client = TrainerClient()

print("Submitting distributed TrainJob...")

try:
    job_name = client.train(
        trainer=CustomTrainer(
            func=train_audio_classification,
            num_nodes=1,  # Start with 1 node
            resources_per_node={
                "cpu": "1",         # 1 CPU core
                "memory": "2Gi",    # 2Gi memory
            },
            packages_to_install=["torch", "torchaudio", "soundfile", "kagglehub", "numpy"],
        ),
        runtime="torch-distributed",
    )
    print(f"TrainJob '{job_name}' submitted successfully!")

except Exception as e:
    print(f"Failed to submit TrainJob: {e}")

## 5. Check Status & Logs

Verify the job status and view real-time logs from the master replica.

In [None]:
!kubectl get trainjobs
!kubectl get pods

In [None]:
if 'job_name' in locals():
    print(f"Waiting for job {job_name} to start running...")
    
    # Wait for the running status
    client.wait_for_job_status(name=job_name, status={"Running"}, timeout=300)
    
    print(f"Job is running! Streaming logs...")
    for logline in client.get_job_logs(job_name, follow=True):
        print(logline, end='')