In [None]:
# This notebook has been run on Kaggle. To run it, uncomment the following lines
# %cd /kaggle/working/
# !rm -rf /kaggle/working/Few-shot-learning-with-HuBERT/
# !git clone -b add-prototyphubert-training https://github.com/luckyman94/Few-shot-learning-with-HuBERT.git
# !pip install -U kagglehub

In [None]:
#%cd /kaggle/working/Few-shot-learning-with-HuBERT

/kaggle/working/Few-shot-learning-with-HuBERT


In [None]:
import os
import sys
from pathlib import Path
import torch
from transformers import HubertModel


# ---- Hugging Face token ----
# Uncomment the following line if you want to run the notebook on Google Colab. Make sure to replace "YOUR
#os.environ["HF_TOKEN"] = "YOUR TOKEN HERE"

# Uncomment the following lines if you want to run the notebook on Google Colab
#ROOT = Path.cwd() / "Few-shot-learning-with-HuBERT"
#sys.path.append(str(ROOT))


assert "HF_TOKEN" in os.environ
project_root = Path().resolve().parent
sys.path.append(str(project_root))


from src.methods.fewshot.train import (
    prototypical_train,
    build_embedding_cache,
    EmbeddingDataset,
)
from src.methods.fewshot.benchmark import benchmark_fewshot_training

from src.datasets.factory import build_dataset
from src.datasets.registry import DATASET_REGISTRY
from src.datasets.split import split_dataset_by_classes

from src.datasets.bird_dog_cat import AnimalAudioDataset
from src.datasets.speech_commands import SpeechCommandsDataset
from src.datasets.timit import TimitDataset
from src.datasets.snoring_dataset import SnoringDataset
from src.datasets.noise_dataset import SyntheticAudioNoiseDataset
from src.datasets.harmonics_dataset import SyntheticAudioHarmonicsDataset
from src.datasets.urban_sound_8k import UrbanDataset
from src.datasets.crema import CremaDDataset


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


## Loading of the model

In [5]:
hubert = HubertModel.from_pretrained(
    "facebook/hubert-base-ls960",
    use_safetensors=True,
).to(device)

for p in hubert.parameters():
    p.requires_grad = False

hubert.eval()

Loading weights:   0%|          | 0/211 [00:00<?, ?it/s]

HubertModel(
  (feature_extractor): HubertFeatureEncoder(
    (conv_layers): ModuleList(
      (0): HubertGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): HubertFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): HubertEncoder(
    (pos_conv_embed): HubertPositionalConvEmbedding(
      (conv): Para

## Build the trainable head of the HuBERT

In [12]:
head = torch.nn.Sequential(
    torch.nn.Linear(hubert.config.hidden_size, 128),
    torch.nn.LayerNorm(128),
).to(device)

optimizer = torch.optim.Adam(
    head.parameters(),
    lr=1e-3
)

## Training

In [None]:
results = []

SHOTS = [1, 5, 10]

for name, cfg in DATASET_REGISTRY.items():

    if name == "urban":
        n_episodes = 500
        n_query = 5
        batch_size_cache = 32
    else:
        n_episodes = 1000
        n_query = 10
        batch_size_cache = 16

    print(f"\n=== Dataset: {name} ===")

    full_dataset = build_dataset(name)

    train_dataset, test_dataset = split_dataset_by_classes(
        full_dataset,
        train_ratio=0.7,
        seed=42,
    )

    X_train, y_train = build_embedding_cache(
        train_dataset,
        hubert,
        device,
        batch_size=batch_size_cache,
    )

    train_emb_dataset = EmbeddingDataset(
        X_train,
        y_train,
        classes=train_dataset.classes,
    )

    X_test, y_test = build_embedding_cache(
        test_dataset,
        hubert,
        device,
        batch_size=batch_size_cache,
    )

    test_emb_dataset = EmbeddingDataset(
        X_test,
        y_test,
        classes=test_dataset.classes,
    )

    n_way_train = len(train_emb_dataset.classes)
    n_way_test  = len(test_emb_dataset.classes)

    print("Train classes:", train_emb_dataset.classes)
    print("Test classes :", test_emb_dataset.classes)

    for k_shot in SHOTS:

        print(f"\n--- {k_shot}-shot ---")

        min_train = min(
            (y_train == c).sum().item() for c in train_emb_dataset.classes
        )
        min_test = min(
            (y_test == c).sum().item() for c in test_emb_dataset.classes
        )

        if min_train < k_shot + n_query or min_test < k_shot + n_query:
            print("⚠️ Not enough samples for this k-shot, skipping")
            continue

        head.apply(
            lambda m: isinstance(m, torch.nn.Linear) and m.reset_parameters()
        )

        optimizer = torch.optim.Adam(head.parameters(), lr=5e-4)

        # Train
        train_metrics = prototypical_train(
            dataset=train_emb_dataset,
            head=head,
            optimizer=optimizer,
            device=device,
            n_way=n_way_train,
            k_shot=k_shot,
            n_query=n_query,
            n_episodes=n_episodes,
            episodes_per_batch=8,
        )

        # Test
        test_metrics = benchmark_fewshot_training(
            dataset=test_emb_dataset,
            head=head,
            device=device,
            n_tasks=200,
            n_way=n_way_test,
            k_shot=k_shot,
            n_query=n_query,
        )

        
        results.append({
            "dataset": name,
            "k_shot": k_shot,
            "n_train_classes": n_way_train,
            "n_test_classes": n_way_test,

            "train_loss": train_metrics["train_loss_mean"],
            "train_loss_std": train_metrics["train_loss_std"],

            "test_accuracy": test_metrics["accuracy_mean"],
            "test_accuracy_std": test_metrics["accuracy_std"],
            "test_f1": test_metrics["f1_macro"],
        })


=== Dataset: animals ===


Caching HuBERT embeddings: 100%|██████████| 25/25 [00:01<00:00, 19.49it/s]
Caching HuBERT embeddings: 100%|██████████| 14/14 [00:01<00:00, 12.15it/s]


Train classes: [0, 1]
Test classes : [2]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 204.16it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 501.39it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 218.24it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 491.50it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 213.34it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 488.93it/s]



=== Dataset: speech_commands ===
[INFO] Building balanced dataset: 142 samples × 7 classes = 994
[INFO] Final dataset: 994 samples (142 per class)


Caching HuBERT embeddings: 100%|██████████| 36/36 [00:02<00:00, 16.24it/s]
Caching HuBERT embeddings: 100%|██████████| 27/27 [00:01<00:00, 20.55it/s]


Train classes: [0, 1, 2, 5]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:06<00:00, 155.95it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 281.08it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:06<00:00, 156.11it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 271.25it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:06<00:00, 152.88it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 268.58it/s]



=== Dataset: synthetic_noise_low ===
[INFO] SyntheticAudioNoiseDataset | 8 classes | 200 samples | SNR = 30 dB


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 17.00it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 14.82it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 217.73it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 540.21it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 216.40it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 520.27it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 208.33it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 493.17it/s]



=== Dataset: synthetic_noise_medium ===
[INFO] SyntheticAudioNoiseDataset | 8 classes | 200 samples | SNR = 10 dB


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 17.82it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 15.16it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 219.36it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 538.63it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 216.49it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 516.10it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 205.89it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 483.86it/s]



=== Dataset: synthetic_noise_high ===
[INFO] SyntheticAudioNoiseDataset | 8 classes | 200 samples | SNR = 0 dB


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 17.46it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 15.12it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 227.47it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 539.91it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 216.23it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 512.37it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 209.79it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 482.99it/s]



=== Dataset: synthetic_harmonics_low ===
[INFO] SyntheticAudioHarmonicsDataset | 8 classes | 200 samples | Max harmonics=2


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 16.91it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 14.92it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 223.22it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 538.83it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 218.97it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 514.71it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 209.93it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 499.39it/s]



=== Dataset: synthetic_harmonics_medium ===
[INFO] SyntheticAudioHarmonicsDataset | 8 classes | 200 samples | Max harmonics=4


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 17.60it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 14.97it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 222.51it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 541.57it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 210.50it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 518.12it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 209.53it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 490.31it/s]



=== Dataset: synthetic_harmonics_high ===
[INFO] SyntheticAudioHarmonicsDataset | 8 classes | 200 samples | Max harmonics=8


Caching HuBERT embeddings: 100%|██████████| 8/8 [00:00<00:00, 17.24it/s]
Caching HuBERT embeddings: 100%|██████████| 5/5 [00:00<00:00, 14.32it/s]


Train classes: [0, 1, 2, 5, 7]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 226.51it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 543.29it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 218.77it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 464.94it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 204.92it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 491.36it/s]



=== Dataset: urban ===


Caching HuBERT embeddings: 100%|██████████| 188/188 [00:41<00:00,  4.50it/s]
Caching HuBERT embeddings: 100%|██████████| 86/86 [00:19<00:00,  4.46it/s]


Train classes: [0, 1, 2, 5, 7, 8, 9]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 500/500 [00:16<00:00, 29.69it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 68.39it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 500/500 [00:16<00:00, 29.57it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 68.49it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 500/500 [00:17<00:00, 29.28it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 67.83it/s]



=== Dataset: crema ===


Caching HuBERT embeddings: 100%|██████████| 318/318 [00:27<00:00, 11.75it/s]
Caching HuBERT embeddings: 100%|██████████| 148/148 [00:12<00:00, 12.12it/s]


Train classes: [0, 1, 2, 5]
Test classes : [3, 4]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:28<00:00, 35.24it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 79.31it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:28<00:00, 35.27it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 78.69it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:28<00:00, 34.96it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:02<00:00, 78.22it/s]



=== Dataset: timit ===
[INFO] Total speakers in TIMIT: 462
[INFO] TIMIT Dataset loaded: 200 files | 10 speakers


Caching HuBERT embeddings: 100%|██████████| 9/9 [00:00<00:00, 11.58it/s]
Caching HuBERT embeddings: 100%|██████████| 4/4 [00:00<00:00,  8.51it/s]


Train classes: [0, 1, 2, 5, 7, 8, 9]
Test classes : [3, 4, 6]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:05<00:00, 187.93it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 550.98it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:05<00:00, 178.35it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 532.57it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:05<00:00, 172.51it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 499.71it/s]



=== Dataset: snoring ===


Caching HuBERT embeddings: 100%|██████████| 32/32 [00:03<00:00,  8.84it/s]
Caching HuBERT embeddings: 100%|██████████| 32/32 [00:04<00:00,  7.64it/s]


Train classes: [1]
Test classes : [0]

--- 1-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 215.58it/s]
1-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 297.89it/s]



--- 5-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 217.92it/s]
5-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 295.05it/s]



--- 10-shot ---


Prototypical training (batched): 100%|██████████| 1000/1000 [00:04<00:00, 215.20it/s]
10-shot benchmark: 100%|██████████| 200/200 [00:00<00:00, 290.31it/s]


## Results

In [20]:
import pandas as pd

df = pd.DataFrame(results)
df = df.round(3)   

display(df)

Unnamed: 0,dataset,k_shot,n_train_classes,n_test_classes,train_loss,train_loss_std,test_accuracy,test_accuracy_std,test_f1
0,animals,1,2,1,0.002,0.011,1.0,0.0,1.0
1,animals,5,2,1,0.001,0.004,1.0,0.0,1.0
2,animals,10,2,1,0.001,0.003,1.0,0.0,1.0
3,speech_commands,1,4,3,0.006,0.02,0.87,0.1,0.869
4,speech_commands,5,4,3,0.003,0.012,0.931,0.045,0.932
5,speech_commands,10,4,3,0.003,0.012,0.937,0.044,0.937
6,synthetic_noise_low,1,5,3,0.0,0.0,0.989,0.022,0.989
7,synthetic_noise_low,5,5,3,0.0,0.0,0.989,0.017,0.989
8,synthetic_noise_low,10,5,3,0.0,0.0,0.989,0.016,0.989
9,synthetic_noise_medium,1,5,3,0.0,0.001,0.953,0.054,0.953
