In [1]:
# IMPORTS
import os
import numpy as np
import json
import random
import shutil
import zipfile

from transformers import ASTForAudioClassification, ASTFeatureExtractor, TrainingArguments, ASTConfig, Trainer, EarlyStoppingCallback, AutoConfig, set_seed as t_set_seed
from transformers.integrations import TensorBoardCallback
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio 
import torchaudio.transforms as T
import evaluate

from datasets import Dataset, disable_caching, DatasetDict, Audio
from PIL import Image, ImageOps
from peft import LoraConfig, get_peft_model
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
import librosa
from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift

In [2]:
def set_seed(seed):
    random.seed(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    t_set_seed(seed)

In [3]:
# Clear Hugging Face datasets cache (memory issues)
cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
if os.path.exists(cache_dir):
    print(f"Clearing Hugging Face datasets cache at {cache_dir} ...")
    shutil.rmtree(cache_dir)
    print("Cache cleared.")
else:
    print("No Hugging Face datasets cache found.")

Clearing Hugging Face datasets cache at /Users/harrywills/.cache/huggingface/datasets ...
Cache cleared.


In [4]:
# PARAMETERS
training_size = int(input("Enter the amount of spectrograms per class to train on (0 for all): ")) # Number of spectrograms per class to use for training (0 for all)
print(f"Training size per class: {training_size if training_size > 0 else 'All available'}")
segments_path = "./segments"
model_output_dir = "./ast-base-manuai" # Fine-tuned model output directory
adapters_dir = "./manuai_lora_adapters" # LoRA adapters output directory
checkpoints_dir = "./manuai_checkpoints" # Checkpoints output directory
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593" # Pre-trained model name or path
# Check if fine-tuned model already exists
if os.path.exists(model_output_dir):
    model_name = model_output_dir
    print("Using existing fine-tuned model as base.")

processor = ASTFeatureExtractor.from_pretrained(model_name)
sample_rate = 16000
epochs = 20
batch_size = 8
n_proc = 2 # Number of processes for parallel processing
dataloader_num_workers=0 # Number of workers for data loading (during training)
seed = 42
segment_len = 5.0
lora_rank = 16
image_size = 224  # ViT base model image size
disable_caching() # Disable caching to avoid potential issues with large datasets
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

set_seed(seed)

accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
# tensorboard --logdir manuai_checkpoints/

Training size per class: 1000
Using device: mps
Using device: mps


In [5]:
audio_augmentations = Compose([
    AddGaussianSNR(min_snr_db=10, max_snr_db=20),
    Gain(min_gain_db=-6, max_gain_db=6),
    GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit="fraction"),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.2),
    PitchShift(min_semitones=-4, max_semitones=4),
], p=0.8, shuffle=True)

In [None]:
def time_stretch_waveform(waveform, rate=1.1):
    rate = random.uniform(0.8, 1.2) if rate is None else rate
    waveform_np = waveform.squeeze().detach().cpu().numpy()
    if waveform_np.ndim > 1:
        waveform_np = waveform_np.mean(axis=0)  # Convert to mono
    stretched = librosa.effects.time_stretch(y=waveform_np, rate=rate)
    stretched_tensor = torch.tensor(stretched, dtype=waveform.dtype, device=waveform.device)
    if stretched_tensor.ndim == 1:
        stretched_tensor = stretched_tensor.unsqueeze(0)  # Ensure shape [1, time]
    return stretched_tensor

def preprocess_audio_with_transforms(batch):
    # we apply augmentations on each waveform
    wavs = [audio_augmentations(np.array(audio), sample_rate=sample_rate) for audio in batch["input_values"]]
    inputs = processor(wavs, sampling_rate=sample_rate, return_tensors="pt")

    output_batch = {"input_values": inputs.get("input_values"), "labels": list(batch["labels"])}
    return output_batch

def augment_audio(sample, sample_rate, max_attempts=3):
    sample = sample.to(device)
    augmentations = [
        lambda x: T.PitchShift(sample_rate, n_steps=random.choice([-2, -1, 1, 2])).to(device)(x), # Change pitch by -2, -1, +1, or +2 semitones
        #lambda x: x + torch.randn_like(x) * min(0.002, x.std().item() * 0.1), # Add Gaussian noise with stddev up to 10% of original signal's stddev, capped at 0.002
        lambda x: T.FrequencyMasking(freq_mask_param=random.randint(8, 16)).to(device)(x), # Apply frequency masking with max width of 16 bins
        lambda x: T.TimeMasking(time_mask_param=random.randint(8, 20)).to(device)(x), # Apply time masking with max width of 20 frames
        lambda x: time_stretch_waveform(x), # Time-stretch by a random rate between 0.8 and 1.2
    ]
    
    for attempt in range(max_attempts):
        num_aug = random.randint(2, 3)  # Apply 2 to 3 augmentations
        aug_funcs = random.sample(augmentations, num_aug)
        augmented = sample.clone()  # Preserve original sample
        
        for augment in aug_funcs:
            try:
                temp_augmented = augment(augmented)
                # Check if augmentation produces valid output
                if is_valid_waveform(temp_augmented, min_variance=1e-8, min_amplitude=1e-4):
                    augmented = temp_augmented
                else:
                    continue
            except Exception as e:
                print(f"Augmentation error: {e}")
                continue
        
        # Final validation before returning
        if is_valid_waveform(augmented, min_variance=1e-8, min_amplitude=1e-4):
            return augmented
        print(f"Attempt {attempt + 1} failed: aug_var={augmented.var().item():.6f}, aug_max={augmented.abs().max().item():.6f}")
    
    # If all attempts fail, return augmented sample
    print("All augmentation attempts failed, returning augmented sample")
    return augmented

def is_valid_waveform(waveform, min_variance=1e-8, min_amplitude=1e-4):
    # If waveform is empty or has low variance, it's invalid
    return waveform.abs().sum() > min_amplitude and waveform.var() > min_variance

def load_audio_segments():
    """
    Load exactly `training_size` samples per class.
    Uses augmentation to fill the gap if there aren't enough originals.
    """
    augmented_count = 0
    files_labels = {label: [] for label in labels}
    for root, dirs, files in os.walk(segments_path):
        for file in files:
            if file.endswith(".wav"):
                label = os.path.splitext(file)[0].split('_')[1]
                files_labels[label].append(os.path.join(root, file))
    for label in labels:
        print(f"Found {len(files_labels[label])} files for label '{label}'")
        files = files_labels[label]

        # Case 1: More files than training_size -> sample down
        if training_size > 0 and len(files) > training_size:
            selected_files = random.sample(files, training_size)
        else:
            selected_files = list(files)  # copy
        samples = []

        # Load original files
        for file_path in selected_files:
            waveform, sr = torchaudio.load(file_path)
            if sr != sample_rate:
                waveform = T.Resample(sr, sample_rate)(waveform)
                sr = sample_rate

            if not is_valid_waveform(waveform):
                print(f"Invalid original waveform for file: {file_path}")
                continue

            samples.append({
                "input_values": waveform.squeeze().detach().cpu().numpy(),
                "label": label_to_id[label]
            })

        # Case 2: If need to augment more samples to reach training_size
        while len(samples) < training_size and len(selected_files) > 0:
            f = random.choice(selected_files)
            waveform, sr = torchaudio.load(f)
            if sr != sample_rate:
                waveform = T.Resample(sr, sample_rate)(waveform)
                sr = sample_rate

            if not is_valid_waveform(waveform):
                continue

            augmented = augment_audio(waveform, sr)
            if not is_valid_waveform(augmented):
                continue

            augmented_count += 1
            samples.append({
                "input_values": augmented.squeeze().detach().cpu().numpy(),
                "label": label_to_id[label]
            })

        # Ensure exactly training_size (trim if overshot)
        samples = samples[:training_size]

        # Yield per-class samples
        for s in samples:
            yield s
    print(f"Total augmented samples created: {augmented_count}")

labels = sorted([d for d in os.listdir(segments_path) if not d.startswith('.')]) # Exclude hidden files
label_to_id = {lbl: i for i, lbl in enumerate(labels)}
id_to_label = {i: lbl for lbl, i in label_to_id.items()}

dataset = Dataset.from_generator(load_audio_segments, cache_dir=None)

print("Final label order:", labels)

# Preprocess dataset
dataset = dataset.map(
    preprocess_audio_with_transforms,
    batched=True,
    remove_columns=["input_values"],
    num_proc=n_proc
)
if dataset:
    print(f"✅ Dataset created successfully with {len(dataset)} samples.")
else:
    print("❌ Dataset creation failed.")

Generating train split: 0 examples [00:00, ? examples/s]

Found 61489 files for label 'bellbird'




Found 1515 files for label 'fantail'
Found 21735 files for label 'greywarbler'
Found 1486 files for label 'kaka'
Found 13 files for label 'kakapo'




Found 4412 files for label 'kea'
Found 291 files for label 'kereru'
Found 313 files for label 'kingfisher'
Found 11006 files for label 'kiwi'
Found 1519 files for label 'kokako'
Found 60698 files for label 'morepork'
Found 1813 files for label 'pukeko'
Found 9215 files for label 'robin'
Found 1238 files for label 'saddleback'
Found 50446 files for label 'silvereye'
Found 723 files for label 'stitchbird'
Found 58462 files for label 'tomtit'
Found 26785 files for label 'tui'
Found 2798 files for label 'whitehead'
Found 773 files for label 'yellowhead'
Total augmented samples created: 2887
Final label order: ['bellbird', 'fantail', 'greywarbler', 'kaka', 'kakapo', 'kea', 'kereru', 'kingfisher', 'kiwi', 'kokako', 'morepork', 'pukeko', 'robin', 'saddleback', 'silvereye', 'stitchbird', 'tomtit', 'tui', 'whitehead', 'yellowhead']


Map (num_proc=2):   0%|          | 0/20000 [00:00<?, ? examples/s]

AttributeError: 'list' object has no attribute 'dtype'

In [None]:
def collate_fn(batch):
    input_values = []
    labels = []
    for item in batch:
        val = item["input_values"]
        if isinstance(val, list):
            val = torch.tensor(val)
        input_values.append(val)
        labels.append(item["label"])
    input_values = torch.stack(input_values)
    labels = torch.tensor(labels, dtype=torch.long)
    return {
        "input_values": input_values,
        "labels": labels
    }

config = ASTConfig.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id_to_label,
    label2id=label_to_id,
)
base_model = ASTForAudioClassification.from_pretrained(
    model_name,
    config=config,
    ignore_mismatched_sizes=True
)

lora = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 4,
    lora_dropout=0.1,
    bias="none", 
    target_modules=["query", "key", "value", "dense"], # Attention layers
    modules_to_save=["classifier"]
)

model = get_peft_model(base_model, lora).to(device)

In [None]:
# Split data into train, test, and validation sets
split_labels = list(dataset["label"])
train_idx, temp_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    stratify=split_labels,
    random_state=seed
)

temp_labels = [split_labels[i] for i in temp_idx]
val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=temp_labels,
    random_state=seed
)

# Create splits
dataset = DatasetDict({
    "train": dataset.select(train_idx),
    "validation": dataset.select(val_idx),
    "test": dataset.select(test_idx)
})

print(f"Train size: {len(dataset['train'])}, Test size: {len(dataset['test'])}, Validation size: {len(dataset['validation'])}")

In [None]:
# SET UP MODEL
def trainable_parameters(model):
    params, trainable = 0, 0
    for _, p in model.named_parameters():
        params += p.numel()
        trainable += p.numel() if p.requires_grad else 0

    print(f"{model.__class__.__name__} trainable parameters: {trainable:,}/{params:,} ({100 * trainable / params:.2f}%)")

def collate_fn(batch):
    pixel_values = torch.stack([torch.tensor(item["pixel_values"]) for item in batch])
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
    return {
        "pixel_values": pixel_values, 
        "labels": labels
    }

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    acc = accuracy_metric.compute(predictions=preds, references=labels)["accuracy"]
    prec = precision_metric.compute(predictions=preds, references=labels, average="weighted")["precision"]
    rec = recall_metric.compute(predictions=preds, references=labels, average="weighted")["recall"]
    f1_weighted = f1_metric.compute(predictions=preds, references=labels, average="weighted")["f1"]
    f1_macro = f1_metric.compute(predictions=preds, references=labels, average="macro")["f1"]
    auc = roc_auc_score(labels, torch.nn.functional.softmax(torch.tensor(logits), dim=-1).numpy(), multi_class='ovr', average='macro')
    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "eval_f1": f1_weighted,
        "f1_macro": f1_macro,
        "auc": auc
    }
    
def focal_loss(inputs, targets, gamma=2.0, alpha=None, weight=None, label_smoothing=0.0, reduction='mean'):
    # per-sample CE 
    ce = F.cross_entropy(inputs, targets, weight=weight, reduction='none', label_smoothing=label_smoothing)
    pt = torch.exp(-ce)  # p_t for the true class
    if alpha is None:
        alpha_t = 1.0
    else:
        if isinstance(alpha, (list, tuple, np.ndarray)):
            alpha = torch.tensor(alpha, device=inputs.device, dtype=inputs.dtype)
            alpha_t = alpha[targets]
        elif isinstance(alpha, torch.Tensor):
            alpha_t = alpha[targets]
        else:
            alpha_t = torch.tensor(float(alpha), device=inputs.device, dtype=inputs.dtype)

    loss = alpha_t * ((1 - pt) ** gamma) * ce
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)   
        logits = outputs.get("logits")
        loss_fn = nn.CrossEntropyLoss(weight=self.class_weights, label_smoothing=training_args.label_smoothing_factor) # Weighted cross-entropy loss
        loss = loss_fn(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        #loss = focal_loss(logits.view(-1, self.model.config.num_labels), labels.view(-1), weight=self.class_weights, label_smoothing=training_args.label_smoothing_factor) # Focal loss alternative
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir=checkpoints_dir,
    logging_dir=str(checkpoints_dir + "/runs"),
    learning_rate=7e-5, 
    lr_scheduler_type="cosine", # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", or "constant_with_warmup"
    warmup_ratio=0.1, 
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=epochs,
    gradient_accumulation_steps=1, # (batch_size * gradient_accumulation_steps = effective batch size)
    weight_decay=0.01,
    eval_strategy="steps", # "steps" or "epoch"
    eval_steps=500, # Only if eval_strategy="steps"
    save_strategy="steps", # "steps" or "epoch"
    save_steps=500, # Only if save_strategy="steps"
    load_best_model_at_end=True,
    remove_unused_columns=False,
    bf16=False,
    logging_steps=100,
    report_to="tensorboard",
    save_total_limit=3,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    label_smoothing_factor=0.05,
    dataloader_pin_memory=False,
    max_grad_norm=1.0,
    use_mps_device = True if device == "mps" else False,
    #dataloader_num_workers=dataloader_num_workers,
)

# Model Architecture
callbacks = [EarlyStoppingCallback(early_stopping_patience=4), TensorBoardCallback()]

base_model = ASTForAudioClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id_to_label,
    label2id=label_to_id,
    ignore_mismatched_sizes=True
)

lora = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    lora_dropout=0.1,
    bias="none", 
    target_modules=["query", "value"], # Attention layers
    modules_to_save=["classifier"]
)

peft_model = get_peft_model(base_model, lora).to(device)

class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.unique(dataset["train"]["label"]), y=dataset["train"]["label"]), dtype=torch.float32).to(device) # Compute class weights to handle class imbalance
trainer = WeightedTrainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    compute_metrics=compute_metrics,
    tokenizer=processor,
    data_collator=collate_fn,
    callbacks=callbacks,
    class_weights=class_weights
)

print(f"ID to label mapping: {id_to_label}")
trainable_parameters(peft_model)
print("Model set-up complete. Ready to begin training...")

In [None]:
# TRAIN MODEL
result = trainer.train()
trainer.log_metrics("train", result.metrics)
trainer.save_metrics("train", result.metrics)

In [None]:
# EVALUATE ON TEST SET 
metrics = trainer.evaluate(eval_dataset=dataset["test"])
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)

In [None]:
# Save final model and processor
base_model.save_pretrained(model_output_dir) # Saves full fine-tuned model (ASTForAudioClassification.from_pretrained("./ast-base-manuai"))
processor.save_pretrained(model_output_dir) # Saves feature extractor (ASTFeatureExtractor.from_pretrained("./ast-base-manuai"))
print(f"Model fine-tuned and saved to {model_output_dir}")
# Save LoRA adapters
model.save_pretrained(adapters_dir) # Saves only LoRA adapters (PEFTModel.from_pretrained(base_model, "./manuai_lora_adapters"))
print(f"LoRA adapters saved to {adapters_dir}")