In [5]:
%cd /capstor/scratch/cscs/ckuya/


/capstor/scratch/cscs/ckuya


In [13]:
import os
import numpy as np
import librosa
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import jiwer


In [14]:
device = 'cuda' if torch.cuda.is_available else 'cpu'

## Model

In [None]:
class MelCNN(nn.Module):
    def __init__(self, num_classes):
        super(MelCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),  # (B, 1, 40, 20) → (B, 16, 40, 20)
            nn.ReLU(),
            nn.MaxPool2d(2),                            # (B, 16, 20, 10)

            nn.Conv2d(16, 32, kernel_size=3, padding=1), # (B, 32, 20, 10)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 32, 10, 5)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),                                # (B, 32*10*5)
            nn.Linear(32 * 10 * 5, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.classifier(x)
        return x

class SbuLSTMClassifier(nn.Module):
    def __init__(self,
                 n_mels: int            = 40,
                 hidden_dim: int        = 128,
                 uni_layers: int        = 1,
                 num_classes: int       = 30,
                 dropout: float         = 0.3):
        super(SbuLSTMClassifier, self).__init__()
        # 1) Bidirectional LSTM feature extractor
        self.bdlstm = nn.LSTM(
            input_size   = n_mels,
            hidden_size  = hidden_dim,
            num_layers   = 1,
            batch_first  = True,
            bidirectional= True
        )
        # 2) Unidirectional LSTM for forward‐only refinement
        self.lstm = nn.LSTM(
            input_size  = 2 * hidden_dim,
            hidden_size = hidden_dim,
            num_layers  = uni_layers,
            batch_first = True,
            bidirectional = False
        )
        # 3) MLP classifier on the last time step
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(64, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, 1, n_mels, T)
        B, C, F, T = x.shape
        # remove channel dim → (B, F, T), then time‐first → (B, T, F)
        x = x.view(B, F, T).permute(0, 2, 1)

        # 1) Bidirectional LSTM → (B, T, 2*hidden_dim)
        x, _ = self.bdlstm(x)

        # 2) Unidirectional LSTM → (B, T, hidden_dim)
        x, _ = self.lstm(x)

        # 3) Take last time step features
        last = x[:, -1, :]              # (B, hidden_dim)

        # 4) Classify
        logits = self.classifier(last)  # (B, num_classes)
        return logits


class SbuLSTMFusion(nn.Module):
    def __init__(self, 
                 n_mels=40, 
                 n_mfcc=13, 
                 hidden_dim=256, 
                 num_classes=10, 
                 dropout=0.3):
        super().__init__()
        
        # Conv front-end for mel spectrograms
        self.conv_mel = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1)),  # Pool in frequency domain only
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,1))
        )
        
        # Conv front-end for MFCC
        self.conv_mfcc = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1,1)),  # No pooling for MFCC (already low-dim)
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        # Calculate output dimensions after convolutions
        mel_conv_out = n_mels // 4  # After two /2 max-pooling operations
        mfcc_conv_out = n_mfcc     # No reduction in MFCC dimension
        
        # Deeper BiLSTMs for mel features
        self.bdlstm_mel = nn.LSTM(
            input_size=64 * mel_conv_out,  # 64 channels × reduced frequency bins
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=dropout
        )
        
        # Deeper BiLSTMs for MFCC features
        self.bdlstm_mfcc = nn.LSTM(
            input_size=64 * mfcc_conv_out,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=dropout
        )
        
        # Attention mechanism for temporal pooling
        self.attention = nn.MultiheadAttention(
            embed_dim=4*hidden_dim,  # 2*hidden_dim from each bidirectional LSTM
            num_heads=8,
            dropout=0.1,
            batch_first=True
        )
        
        # Layer normalization for attention input
        self.layer_norm = nn.LayerNorm(4*hidden_dim)
        
        # Final classifier with improved capacity
        self.classifier = nn.Sequential(
            nn.Linear(4*hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, mel, mfcc):
        # mel: (B,1,n_mels,T), mfcc: (B,1,n_mfcc,T)
        B, _, _, T = mel.shape
        
        # Apply convolutional front-ends
        mel_conv = self.conv_mel(mel)  # (B,64,n_mels/4,T)
        mfcc_conv = self.conv_mfcc(mfcc)  # (B,64,n_mfcc,T)
        
        # Reshape for LSTM: (B,T,C*F)
        mel_conv = mel_conv.permute(0, 3, 1, 2)  # (B,T,64,n_mels/4)
        mel_conv = mel_conv.reshape(B, T, -1)    # (B,T,64*n_mels/4)
        
        mfcc_conv = mfcc_conv.permute(0, 3, 1, 2)  # (B,T,64,n_mfcc)
        mfcc_conv = mfcc_conv.reshape(B, T, -1)    # (B,T,64*n_mfcc)
        
        # Apply bidirectional LSTMs
        mel_feats, _ = self.bdlstm_mel(mel_conv)   # (B,T,2*hidden_dim)
        mfcc_feats, _ = self.bdlstm_mfcc(mfcc_conv)  # (B,T,2*hidden_dim)
        
        # Concatenate features from both streams
        fused_feats = torch.cat([mel_feats, mfcc_feats], dim=2)  # (B,T,4*hidden_dim)
        
        # Apply layer normalization
        fused_feats = self.layer_norm(fused_feats)
        
        # Self-attention for temporal context
        attn_out, _ = self.attention(fused_feats, fused_feats, fused_feats)
        
        # Residual connection
        fused_feats = fused_feats + attn_out
        
        # Global attention-weighted pooling
        # Create a learnable query vector for attention-based pooling
        query = torch.mean(fused_feats, dim=1, keepdim=True)  # (B,1,4*hidden_dim)
        
        # Calculate attention scores
        attn_scores = torch.bmm(query, fused_feats.transpose(1, 2))  # (B,1,T)
        attn_weights = F.softmax(attn_scores, dim=2)
        
        # Apply attention weights to get context vector
        context = torch.bmm(attn_weights, fused_feats)  # (B,1,4*hidden_dim)
        context = context.squeeze(1)  # (B,4*hidden_dim)
        
        # Final classification
        logits = self.classifier(context)  # (B,num_classes)
        
        return logits


## Dataloader

In [197]:
class AudioFeatureDataset(Dataset):
    def __init__(self, root_dir,
                 sr=16000,
                 n_fft=2048,
                 hop_percent=0.75,
                 win_length=1600,
                 n_mels=40,
                 n_mfcc=13,
                 max_len=None,
                 exclude_folders=None):
        """
        Dataset for evaluation - returns audio features and ground truth labels
        """
        super().__init__()
        self.root_dir       = root_dir
        self.sr             = sr
        self.n_fft          = n_fft
        self.win_length     = win_length
        self.hop_length     = int(win_length * (1-hop_percent))
        self.n_mels         = n_mels
        self.n_mfcc         = n_mfcc
        self.max_len        = max_len
        self.exclude_folders= set(exclude_folders or [])

        self._prepare_dataset()

    def _prepare_dataset(self):
        # Find all subfolders (classes), filter excludes
        labels = sorted([
            d for d in os.listdir(self.root_dir)
            if os.path.isdir(os.path.join(self.root_dir, d))
               and d not in self.exclude_folders
        ])
        # Build label→index map and reverse mapping for evaluation
        self.label_to_idx = {lab: idx for idx, lab in enumerate(labels)}
        self.idx_to_label = {idx: lab for lab, idx in self.label_to_idx.items()}

        # Walk filesystem and store file paths with labels
        self.samples = []
        for lab in labels:
            folder = os.path.join(self.root_dir, lab)
            for fn in os.listdir(folder):
                if fn.lower().endswith(('.wav','.mp3')):
                    file_path = os.path.join(folder, fn)
                    self.samples.append((file_path, self.label_to_idx[lab], lab, fn))

        # Sanity check
        N = len(self.label_to_idx)
        self.samples = [s for s in self.samples if 0 <= s[1] < N]
        print(f"[Evaluation Dataset] Found {len(self.samples)} files "
              f"under {N} classes: {labels}")

    def __len__(self):
        return len(self.samples)

    def _extract_features(self, path):
        y, _ = librosa.load(path, sr=self.sr)
        # Mel-spectrogram
        m = librosa.feature.melspectrogram(
            y=y, sr=self.sr,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            n_mels=self.n_mels)
        log_mel = librosa.power_to_db(m)
        # MFCC
        mfcc = librosa.feature.mfcc(
            y=y, sr=self.sr,
            n_mfcc=self.n_mfcc,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length)
        # Normalize
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-8)
        mfcc    = (mfcc    - mfcc.mean())    / (mfcc.std()    + 1e-8)
        # Pad/truncate in time
        if self.max_len is not None:
            T = log_mel.shape[1]
            if T >= self.max_len:
                log_mel = log_mel[:, :self.max_len]
                mfcc    = mfcc   [:, :self.max_len]
            else:
                pw = self.max_len - T
                log_mel = np.pad(log_mel, ((0,0),(0,pw)), mode='constant')
                mfcc    = np.pad(mfcc,    ((0,0),(0,pw)), mode='constant')
        return log_mel, mfcc

    def __getitem__(self, idx):
        path, label_idx, label_name, filename = self.samples[idx]
        log_mel, mfcc = self._extract_features(path)
        return {
            'mel':         torch.FloatTensor(log_mel).unsqueeze(0),   # (n_mels, T)
            'mfcc':        torch.FloatTensor(mfcc).unsqueeze(0),      # (n_mfcc, T)
            'label':       torch.LongTensor([label_idx]).squeeze(),  # For model prediction
            'label_name':  label_name,                   # Ground truth word for WER
            'filename':    filename,                     # For tracking individual files
            'file_path':   path                          # Full path for debugging
        }

def get_eval_dataloader(root_dir,
                       mel_dim     = 40,
                       mfcc_dim    = 13,
                       max_len     = 20,
                       batch_size  = 32,  # Smaller batch size for evaluation
                       num_workers = 4,   # Fewer workers for evaluation
                       exclude_folders = []):
    """
    Returns evaluation dataloader optimized for WER computation
    
    Returns: eval_loader, num_classes, label_to_idx, idx_to_label
    """
    ds = AudioFeatureDataset(
        root_dir        = root_dir,
        n_mels          = mel_dim,
        n_mfcc          = mfcc_dim,
        max_len         = max_len,
        exclude_folders = exclude_folders
    )
    
    num_classes = len(ds.label_to_idx)
    
    # Evaluation-specific DataLoader settings
    eval_loader = DataLoader(
        ds, 
        batch_size=batch_size,
        shuffle=False,           # No shuffling for consistent evaluation
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False          # Include all samples for complete evaluation
    )
    
    return eval_loader, num_classes, ds.label_to_idx, ds.idx_to_label

## Evaluation 

In [198]:
eval_loader, num_classes, label_to_idx, idx_to_label = get_eval_dataloader(
    root_dir= "/capstor/scratch/cscs/ckuya/mini_speech_commands",
    batch_size=32,
    exclude_folders=['.ipynb_checkpoints']
)

[Evaluation Dataset] Found 12724 files under 10 classes: ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']


In [199]:
model    = SbuLSTMFusion(
    n_mels     = 40,
    n_mfcc     = 13,
    hidden_dim = 128,
    num_classes= 10,
    dropout    = 0.3
).to(device)

model.load_state_dict(torch.load("/capstor/scratch/cscs/ckuya/speech-processing/keyword-spotting/source/pipeline/hop50/model9.pth", map_location=device))

  model.load_state_dict(torch.load("/capstor/scratch/cscs/ckuya/speech-processing/keyword-spotting/source/pipeline/hop50/model9.pth", map_location=device))


<All keys matched successfully>

In [200]:
def evaluate_with_wer(model, eval_loader, idx_to_label, device='cpu'):
    """
    Evaluation function that computes WER
    
    Args:``
        model: Trained model
        eval_loader: Evaluation dataloader
        idx_to_label: Mapping from class indices to word labels
        device: Device to run evaluation on
    
    Returns:
        wer: Word Error Rate
        predictions: List of predicted words
        ground_truths: List of ground truth words
    """
    model.eval()
    predictions = []
    ground_truths = []
    
    with torch.no_grad():
        for batch in eval_loader:
            # Extract features and move to device
            mel_features = batch['mel'].to(device)
            mfcc_features = batch['mfcc'].to(device)
            logits = model(mel_features, mfcc_features)

            pred_indices = torch.argmax(logits, dim=1)
            
            # Convert predictions to words
            for pred_idx in pred_indices.cpu().numpy():
                predictions.append(idx_to_label[pred_idx])
            
            # Get ground truth words
            for label_name in batch['label_name']:
                ground_truths.append(label_name)
    
    
    return predictions, ground_truths



In [201]:
# Run evaluation
predictions, ground_truths = evaluate_with_wer(model, eval_loader, idx_to_label, device)

In [202]:
#computing Word Error Rate
wer = jiwer.wer(ground_truths, predictions)
print(f"Word Error Rate: {wer:.4f}")

Word Error Rate: 0.4281
