In [None]:
# Setup and imports
from pathlib import Path
import sys

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import (
    AutoFeatureExtractor,
    AutoModelForAudioClassification,
    AutoConfig
)
from tqdm import tqdm
import matplotlib.pyplot as plt

from utils.config_utils import load_env, add_project_root_to_path
load_env()

from config import get_settings, AST_MODEL_NAME, SAMPLE_RATE
from utils.dams_types import (
    SPEECH,
    MUSIC,
    NOISE,
    SPEECH_SCORE,
    MUSIC_SCORE,
    NOISE_SCORE,
    SEGMENT_PATH,
    BLOCS_SMAD_V2_M2D,
    BLOCS_SMAD_V2_AST,
    BLOCS_SMAD_V2_CLAP,
    BLOCS_SMAD_V2_WHISPER,
)
from utils.audio_io import load_mono_resampled

# Ensure project root is in path
add_project_root_to_path()

settings = get_settings()
metadata_dir = Path(settings.metadata_path)
segments_dir = Path(settings.segments_path)

print(f"Metadata directory: {metadata_dir}")
print(f"Segments directory: {segments_dir}")

# Device setup
def get_device() -> torch.device:
    if torch.backends.mps.is_available():
        return torch.device('mps')
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

device = get_device()
print(f"Using device: {device}")


: 

## Configuration

**Change the CSV file here to switch data sources.** The CSV file should be in `data/metadata/` directory.

The notebook supports two formats:
1. **New format** (recommended): `blocs_smad_v2_finetune.csv` with `chosen_speech_label`, `chosen_music_label`, `chosen_noise_label` columns
2. **Old format**: Files like `blocs_smad_v2_m2d.csv`, `blocs_smad_v2_ast.csv`, etc. with `speech_label`, `music_label`, `noise_label` columns

The notebook will automatically detect which format you're using.


In [None]:
# ============================================
# CONFIGURATION: Change this to switch CSV files
# ============================================
CSV_FILE = "blocs_smad_v2_finetune.csv"  # Change this to your desired CSV file

# Training hyperparameters
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
WEIGHT_DECAY = 0.01
TRAIN_SPLIT = 0.8  # 80% train, 20% validation
RANDOM_SEED = 42

# Model configuration
NUM_CLASSES = 3  # [speech, music, noise]
FREEZE_ENCODER = False  # Set to True to freeze AST encoder and only train head

# Load the CSV file
csv_path = metadata_dir / CSV_FILE
print(f"Loading data from: {csv_path}")

if not csv_path.exists():
    raise FileNotFoundError(f"CSV file not found: {csv_path}")

df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} samples")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
df.head()


## Data Loading and Preprocessing

Create a PyTorch Dataset class that loads audio segments and their labels from the CSV.


In [None]:
class PseudolabelDataset(Dataset):
    """Dataset for loading audio segments and multilabel targets from CSV."""
    
    def __init__(
        self,
        df: pd.DataFrame,
        segments_dir: Path,
        feature_extractor: AutoFeatureExtractor,
        use_scores: bool = False,
        use_chosen_labels: bool = True,
    ):
        """
        Args:
            df: DataFrame with columns: segment_path, chosen_speech_label, chosen_music_label, chosen_noise_label
                (or speech_label, music_label, noise_label for old format)
            segments_dir: Directory containing audio segment files
            feature_extractor: AST feature extractor for preprocessing
            use_scores: If True, use continuous scores instead of binary labels
            use_chosen_labels: If True, use chosen_* columns (new format), else use old format columns
        """
        self.df = df.reset_index(drop=True)
        self.segments_dir = segments_dir
        self.feature_extractor = feature_extractor
        self.use_scores = use_scores
        self.use_chosen_labels = use_chosen_labels
        
        # Determine column names based on format
        if use_chosen_labels:
            self.speech_col = 'chosen_speech_label'
            self.music_col = 'chosen_music_label'
            self.noise_col = 'chosen_noise_label'
            self.speech_score_col = 'chosen_speech_score'
            self.music_score_col = 'chosen_music_score'
            self.noise_score_col = 'chosen_noise_score'
        else:
            # Fallback to old format
            self.speech_col = SPEECH
            self.music_col = MUSIC
            self.noise_col = NOISE
            self.speech_score_col = SPEECH_SCORE
            self.music_score_col = MUSIC_SCORE
            self.noise_score_col = NOISE_SCORE
        
        # Validate required columns
        required_cols = [SEGMENT_PATH, self.speech_col, self.music_col, self.noise_col]
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        segment_path = self.segments_dir / row[SEGMENT_PATH]
        
        # Load audio
        waveform = load_mono_resampled(segment_path, target_sr=SAMPLE_RATE)
        
        # Preprocess with feature extractor
        inputs = self.feature_extractor(
            waveform.numpy(),
            sampling_rate=SAMPLE_RATE,
            return_tensors="pt",
            padding=True,
        )
        
        # Extract labels
        if self.use_scores and all(col in row for col in [self.speech_score_col, self.music_score_col, self.noise_score_col]):
            labels = torch.tensor([
                row[self.speech_score_col],
                row[self.music_score_col],
                row[self.noise_score_col]
            ], dtype=torch.float32)
        else:
            labels = torch.tensor([
                int(row[self.speech_col]),
                int(row[self.music_col]),
                int(row[self.noise_col])
            ], dtype=torch.float32)
        
        return {
            'input_values': inputs['input_values'].squeeze(0),
            'labels': labels,
            'segment_path': row[SEGMENT_PATH]
        }

# Initialize feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(AST_MODEL_NAME)
print("Feature extractor initialized")


## Model Definition

Create an AST-based encoder with a custom classification head for 3-class multilabel classification.


In [None]:
class ASTStudentClassifier(nn.Module):
    """AST encoder with custom classification head for multilabel classification."""
    
    def __init__(
        self,
        model_name: str = AST_MODEL_NAME,
        num_classes: int = 3,
        freeze_encoder: bool = False,
    ):
        super().__init__()
        
        # Load pretrained AST model
        config = AutoConfig.from_pretrained(model_name)
        self.encoder = AutoModelForAudioClassification.from_pretrained(
            model_name,
            config=config,
        )
        
        # Get the hidden size from the encoder
        # AST models typically have a classifier head, we'll replace it
        hidden_size = config.hidden_size
        
        # Remove the original classification head (if it exists)
        if hasattr(self.encoder, 'classifier'):
            self.encoder.classifier = nn.Identity()
        elif hasattr(self.encoder, 'projector'):
            self.encoder.projector = nn.Identity()
        
        # Create new classification head for 3-class multilabel
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )
        
        # Freeze encoder if requested
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
            print("Encoder frozen, only training classification head")
        else:
            print("Training full model (encoder + head)")
    
    def forward(self, input_values):
        # Get encoder outputs with output_hidden_states to access features
        outputs = self.encoder(
            input_values,
            output_hidden_states=True,
        )
        
        # Extract pooled features from the last hidden state
        # AST models have hidden_states in the outputs when output_hidden_states=True
        if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
            # Use last hidden state (from the transformer encoder)
            last_hidden_state = outputs.hidden_states[-1]  # [batch, seq_len, hidden_size]
            # Average pooling over sequence dimension
            pooled = last_hidden_state.mean(dim=1)  # [batch, hidden_size]
        else:
            # Fallback: access the base transformer model directly
            # AST models have audio_spectrogram_transformer attribute
            if hasattr(self.encoder, 'audio_spectrogram_transformer'):
                transformer = self.encoder.audio_spectrogram_transformer
                # Get embeddings
                embeddings = transformer.embeddings(input_values)
                # Pass through encoder
                encoder_outputs = transformer.encoder(embeddings, output_hidden_states=True)
                # Get last hidden state and pool
                last_hidden_state = encoder_outputs.last_hidden_state
                pooled = last_hidden_state.mean(dim=1)
            else:
                # Last resort: try to get features from model outputs
                # Some AST models might have different structure
                raise NotImplementedError(
                    f"Could not extract features from AST model. "
                    f"Model structure: {type(self.encoder)}. "
                    f"Available attributes: {dir(self.encoder)}"
                )
        
        # Apply classification head
        logits = self.classifier(pooled)
        return logits

# Initialize model
model = ASTStudentClassifier(
    model_name=AST_MODEL_NAME,
    num_classes=NUM_CLASSES,
    freeze_encoder=FREEZE_ENCODER,
).to(device)

print(f"Model initialized on {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


## Prepare Data Splits

Split the data into train and validation sets.


In [None]:
# Filter out rows with missing data
# Check which format we're using (new format with chosen_* or old format)
if 'chosen_speech_label' in df.columns:
    label_cols = ['chosen_speech_label', 'chosen_music_label', 'chosen_noise_label']
    use_chosen = True
    print("Using new format: chosen_* columns")
else:
    label_cols = [SPEECH, MUSIC, NOISE]
    use_chosen = False
    print("Using old format: speech_label, music_label, noise_label")

df_clean = df.dropna(subset=[SEGMENT_PATH] + label_cols)
print(f"After filtering: {len(df_clean)} samples")

# Split into train and validation
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

indices = np.random.permutation(len(df_clean))
split_idx = int(len(df_clean) * TRAIN_SPLIT)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]

df_train = df_clean.iloc[train_indices].reset_index(drop=True)
df_val = df_clean.iloc[val_indices].reset_index(drop=True)

print(f"Train samples: {len(df_train)}")
print(f"Validation samples: {len(df_val)}")

# Create datasets
train_dataset = PseudolabelDataset(
    df_train,
    segments_dir,
    feature_extractor,
    use_scores=False,  # Use binary labels
    use_chosen_labels=use_chosen,  # Use chosen_* format if available
)

val_dataset = PseudolabelDataset(
    df_val,
    segments_dir,
    feature_extractor,
    use_scores=False,
    use_chosen_labels=use_chosen,  # Use chosen_* format if available
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for compatibility, increase if needed
    pin_memory=True if device.type == 'cuda' else False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if device.type == 'cuda' else False,
)

print("Data loaders created")


## Training Setup

Initialize loss function, optimizer, and scheduler.


In [None]:
# Loss function: BCEWithLogitsLoss for multilabel classification
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

# Learning rate scheduler
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCHS,
    eta_min=LEARNING_RATE * 0.01,
)

print("Training setup complete")
print(f"Loss function: {criterion}")
print(f"Optimizer: {optimizer}")
print(f"Scheduler: {scheduler}")


## Training Loop

Train the model with validation monitoring.


In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': [],
}

def compute_accuracy(logits, labels, threshold=0.5):
    """Compute multilabel accuracy with threshold."""
    probs = torch.sigmoid(logits)
    preds = (probs >= threshold).float()
    # Exact match accuracy (all labels must match)
    correct = (preds == labels).all(dim=1).float()
    return correct.mean().item()

def validate(model, val_loader, criterion, device):
    """Run validation."""
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating", leave=False):
            input_values = batch['input_values'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_values)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            total_acc += compute_accuracy(logits, labels)
            num_batches += 1
    
    return total_loss / num_batches, total_acc / num_batches

# Training loop
best_val_loss = float('inf')
patience = 5
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Training phase
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    num_batches = 0
    
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
    for batch in progress_bar:
        input_values = batch['input_values'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(input_values)
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        train_loss += loss.item()
        train_acc += compute_accuracy(logits, labels)
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{compute_accuracy(logits, labels):.4f}'
        })
    
    avg_train_loss = train_loss / num_batches
    avg_train_acc = train_acc / num_batches
    
    # Validation phase
    avg_val_loss, avg_val_acc = validate(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    # Store history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['train_acc'].append(avg_train_acc)
    history['val_acc'].append(avg_val_acc)
    
    # Print epoch summary
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f} | Val Acc: {avg_val_acc:.4f}")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        # Save best model (you can add model saving here)
        print("âœ“ New best validation loss!")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

print(f"\n{'='*60}")
print("Training complete!")
print(f"{'='*60}")


## Training Curves

Visualize training and validation loss/accuracy over epochs.


In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Best validation loss: {min(history['val_loss']):.4f}")
print(f"Best validation accuracy: {max(history['val_acc']):.4f}")


## Model Evaluation

Evaluate the model on the validation set with detailed metrics.


In [None]:
from sklearn.metrics import (
    classification_report,
    multilabel_confusion_matrix,
    hamming_loss,
    jaccard_score,
)

def evaluate_model(model, val_loader, device, threshold=0.5):
    """Comprehensive model evaluation."""
    model.eval()
    
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Evaluating"):
            input_values = batch['input_values'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_values)
            probs = torch.sigmoid(logits)
            preds = (probs >= threshold).float()
            
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    all_probs = np.vstack(all_probs)
    
    # Classification report
    class_names = ['speech', 'music', 'noise']
    print("Classification Report (per class):")
    print(classification_report(
        all_labels,
        all_preds,
        target_names=class_names,
        zero_division=0
    ))
    
    # Multilabel metrics
    print(f"\nHamming Loss: {hamming_loss(all_labels, all_preds):.4f}")
    print(f"Jaccard Score (micro): {jaccard_score(all_labels, all_preds, average='micro'):.4f}")
    print(f"Jaccard Score (macro): {jaccard_score(all_labels, all_preds, average='macro'):.4f}")
    print(f"Jaccard Score (per class): {jaccard_score(all_labels, all_preds, average=None)}")
    
    # Exact match accuracy
    exact_match = (all_preds == all_labels).all(axis=1).mean()
    print(f"\nExact Match Accuracy: {exact_match:.4f}")
    
    return {
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs,
    }

# Run evaluation
eval_results = evaluate_model(model, val_loader, device)


## Save Model (Optional)

Save the trained model for later use.


In [None]:
# Uncomment to save the model
# model_save_path = settings.models_path / f"ast_student_{CSV_FILE.replace('.csv', '')}.pt"
# torch.save({
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'history': history,
#     'config': {
#         'num_classes': NUM_CLASSES,
#         'freeze_encoder': FREEZE_ENCODER,
#         'csv_file': CSV_FILE,
#     }
# }, model_save_path)
# print(f"Model saved to: {model_save_path}")
