In [1]:
# %pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
# %pip install pytorch-lightning --q --upgrade
# %pip install torchmetrics --q --upgrade

# %pip install wandb --q
# %pip install einops --q
# %pip install soundfile --q

In [2]:
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import torchaudio, torchvision
from torchvision.transforms import Compose, Resize, ToTensor
import os
import matplotlib.pyplot as plt 
import argparse
import numpy as np
import wandb
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, LightningDataModule, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchmetrics.functional import accuracy
from torchaudio.datasets import SPEECHCOMMANDS
from torchaudio.datasets.speechcommands import load_speechcommands_item
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.functional import accuracy

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

## Data Module

Custom dataset classes for unknown speech commands and silence samples

In [3]:
class SilenceDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(SilenceDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35
        path = os.path.join(self._path, torchaudio.datasets.speechcommands.EXCEPT_FOLDER)
        self.paths = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('.wav')]

    def __getitem__(self, index):
        index = np.random.randint(0, len(self.paths))
        filepath = self.paths[index]
        waveform, sample_rate = torchaudio.load(filepath)
        return waveform, sample_rate, "silence", 0, 0

    def __len__(self):
        return self.len

class UnknownDataset(SPEECHCOMMANDS):
    def __init__(self, root):
        super(UnknownDataset, self).__init__(root, subset='training')
        self.len = len(self._walker) // 35

    def __getitem__(self, index):
        index = np.random.randint(0, len(self._walker))
        fileid = self._walker[index]
        waveform, sample_rate, _, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
        return waveform, sample_rate, "unknown", speaker_id, utterance_number

    def __len__(self):
        return self.len

KWS DataModule Handles transformation of waveform to MEL spectrum and the turning the "image" into patches

In [4]:
class KWSDataModule(LightningDataModule):
    def __init__(self, path, batch_size=128, num_workers=0, n_fft=512, 
                 n_mels=128, win_length=None, hop_length=256, patch_num=16, class_dict={}, 
                 **kwargs):
        super().__init__(**kwargs)
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.win_length = win_length
        self.hop_length = hop_length
        self.class_dict = class_dict
        self.patch_num = patch_num

    def prepare_data(self):
        self.train_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                                download=True,
                                                                subset='training')

        silence_dataset = SilenceDataset(self.path)
        unknown_dataset = UnknownDataset(self.path)
        self.train_dataset = torch.utils.data.ConcatDataset([self.train_dataset, silence_dataset, unknown_dataset])
                                                                
        self.val_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                              download=True,
                                                              subset='validation')
        self.test_dataset = torchaudio.datasets.SPEECHCOMMANDS(self.path,
                                                               download=True,
                                                               subset='testing')                                                    
        _, sample_rate, _, _, _ = self.train_dataset[0]
        self.sample_rate = sample_rate
        self.transform = torchvision.transforms.Compose([
            torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,
                                                              n_fft=self.n_fft,
                                                              win_length=self.win_length,
                                                              hop_length=self.hop_length,
                                                              n_mels=self.n_mels,
                                                              power=2.0),
            torchaudio.transforms.AmplitudeToDB(),
            torchvision.transforms.Resize((128,128))
        ])

    def setup(self, stage=None):
        self.prepare_data()

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )
    
    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True,
            collate_fn=self.collate_fn
        )

    def collate_fn(self, batch):
        mels = []
        labels = []
        wavs = []
        for sample in batch:
            waveform, sample_rate, label, speaker_id, utterance_number = sample
            # ensure that all waveforms are 1sec in length; if not pad with zeros
            if waveform.shape[-1] < sample_rate:
                waveform = torch.cat([waveform, torch.zeros((1, sample_rate - waveform.shape[-1]))], dim=-1)
            elif waveform.shape[-1] > sample_rate:
                waveform = waveform[:,:sample_rate]

            # mel from power to db
            mels.append(self.transform(waveform))
            labels.append(torch.tensor(self.class_dict[label]))
            wavs.append(waveform)

        mels = torch.stack(mels)
        # mels = rearrange(mels, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)
        labels = torch.stack(labels)
        wavs = torch.stack(wavs)
   
        return mels, labels, wavs

## Defining the Transformer Architecture

In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out

In [7]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))

def init_weights_vit_timm(module: nn.Module):
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()

## Transformers Lightning Module

In [8]:
class LitTransformer(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, max_epochs=30, depth=12, emb_size=64,
                 head=4, patch_size=192, img_size=64, in_channels=1, **kwargs):
        super().__init__()
        self.save_hyperparameters()

        self.embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.encoder = TransformerEncoder(depth, emb_size=emb_size, **kwargs)
        self.classifier = ClassificationHead(emb_size, num_classes)

        self.loss = torch.nn.CrossEntropyLoss()
        
        self.reset_parameters()


    def reset_parameters(self):
        init_weights_vit_timm(self)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.classifier(x)
        return x
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss
    

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y)
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc*100., on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)

In [9]:
class WandbCallback(Callback):

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        # log 10 sample audio predictions from the first batch
        if batch_idx == 0:
            n = 10
            mels, labels, wavs = batch
            preds = outputs["y_hat"]
            preds = torch.argmax(preds, dim=1)

            labels = labels.cpu().numpy()
            preds = preds.cpu().numpy()
            
            wavs = torch.squeeze(wavs, dim=1)
            wavs = [ (wav.cpu().numpy()*32768.0).astype("int16") for wav in wavs]
            
            sample_rate = pl_module.hparams.sample_rate
            idx_to_class = pl_module.hparams.idx_to_class
            
            # log audio samples and predictions as a W&B Table
            columns = ['audio', 'mel', 'ground truth', 'prediction']
            data = [[wandb.Audio(wav, sample_rate=sample_rate), wandb.Image(mel), idx_to_class[label], idx_to_class[pred]] for wav, mel, label, pred in list(
                zip(wavs[:n], mels[:n], labels[:n], preds[:n]))]
            wandb_logger.log_table(
                key='Transformers on KWS using PyTorch Lightning',
                columns=columns,
                data=data)


## Utility Functions

In [10]:
def get_args():
    parser = ArgumentParser(description='PyTorch Transformer')
    
    # where dataset will be stored
    parser.add_argument("--path", type=str, default="data/speech_commands/")

    # 35 keywords + silence + unknown
    parser.add_argument("--num-classes", type=int, default=37)
   
    # mel spectrogram parameters
    parser.add_argument("--n-fft", type=int, default=1024)
    parser.add_argument("--n-mels", type=int, default=128)
    parser.add_argument("--win-length", type=int, default=None)
    parser.add_argument("--hop-length", type=int, default=512)
    
    # model hyperparameters
    parser.add_argument('--depth', type=int, default=12, help='depth')
    parser.add_argument('--embed_dim', type=int, default=64, help='embedding dimension')
    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')
    parser.add_argument('--patch_size', type=int, default=16, help='patch_num')
    
    parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                        help='input batch size for training (default: )')
    parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 0)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.0)')
    
    # 16-bit fp model to reduce the size
    parser.add_argument("--precision", default=16)
    parser.add_argument("--accelerator", default='gpu')
    parser.add_argument("--devices", default=1)
    parser.add_argument("--num-workers", type=int, default=4)
    
    parser.add_argument("--no-wandb", default=False, action='store_true')
    
    args = parser.parse_args("")
    return args

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f'Channel {c+1}')
        if xlim:
            axes[c].set_xlim(xlim)
        if ylim:
            axes[c].set_ylim(ylim)
    figure.suptitle(title)
    plt.show(block=False)

## Trainer

In [16]:

if __name__ == "__main__":

    args = get_args()
    CLASSES = ['silence', 'unknown', 'backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow',
               'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no',
               'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three',
               'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero']
    
    # make a dictionary from CLASSES to integers
    CLASS_TO_IDX = {c: i for i, c in enumerate(CLASSES)}

    if not os.path.exists(args.path):
        os.makedirs(args.path, exist_ok=True)

    
    datamodule = KWSDataModule(batch_size=args.batch_size, num_workers=args.num_workers,
                               path=args.path, n_fft=args.n_fft, n_mels=args.n_mels,
                               win_length=args.win_length, hop_length=args.hop_length, class_dict=CLASS_TO_IDX)
    datamodule.setup()
    
    mels, _,_ = iter(datamodule.train_dataloader()).next()
    b, c, w, h = mels.shape
    
    model = LitTransformer(num_classes=args.num_classes, lr=args.lr, max_epochs=args.max_epochs, 
                           depth=args.depth, emb_size=args.embed_dim, head=args.num_heads,
                           patch_size=args.patch_size, img_size=w, in_channels=c)

    # wandb is a great way to debug and visualize this model
    wandb_logger = WandbLogger(project="pl-kws", log_model="all")

    model_checkpoint = ModelCheckpoint(
        dirpath=os.path.join(args.path, "checkpoints"),
        filename="transformers-kws-best-acc",
        save_top_k=1,
        verbose=True,
        monitor='test_acc',
        mode='max',
    )
    idx_to_class = {v: k for k, v in CLASS_TO_IDX.items()}
    trainer = Trainer(accelerator=args.accelerator,
                      devices=args.devices,
                      precision=args.precision,
                      max_epochs=args.max_epochs,
                      logger=wandb_logger if not args.no_wandb else None,
                      callbacks=[model_checkpoint, WandbCallback() if not args.no_wandb else None])
    model.hparams.sample_rate = datamodule.sample_rate
    model.hparams.idx_to_class = idx_to_class
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)

    wandb.finish()
    #trainer.save_checkpoint('../mnist/checkpoint.ckpt')


[34m[1mwandb[0m: Currently logged in as: [33mkhizon[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | embedding  | PatchEmbedding     | 20.7 K
1 | encoder    | TransformerEncoder | 599 K 
2 | classifier | ClassificationHead | 2.5 K 
3 | loss       | CrossEntropyLoss   | 0     
--------------------------------------------------
623 K     Trainable params
0         Non-trainable params
623 K     Total params
1.246     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 2803: 'test_acc' reached 30.21559 (best 30.21559), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 5606: 'test_acc' reached 47.79302 (best 47.79302), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 8409: 'test_acc' reached 59.43475 (best 59.43475), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 11212: 'test_acc' reached 61.39997 (best 61.39997), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 14015: 'test_acc' reached 69.70912 (best 69.70912), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 16818: 'test_acc' reached 70.28210 (best 70.28210), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 19621: 'test_acc' reached 75.51462 (best 75.51462), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 22424: 'test_acc' reached 77.96647 (best 77.96647), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 25227: 'test_acc' reached 78.57641 (best 78.57641), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 28030: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 30833: 'test_acc' reached 80.35719 (best 80.35719), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 33636: 'test_acc' reached 81.56844 (best 81.56844), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 36439: 'test_acc' reached 81.90691 (best 81.90691), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 39242: 'test_acc' reached 82.32758 (best 82.32758), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 42045: 'test_acc' reached 82.45261 (best 82.45261), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 44848: 'test_acc' reached 83.29913 (best 83.29913), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 47651: 'test_acc' reached 83.91702 (best 83.91702), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 50454: 'test_acc' reached 84.71243 (best 84.71243), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 53257: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 56060: 'test_acc' reached 86.05562 (best 86.05562), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 58863: 'test_acc' reached 86.54433 (best 86.54433), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 21, global step 61666: 'test_acc' reached 86.77781 (best 86.77781), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 64469: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 67272: 'test_acc' reached 87.24545 (best 87.24545), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 70075: 'test_acc' reached 87.30659 (best 87.30659), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 72878: 'test_acc' reached 87.37463 (best 87.37463), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 75681: 'test_acc' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 78484: 'test_acc' reached 87.59290 (best 87.59290), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 81287: 'test_acc' reached 87.66716 (best 87.66716), saving model to '/home/jupyter/Keyword_Spotting_Transformers/data/speech_commands/checkpoints/transformers-kws-best-acc-v3.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 84090: 'test_acc' was not in top 1
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             86.41520690917969
        test_loss           0.5097184181213379
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


VBox(children=(Label(value='194.755 MB of 194.755 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇███
test_acc,▁▃▅▅▆▆▇▇▇▇▇▇▇▇▇▇███████████████
test_loss,█▆▄▄▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████

0,1
epoch,30.0
test_acc,86.41521
test_loss,0.50972
trainer/global_step,84090.0


In [14]:
b, c, w, h = mels.shape
h

128

In [15]:
mels.shape

torch.Size([32, 1, 128, 128])