In [1]:
import torch
import numpy as np
import sys
from pathlib import Path
import os
assert "HF_TOKEN" in os.environ
project_root = Path().resolve().parent
sys.path.append(str(project_root))


from src.models.hubert_wrapper import HubertProtoNet
from src.datasets.registry import DATASET_REGISTRY
from src.methods.fewshot.fewshot import create_episode, evaluate_fewshot
from src.utils.config import load_yaml

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [6]:
cfg = load_yaml(project_root / "configs/datasets/speech_commands.yaml")
dataset = DATASET_REGISTRY[cfg["name"]](cfg)

print(f"{len(dataset)} classes loaded")
for k in dataset:
    print(k, len(dataset[k]))


Loading Speech Commands: 100%|██████████| 105829/105829 [02:02<00:00, 866.60it/s] 

6 classes loaded
down 3917
go 3880
left 3801
right 3778
stop 3872
up 3723





Sanity check

In [7]:
for k in dataset:
    print(k, type(dataset[k][0]), len(dataset[k][0]))


down <class 'numpy.ndarray'> 16000
go <class 'numpy.ndarray'> 11889
left <class 'numpy.ndarray'> 16000
right <class 'numpy.ndarray'> 14336
stop <class 'numpy.ndarray'> 13375
up <class 'numpy.ndarray'> 14861


In [8]:
model = HubertProtoNet(
    hubert_model="facebook/hubert-base-ls960",
    freeze_hubert=True,
).to(device)

model.eval()


HubertProtoNet(
  (hubert): HubertModel(
    (feature_extractor): HubertFeatureEncoder(
      (conv_layers): ModuleList(
        (0): HubertGroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x HubertNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x HubertNoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): HubertFeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): HubertEncode

In [9]:
mean, std = evaluate_fewshot(
    dataset,
    model,
    device=device,
    n_way=5,
    n_shot=1,
    n_query=5,
    n_episodes=50,
)

print(f"1-shot accuracy: {mean*100:.2f}% ± {std*100:.2f}%")


5-way 1-shot:   0%|          | 0/50 [00:00<?, ?it/s]

Creating episode...
Extracting features...
Running model...


5-way 1-shot:   2%|▏         | 1/50 [00:03<02:35,  3.17s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:   4%|▍         | 2/50 [00:06<02:33,  3.20s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:   6%|▌         | 3/50 [00:09<02:29,  3.18s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:   8%|▊         | 4/50 [00:14<03:05,  4.02s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  10%|█         | 5/50 [00:18<02:58,  3.97s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  12%|█▏        | 6/50 [00:24<03:23,  4.62s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  14%|█▍        | 7/50 [00:30<03:38,  5.08s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  16%|█▌        | 8/50 [00:34<03:12,  4.57s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  18%|█▊        | 9/50 [00:39<03:20,  4.90s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  20%|██        | 10/50 [00:43<03:00,  4.51s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  22%|██▏       | 11/50 [00:47<02:50,  4.38s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  24%|██▍       | 12/50 [00:54<03:22,  5.32s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  26%|██▌       | 13/50 [01:00<03:16,  5.32s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  28%|██▊       | 14/50 [01:06<03:16,  5.46s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  30%|███       | 15/50 [01:11<03:06,  5.32s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  32%|███▏      | 16/50 [01:14<02:44,  4.84s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  34%|███▍      | 17/50 [01:18<02:30,  4.56s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  36%|███▌      | 18/50 [01:21<02:13,  4.16s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  38%|███▊      | 19/50 [01:25<02:02,  3.95s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  40%|████      | 20/50 [01:28<01:52,  3.76s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  42%|████▏     | 21/50 [01:31<01:45,  3.63s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  44%|████▍     | 22/50 [01:35<01:37,  3.48s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  46%|████▌     | 23/50 [01:38<01:35,  3.54s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  48%|████▊     | 24/50 [01:42<01:29,  3.44s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  50%|█████     | 25/50 [01:45<01:23,  3.33s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  52%|█████▏    | 26/50 [01:48<01:19,  3.32s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  54%|█████▍    | 27/50 [01:51<01:16,  3.31s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  56%|█████▌    | 28/50 [01:54<01:12,  3.31s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  58%|█████▊    | 29/50 [01:58<01:11,  3.39s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  60%|██████    | 30/50 [02:01<01:07,  3.39s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  62%|██████▏   | 31/50 [02:05<01:06,  3.51s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  64%|██████▍   | 32/50 [02:10<01:12,  4.02s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  66%|██████▌   | 33/50 [02:16<01:15,  4.47s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  68%|██████▊   | 34/50 [02:19<01:06,  4.18s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  70%|███████   | 35/50 [02:24<01:05,  4.34s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  72%|███████▏  | 36/50 [02:28<00:58,  4.16s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  74%|███████▍  | 37/50 [02:32<00:53,  4.11s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  76%|███████▌  | 38/50 [02:36<00:48,  4.02s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  78%|███████▊  | 39/50 [02:40<00:45,  4.10s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  80%|████████  | 40/50 [02:43<00:38,  3.81s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  82%|████████▏ | 41/50 [02:46<00:32,  3.64s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  84%|████████▍ | 42/50 [02:50<00:28,  3.52s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  86%|████████▌ | 43/50 [02:53<00:24,  3.52s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  88%|████████▊ | 44/50 [02:56<00:20,  3.39s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  90%|█████████ | 45/50 [02:59<00:16,  3.29s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  92%|█████████▏| 46/50 [03:03<00:13,  3.30s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  94%|█████████▍| 47/50 [03:06<00:10,  3.47s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  96%|█████████▌| 48/50 [03:10<00:07,  3.54s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot:  98%|█████████▊| 49/50 [03:13<00:03,  3.41s/it]

Computing prototypes and accuracy...
Creating episode...
Extracting features...
Running model...


5-way 1-shot: 100%|██████████| 50/50 [03:17<00:00,  3.95s/it]

Computing prototypes and accuracy...
1-shot accuracy: 47.52% ± 12.46%



