# PyTorch Speech Recognition with Kubeflow Trainer

This example demonstrates how to train a speech recognition model using the [Google Speech Commands](https://www.tensorflow.org/datasets/catalog/speech_commands) dataset and PyTorch Distributed Data Parallel (DDP).

It follows a development workflow designed for scale:
1. **Define**: Wrap training logic in a self-contained function.
2. **Test Locally**: Run the function in a local subprocess to verify code correctness.
3. **Scale Distributed**: Submit the same function as a distributed `TrainJob` to a Kubernetes cluster.

## 1. Installation

Install the Kubeflow Trainer SDK and the model dependencies (PyTorch, Torchaudio, etc).

In [7]:
# Install Kubeflow Trainer SDK (Development Mode)
# If you are not developing the SDK, use: pip install kubeflow-trainer
!pip install -q -e ../../../api/python_api

# Install model dependencies
!pip install -q torch torchaudio librosa soundfile tensorboard



## 2. Define Training Function

We define `train_speech_recognition` to encapsulate the entire training loop.

**Critical Requirement**: This function must be self-contained. It must import all necessary libraries inside itself because it will be serialized and executed in an isolated environment (container or subprocess).

In [8]:
def train_speech_recognition():
    """
    Trains an Audio Transformer on the Speech Commands dataset using PyTorch DDP.
    """
    import os
    import torch
    import torch.nn as nn
    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data import Dataset, DataLoader, DistributedSampler
    import torchaudio
    import random
    from pathlib import Path
    import platform

    # Suppress warnings and progress bars for cleaner output
    import warnings
    warnings.filterwarnings('ignore')
    os.environ['TORCHAUDIO_DOWNLOAD_PROGRESS'] = '0'

    # --- 1. Environment & Configuration ---
    # Detect if we are running on Windows (Local) or Linux (Cluster/Container)
    if platform.system() == 'Windows':
        data_path = Path("C:/data")
        output_dir = Path("C:/output")
    else:
        data_path = Path.home() / "data"
        output_dir = Path.home() / "output"
    
    data_path.mkdir(parents=True, exist_ok=True)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Hyperparameters
    BATCH_SIZE = 32
    EPOCHS = 3
    LR = 0.001

    # --- 2. DDP Initialization ---
    def setup_ddp():
        """Initializes the distributed process group (NCCL for GPU, Gloo for CPU)."""
        if all(k in os.environ for k in ["LOCAL_RANK", "RANK", "WORLD_SIZE"]):
            local_rank = int(os.environ["LOCAL_RANK"])
            rank = int(os.environ["RANK"])
            world_size = int(os.environ["WORLD_SIZE"])
            
            if torch.cuda.is_available():
                backend = "nccl"
                device = torch.device("cuda", local_rank)
                torch.cuda.set_device(device)
            else:
                backend = "gloo"
                device = torch.device("cpu")
            
            dist.init_process_group(backend=backend)
            print(f"[Rank {rank}/{world_size}] DDP Initialized using {backend}")
            return device, local_rank, rank, world_size
        else:
            print("[Single Process] DDP variables not found. Running in standalone mode.")
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            return device, 0, 0, 1

    def cleanup_ddp():
        if dist.is_initialized():
            dist.destroy_process_group()

    # --- 3. Model Architecture ---
    class AudioTransformer(nn.Module):
        def __init__(self, num_classes=35, d_model=128, nhead=4, num_layers=2):
            super().__init__()
            self.encoder = nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
                num_layers=num_layers
            )
            self.classifier = nn.Linear(d_model, num_classes)

        def forward(self, x):
            # Input: (batch, n_mels, time) -> Permute to (batch, time, n_mels) for Transformer
            x = x.permute(0, 2, 1)
            x = self.encoder(x)
            # Simple global average pooling
            x = x.mean(dim=1)
            return self.classifier(x)

    # --- 4. Dataset Wrapper ---
    class SpeechDataset(Dataset):
        def __init__(self, audio_paths, label_map):
            self.audio_paths = audio_paths
            self.label_map = label_map
            self.transform = torchaudio.transforms.MelSpectrogram(n_mels=128)

        def __len__(self): 
            return len(self.audio_paths)

        def __getitem__(self, idx):
            try:
                path, label = self.audio_paths[idx]
                waveform, _ = torchaudio.load(path)
                # Ensure fixed length (1 second @ 16kHz)
                target_len = 16000
                if waveform.size(1) < target_len:
                    waveform = torch.nn.functional.pad(waveform, (0, target_len - waveform.size(1)))
                else:
                    waveform = waveform[:, :target_len]
                
                spec = self.transform(waveform).squeeze(0)
                return spec, self.label_map[label]
            except Exception as e:
                print(f"Error loading audio file: {e}")
                return torch.zeros(128, 81), 0  # Fallback for corrupted files

    def collate_fn(batch):
        specs, labels = zip(*batch)
        return torch.stack(specs), torch.tensor(labels, dtype=torch.long)

    # --- 5. Main Training Logic ---
    try:
        device, local_rank, rank, world_size = setup_ddp()

        # --- Data Prep ---
        # Only Rank 0 downloads the data to the shared volume/cache
        if rank == 0:
            print("Verifying/Downloading dataset...")
            torchaudio.datasets.SPEECHCOMMANDS(root=str(data_path), download=True)
        
        if dist.is_initialized():
            dist.barrier()  # Wait for download to finish

        # Load file list manually for flexibility
        data_root = data_path / "SpeechCommands" / "speech_commands_v0.02"
        labels = sorted([d.name for d in data_root.iterdir() if d.is_dir() and not d.name.startswith("_")])
        label_map = {l: i for i, l in enumerate(labels)}
        
        all_files = []
        for label in labels:
            for f in (data_root / label).glob("*.wav"):
                all_files.append((str(f), label))
        
        # Subset for faster demonstration
        random.shuffle(all_files)
        train_files = all_files[:1000]

        # DDP Sampler handles data sharding across nodes
        dataset = SpeechDataset(train_files, label_map)
        sampler = DistributedSampler(dataset) if dist.is_initialized() else None
        loader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, 
                          shuffle=(sampler is None), collate_fn=collate_fn, num_workers=0)

        # --- Model & Optimizer ---
        model = AudioTransformer(num_classes=len(labels)).to(device)
        if dist.is_initialized():
            model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None)

        optimizer = torch.optim.Adam(model.parameters(), lr=LR)
        criterion = nn.CrossEntropyLoss()

        # --- Training Loop ---
        print(f"[Rank {rank}] Starting training on {len(labels)} classes...")
        for epoch in range(1, EPOCHS + 1):
            model.train()
            if sampler: 
                sampler.set_epoch(epoch)
            
            total_loss = 0.0
            for i, (data, target) in enumerate(loader):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            avg_loss = total_loss / len(loader)
            if rank == 0:
                print(f"Epoch {epoch}/{EPOCHS} | Loss: {avg_loss:.4f}")

        # Save model (only rank 0)
        if rank == 0:
            model_path = output_dir / "speech_model.pt"
            if dist.is_initialized():
                torch.save(model.module.state_dict(), model_path)
            else:
                torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")

        cleanup_ddp()

    except Exception as e:
        print(f"[Error] Training failed: {e}")
        import traceback
        traceback.print_exc()
        cleanup_ddp()
        raise

## 3. Run Training Locally

Before scaling to a cluster, we test the code locally using the `TrainerClient`. We use the `torch-distributed` runtime which creates a local DDP environment (simulating a cluster node).

In [9]:
from kubeflow.trainer import CustomTrainer, TrainerClient, LocalProcessBackendConfig

# Initialize the client with LocalProcessBackend to run locally
client = TrainerClient(
    backend_config=LocalProcessBackendConfig(cleanup_venv=True)
)

# Retrieve the 'torch-distributed' runtime object from the client's registry
local_runtime = None
for runtime in client.list_runtimes():
    if runtime.name == "torch-distributed":
        local_runtime = runtime
        break

if not local_runtime:
    raise ValueError("Local runtime 'torch-distributed' not found")

print("Starting local training...")
print("Note: First run will download the dataset (~2GB), which may take a few minutes.\n")

# Start training
local_job_name = client.train(
    trainer=CustomTrainer(
        func=train_speech_recognition,
        # Libraries to install in the local virtual environment
        packages_to_install=["torch", "torchaudio", "librosa", "soundfile", "tensorboard"],
    ),
    runtime=local_runtime, 
)

# Stream logs
print(f"Job {local_job_name} started. Streaming logs...\n")
try:
    for logline in client.get_job_logs(local_job_name, follow=True):
        print(logline, end='')
except KeyboardInterrupt:
    print("\n\nLog streaming interrupted by user.")

Starting local training...
Note: First run will download the dataset (~2GB), which may take a few minutes.

Job o986a4180f2d started. Streaming logs...

T h e   R P C   c a l l   c o n t a i n s   a   h a n d l e   t h a t   d i f f e r s   f r o m   t h e   d e c l a r e d   h a n d l e   t y p e .   
 
 E r r o r   c o d e :   B a s h / S e r v i c e / 0 x 8 0 0 7 0 7 2 c 
 
 [o986a4180f2d-train] Completed with code 1 in 0:00:00.147117 seconds.T h e   R P C   c a l l   c o n t a i n s   a   h a n d l e   t h a t   d i f f e r s   f r o m   t h e   d e c l a r e d   h a n d l e   t y p e .     E r r o r   c o d e :   B a s h / S e r v i c e / 0 x 8 0 0 7 0 7 2 c   [o986a4180f2d-train] Completed with code 1 in 0:00:00.147117 seconds.

## 4. Run Distributed Training (Kubernetes)

Now we scale the job to a Kubernetes cluster. This requires:
1. A Kubernetes cluster with the **Kubeflow Training Operator** installed.
2. The **TrainingRuntime** CRD (`torch-distributed`) installed on the cluster.
3. Access via `~/.kube/config`.

Based on your cluster's resources, we use a minimal configuration.

In [10]:
from kubeflow.trainer import CustomTrainer, TrainerClient

# Initialize client (automatically loads ~/.kube/config)
client = TrainerClient()

print("Submitting distributed TrainJob...")

try:
    k8s_job_name = client.train(
        trainer=CustomTrainer(
            func=train_speech_recognition,
            num_nodes=1,  # Start with 1 node (can scale to 3 if resources allow)
            resources_per_node={
                "cpu": "2",
                "memory": "4Gi",
                # "nvidia.com/gpu": "1",  # Uncomment for GPU training
            },
            packages_to_install=["torch", "torchaudio", "librosa", "soundfile", "tensorboard"],
        ),
        runtime="torch-distributed",  # Matches the TrainingRuntime on the cluster
    )
    print(f"TrainJob '{k8s_job_name}' submitted successfully!")
    print(f"\nYou can check the job status with: kubectl get trainjobs {k8s_job_name}")

except Exception as e:
    print(f"Failed to submit TrainJob: {e}")
    import traceback
    traceback.print_exc()

Submitting distributed TrainJob...
TrainJob 's13cdaf8c7b2' submitted successfully!

You can check the job status with: kubectl get trainjobs s13cdaf8c7b2


## 5. Check Status & Logs

Verify the job status and view real-time logs from the master replica.

In [11]:
!kubectl get trainjobs
!kubectl get pods

NAME           STATE      AGE
c5cb3f1787ac   Complete   26m
l29bae43c5be   Complete   75m
s13cdaf8c7b2              1s
NAME                          READY   STATUS              RESTARTS   AGE
c5cb3f1787ac-node-0-0-9rpgk   0/1     Completed           0          26m
l29bae43c5be-node-0-0-jr5rf   0/1     Completed           0          75m
s13cdaf8c7b2-node-0-0-g2ns4   0/1     ContainerCreating   0          0s


In [12]:
# Only run this cell if you successfully created a Kubernetes job in the previous step
if 'k8s_job_name' in locals():
    print(f"Waiting for job {k8s_job_name} to start running...")
    
    try:
        # Wait for the running status
        client.wait_for_job_status(name=k8s_job_name, status={"Running"}, timeout=300)
        
        print(f"Job is running! Streaming logs...\n")
        for logline in client.get_job_logs(k8s_job_name, follow=True):
            print(logline, end='')
    except KeyboardInterrupt:
        print("\n\nLog streaming interrupted by user.")
    except Exception as e:
        print(f"\nError while waiting for job or streaming logs: {e}")
        print(f"You can manually check logs with: kubectl logs -f {k8s_job_name}-node-0-0-<pod-id>")
else:
    print("No Kubernetes job found. Please run the previous cell to create a job first.")

Waiting for job s13cdaf8c7b2 to start running...
Job is running! Streaming logs...
