In [1]:
import json
import pandas as pd
from datetime import datetime

import evaluate
import numpy as np
import torch
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
speech_data = load_dataset(
    "speech_commands", "v0.02"
)

In [3]:
speech_data

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 84848
    })
    validation: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 9982
    })
    test: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 4890
    })
})

In [4]:
speech_data = speech_data.filter(lambda x: len(x["audio"]["array"]) / 16_000 < 10)

In [5]:
speech_data

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 84843
    })
    validation: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 9981
    })
    test: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 4890
    })
})

In [6]:
subset_of_labels = ["up", "down", "left", "right", "on", "off", "yes", "no"]

In [7]:
speech_data = speech_data.filter(lambda x: x['file'].split("/")[0] in subset_of_labels)

In [8]:
speech_data

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 24552
    })
    validation: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 2981
    })
    test: Dataset({
        features: ['file', 'audio', 'label', 'is_unknown', 'speaker_id', 'utterance_id'],
        num_rows: 3261
    })
})

In [9]:
train = speech_data["train"]
validation = speech_data["validation"]
test = speech_data["test"]

In [10]:
SEED = 1
train = train.shuffle(seed=SEED)
validation = validation.shuffle(seed=SEED)
test = test.shuffle(seed=SEED)

In [11]:
def extract_fields(example):
    x = example["audio"]["array"]
    return {"label": example["label"], "array": np.pad(x, (0, 16000 - len(x)), constant_values=0)}

In [12]:
train = train.map(extract_fields)
validation = validation.map(extract_fields)
test = test.map(extract_fields)

In [13]:
train = train.map(remove_columns=["file", "audio", "speaker_id", "utterance_id"])
validation = validation.map(remove_columns=["file", "audio", "speaker_id", "utterance_id"])
test = test.map(remove_columns=["file", "audio", "speaker_id", "utterance_id"])

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cuda:3'
print(device)

cuda:3


In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np

from torch.autograd import Variable

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        out, _ = self.lstm(x, (h0, c0))

        out = self.fc(out[:, -1, :])
        return out


from datasets import load_metric

accuracy = load_metric("accuracy")

NUM_EPOCHS = 100
BATCH_SIZE = 256

torch.cuda.empty_cache()
model = LSTMModel(16_000, 64, 2, len(subset_of_labels))

model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer=optimizer, T_max=NUM_EPOCHS, eta_min=0
)
best_val_loss = float("inf")

  accuracy = load_metric("accuracy")


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [16]:
def calculate_accuracy(preds, y):
    temp = torch.nn.functional.softmax(preds, dim=1)
    temp = torch.argmax(temp, dim=1)
    return (torch.sum(temp == y) / len(y)).item()

In [17]:
from tqdm import tqdm

results = pd.DataFrame(columns=["epoch", "train_loss", "train_accuracy", "val_loss", "val_accuracy"])

train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
for epoch in tqdm(range(1, NUM_EPOCHS+1)):
    model.train()
    train_loss = 0
    train_accuracy = 0
    for batch in train_loader:
        x = batch["array"]
        x = torch.stack(x).to(device)
        x = x.unsqueeze(1)
        x = x.permute(2, 1, 0)
        y = batch["label"].to(device)
        y_pred = model(x.float())
        loss = criterion(y_pred, y)
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_accuracy += calculate_accuracy(y_pred, y)
    train_loss /= len(train_loader)
    train_accuracy = train_accuracy / len(train_loader)
    print(f"Epoch {epoch} train loss: {train_loss}, train accuracy: {train_accuracy}")

    model.eval()
    with torch.no_grad():
        val_loss = 0
        val_accuracy = 0
        for batch in DataLoader(validation, batch_size=BATCH_SIZE):
            x = batch["array"]
            x = torch.stack(x).to(device)
            x = x.unsqueeze(1)
            x = x.permute(2, 1, 0)
            y = batch["label"].to(device)
            y_pred = model(x.float())
            val_loss += criterion(y_pred, y).item()
            val_accuracy += calculate_accuracy(y_pred, y)
        val_loss /= len(validation)
        val_accuracy = val_accuracy / len(validation)
        print(f"Epoch {epoch} val loss: {val_loss}, val accuracy: {val_accuracy}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_model.pth")
        row = pd.DataFrame(
            {
                "epoch": [epoch],
                "train_loss": [loss.item()],
                "train_accuracy": [train_accuracy],
                "val_loss": [val_loss],
                "val_accuracy": [val_accuracy],
            }
        )
        results = pd.merge(results, row, how="outer")
    # break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1 train loss: 2.0787476648887, train accuracy: 0.13816535333171487


  1%|          | 1/100 [04:34<7:33:05, 274.60s/it]

Epoch 1 val loss: 0.008367870206682525, val accuracy: 0.0005528860816376515
Epoch 2 train loss: 2.0660356308023133, train accuracy: 0.20729924350356063


  2%|▏         | 2/100 [09:06<7:25:43, 272.89s/it]

Epoch 2 val loss: 0.008363358249043667, val accuracy: 0.0005937779521590229
Epoch 3 train loss: 1.9346177950501442, train accuracy: 0.3113915454596281


  3%|▎         | 3/100 [13:36<7:19:12, 271.68s/it]

Epoch 3 val loss: 0.00844027419411788, val accuracy: 0.0006077394804191525
Epoch 4 train loss: 1.6637613487740357, train accuracy: 0.461559076483051


  4%|▍         | 4/100 [18:06<7:13:25, 270.89s/it]

Epoch 4 val loss: 0.008768760930048223, val accuracy: 0.0005952233447999772
Epoch 5 train loss: 1.381819939861695, train accuracy: 0.5245571794609228


  5%|▌         | 5/100 [22:36<7:08:43, 270.77s/it]

Epoch 5 val loss: 0.009320036751517751, val accuracy: 0.0006604247977372905
Epoch 6 train loss: 1.0208840544025104, train accuracy: 0.6846404677877823


  6%|▌         | 6/100 [27:07<7:04:03, 270.67s/it]

Epoch 6 val loss: 0.009906098091614322, val accuracy: 0.0006656663274253147
Epoch 7 train loss: 0.8918028641492128, train accuracy: 0.7081803946445385


  7%|▋         | 7/100 [31:36<6:58:49, 270.21s/it]

Epoch 7 val loss: 0.01086152843563639, val accuracy: 0.0006732586015194544
Epoch 8 train loss: 0.6053905729204416, train accuracy: 0.8251812811940908


  8%|▊         | 8/100 [36:06<6:53:58, 269.99s/it]

Epoch 8 val loss: 0.011377588529308464, val accuracy: 0.0006845993649355719
Epoch 9 train loss: 0.5652038526410857, train accuracy: 0.8219372977813085


  9%|▉         | 9/100 [40:35<6:49:27, 269.98s/it]

Epoch 9 val loss: 0.012860960866012497, val accuracy: 0.0006962578015133903
Epoch 10 train loss: 0.364815104752779, train accuracy: 0.9018807243555784


 10%|█         | 10/100 [45:05<6:44:34, 269.72s/it]

Epoch 10 val loss: 0.013252331162010892, val accuracy: 0.0007227354653926953
Epoch 11 train loss: 0.3606899653871854, train accuracy: 0.8937595412135124


 11%|█         | 11/100 [49:35<6:40:14, 269.83s/it]

Epoch 11 val loss: 0.014903972069715662, val accuracy: 0.0006965278169525746
Epoch 12 train loss: 0.22830977647875747, train accuracy: 0.9441549368202686


 12%|█▏        | 12/100 [54:05<6:35:49, 269.88s/it]

Epoch 12 val loss: 0.015192518295037271, val accuracy: 0.0006986959034146498
Epoch 13 train loss: 0.23059594770893455, train accuracy: 0.9380836927642425


 13%|█▎        | 13/100 [58:35<6:31:38, 270.10s/it]

Epoch 13 val loss: 0.017068658647662, val accuracy: 0.0006886655224205383
Epoch 14 train loss: 0.14729188506801924, train accuracy: 0.9686349450300137


 14%|█▍        | 14/100 [1:03:05<6:26:56, 269.96s/it]

Epoch 14 val loss: 0.0173314273537345, val accuracy: 0.0006950824243116203
Epoch 15 train loss: 0.14714672796738645, train accuracy: 0.965343256170551


 15%|█▌        | 15/100 [1:07:36<6:22:54, 270.29s/it]

Epoch 15 val loss: 0.019029915472758133, val accuracy: 0.0006895708791844119
Epoch 16 train loss: 0.0978631266237547, train accuracy: 0.9818311668932438


 16%|█▌        | 16/100 [1:12:08<6:18:59, 270.71s/it]

Epoch 16 val loss: 0.01928243553906389, val accuracy: 0.0006693274542534363
Epoch 17 train loss: 0.094545964927723, train accuracy: 0.9819364001353582


 17%|█▋        | 17/100 [1:16:37<6:14:04, 270.42s/it]

Epoch 17 val loss: 0.020811576324995234, val accuracy: 0.0006900235575663488
Epoch 18 train loss: 0.06336027432310705, train accuracy: 0.9905472677201033


 18%|█▊        | 18/100 [1:21:07<6:09:08, 270.10s/it]

Epoch 18 val loss: 0.021195266759463583, val accuracy: 0.0006940897150513151
Epoch 19 train loss: 0.061624586795611926, train accuracy: 0.9905879578242699


 19%|█▉        | 19/100 [1:25:36<6:04:26, 269.96s/it]

Epoch 19 val loss: 0.022380305952131647, val accuracy: 0.0006803982022303696
Epoch 20 train loss: 0.045839407083500795, train accuracy: 0.99462890625


 20%|██        | 20/100 [1:30:06<5:59:38, 269.73s/it]

Epoch 20 val loss: 0.022759989738784275, val accuracy: 0.0006817085846523757
Epoch 21 train loss: 0.03869080950971693, train accuracy: 0.9958860899011294


 21%|██        | 21/100 [1:34:35<5:55:07, 269.72s/it]

Epoch 21 val loss: 0.023778872056440872, val accuracy: 0.0007014993312014145
Epoch 22 train loss: 0.03279817634029314, train accuracy: 0.9963743711511294


 22%|██▏       | 22/100 [1:39:05<5:50:33, 269.66s/it]

Epoch 22 val loss: 0.02417895714735193, val accuracy: 0.0006836066556752668
Epoch 23 train loss: 0.026487521584688995, train accuracy: 0.9973958333333334


 23%|██▎       | 23/100 [1:43:34<5:46:02, 269.64s/it]

Epoch 23 val loss: 0.024921001343469577, val accuracy: 0.0006894358689654637
Epoch 24 train loss: 0.022595250299976517, train accuracy: 0.9975992838541666


 24%|██▍       | 24/100 [1:48:07<5:42:28, 270.38s/it]

Epoch 24 val loss: 0.02544753018020584, val accuracy: 0.000688848185363291
Epoch 25 train loss: 0.017352224337325122, train accuracy: 0.9987386067708334


 25%|██▌       | 25/100 [1:52:37<5:37:55, 270.34s/it]

Epoch 25 val loss: 0.026027688401051713, val accuracy: 0.0007073285444916114
Epoch 26 train loss: 0.017090587362569448, train accuracy: 0.9980875651041666


 26%|██▌       | 26/100 [1:57:08<5:33:45, 270.61s/it]

Epoch 26 val loss: 0.026543256861377346, val accuracy: 0.0006895708791844119
Epoch 27 train loss: 0.01308753838626823, train accuracy: 0.9988199869791666


 27%|██▋       | 27/100 [2:01:39<5:29:25, 270.76s/it]

Epoch 27 val loss: 0.026960029653237596, val accuracy: 0.0007041200960454265
Epoch 28 train loss: 0.018225314963880617, train accuracy: 0.9970296223958334


 28%|██▊       | 28/100 [2:06:09<5:24:42, 270.59s/it]

Epoch 28 val loss: 0.0274172485054679, val accuracy: 0.0007067408608894387
Epoch 29 train loss: 0.010831462442486858, train accuracy: 0.9988606770833334


 29%|██▉       | 29/100 [2:10:40<5:20:07, 270.53s/it]

Epoch 29 val loss: 0.027749620160573443, val accuracy: 0.0007147381606417109
Epoch 30 train loss: 0.01528173558714722, train accuracy: 0.9977213541666666


 30%|███       | 30/100 [2:15:11<5:16:03, 270.90s/it]

Epoch 30 val loss: 0.028321418775003736, val accuracy: 0.0006792228300273118
Epoch 31 train loss: 0.00964407212450169, train accuracy: 0.9988606770833334


 31%|███       | 31/100 [2:19:42<5:11:16, 270.68s/it]

Epoch 31 val loss: 0.028578630272712693, val accuracy: 0.0006699151428543213
Epoch 32 train loss: 0.01687245453649666, train accuracy: 0.9972695534427961


 32%|███▏      | 32/100 [2:24:12<5:06:40, 270.59s/it]

Epoch 32 val loss: 0.029139237850154657, val accuracy: 0.0006828839568554336
Epoch 33 train loss: 0.011911250447155908, train accuracy: 0.9982334884504477


 33%|███▎      | 33/100 [2:28:43<5:02:16, 270.69s/it]

Epoch 33 val loss: 0.029223881484437653, val accuracy: 0.0006731235962992185
Epoch 34 train loss: 0.02068909621448256, train accuracy: 0.9957191205273072


 34%|███▍      | 34/100 [2:33:13<4:57:29, 270.45s/it]

Epoch 34 val loss: 0.029703619495329143, val accuracy: 0.0006855047216994456
Epoch 35 train loss: 0.018143160079489462, train accuracy: 0.9965006510416666


 35%|███▌      | 35/100 [2:37:43<4:52:53, 270.36s/it]

Epoch 35 val loss: 0.02962800473504641, val accuracy: 0.0006849170380972728
Epoch 36 train loss: 0.022246064399951138, train accuracy: 0.9956784304231405


 36%|███▌      | 36/100 [2:42:13<4:48:08, 270.13s/it]

Epoch 36 val loss: 0.030373254665309494, val accuracy: 0.0006739813003392877
Epoch 37 train loss: 0.02417115735685608, train accuracy: 0.9947341394921144


 37%|███▋      | 37/100 [2:46:42<4:43:22, 269.89s/it]

Epoch 37 val loss: 0.030243091077782173, val accuracy: 0.0006830189670743817
Epoch 38 train loss: 0.02361199142372546, train accuracy: 0.995372553045551


 38%|███▊      | 38/100 [2:51:11<4:38:40, 269.69s/it]

Epoch 38 val loss: 0.030910008867808293, val accuracy: 0.0006725359076983334
Epoch 39 train loss: 0.025888973197046045, train accuracy: 0.9941364154219627


 39%|███▉      | 39/100 [2:55:41<4:34:05, 269.60s/it]

Epoch 39 val loss: 0.030832416217065586, val accuracy: 0.0006793578352475478
Epoch 40 train loss: 0.011472567542417286, train accuracy: 0.9983274961511294


 40%|████      | 40/100 [3:00:11<4:29:46, 269.77s/it]

Epoch 40 val loss: 0.031326977609188435, val accuracy: 0.0006771897487854726
Epoch 41 train loss: 0.008300548084662296, train accuracy: 0.9988199869791666


 41%|████      | 41/100 [3:04:40<4:25:08, 269.64s/it]

Epoch 41 val loss: 0.03142430789516261, val accuracy: 0.0006725359076983334
Epoch 42 train loss: 0.006369074624672066, train accuracy: 0.9994303385416666


 42%|████▏     | 42/100 [3:09:10<4:20:41, 269.69s/it]

Epoch 42 val loss: 0.03182098282459237, val accuracy: 0.0006824312784734967
Epoch 43 train loss: 0.004402965827466687, train accuracy: 0.99951171875


 43%|████▎     | 43/100 [3:13:39<4:16:06, 269.58s/it]

Epoch 43 val loss: 0.031963706216489016, val accuracy: 0.000681843594871324
Epoch 44 train loss: 0.004093705266617083, train accuracy: 0.9996337890625


 44%|████▍     | 44/100 [3:18:09<4:11:43, 269.70s/it]

Epoch 44 val loss: 0.032320102082527315, val accuracy: 0.0006883955069813541
Epoch 45 train loss: 0.003775933545208924, train accuracy: 0.9995524088541666


 45%|████▌     | 45/100 [3:22:39<4:07:23, 269.89s/it]

Epoch 45 val loss: 0.03247049005988939, val accuracy: 0.000685774742137342
Epoch 46 train loss: 0.003997379310021643, train accuracy: 0.9996337890625


 46%|████▌     | 46/100 [3:27:09<4:02:53, 269.88s/it]

Epoch 46 val loss: 0.03281387311826013, val accuracy: 0.000684464359715336
Epoch 47 train loss: 0.00345980384251258, train accuracy: 0.9995930989583334


 47%|████▋     | 47/100 [3:31:39<3:58:21, 269.84s/it]

Epoch 47 val loss: 0.03293458618041214, val accuracy: 0.0006910162718253662
Epoch 48 train loss: 0.0030524964871195457, train accuracy: 0.9996744791666666


 48%|████▊     | 48/100 [3:36:08<3:53:41, 269.63s/it]

Epoch 48 val loss: 0.0331967024945525, val accuracy: 0.0006805332124493179
Epoch 49 train loss: 0.003359771246323362, train accuracy: 0.9994303385416666


 49%|████▉     | 49/100 [3:40:37<3:48:55, 269.32s/it]

Epoch 49 val loss: 0.03343164644877969, val accuracy: 0.0006792228300273118
Epoch 50 train loss: 0.0030360104634989207, train accuracy: 0.9995930989583334


 50%|█████     | 50/100 [3:45:07<3:44:43, 269.66s/it]

Epoch 50 val loss: 0.03361895979350543, val accuracy: 0.000685774742137342
Epoch 51 train loss: 0.002761133214031967, train accuracy: 0.99951171875


 51%|█████     | 51/100 [3:49:38<3:40:22, 269.84s/it]

Epoch 51 val loss: 0.03387122136640373, val accuracy: 0.0006792228300273118
Epoch 52 train loss: 0.002640938300222236, train accuracy: 0.9995524088541666


 52%|█████▏    | 52/100 [3:54:07<3:35:54, 269.88s/it]

Epoch 52 val loss: 0.034056427414006975, val accuracy: 0.0006726709179172817
Epoch 53 train loss: 0.0027280685941756624, train accuracy: 0.9995930989583334


 53%|█████▎    | 53/100 [3:58:38<3:31:28, 269.97s/it]

Epoch 53 val loss: 0.0343045319778253, val accuracy: 0.0006792228300273118
Epoch 54 train loss: 0.002291918997798348, train accuracy: 0.9996744791666666


 54%|█████▍    | 54/100 [4:03:08<3:26:58, 269.97s/it]

Epoch 54 val loss: 0.034449078739989085, val accuracy: 0.0006799455238484329
Epoch 55 train loss: 0.005268033874623749, train accuracy: 0.9989378477136294


 55%|█████▌    | 55/100 [4:07:39<3:22:42, 270.28s/it]

Epoch 55 val loss: 0.03475458227680338, val accuracy: 0.0006806682176695539
Epoch 56 train loss: 0.003768288649249977, train accuracy: 0.9993040586511294


 56%|█████▌    | 56/100 [4:12:08<3:18:04, 270.11s/it]

Epoch 56 val loss: 0.034852163218204224, val accuracy: 0.0006792228300273118
Epoch 57 train loss: 0.00940329623396489, train accuracy: 0.9975992838541666


 57%|█████▋    | 57/100 [4:16:39<3:13:35, 270.14s/it]

Epoch 57 val loss: 0.03498948208240405, val accuracy: 0.0006688747758714995
Epoch 58 train loss: 0.004122256999835372, train accuracy: 0.9991861979166666


 58%|█████▊    | 58/100 [4:21:08<3:08:59, 269.99s/it]

Epoch 58 val loss: 0.035041812921042255, val accuracy: 0.0006754266879815297
Epoch 59 train loss: 0.005514287663269594, train accuracy: 0.9988157774011294


 59%|█████▉    | 59/100 [4:25:38<3:04:28, 269.97s/it]

Epoch 59 val loss: 0.03544827172996933, val accuracy: 0.0006726709179172817
Epoch 60 train loss: 0.004003303204929883, train accuracy: 0.9991819883386294


 60%|██████    | 60/100 [4:30:08<3:00:01, 270.03s/it]

Epoch 60 val loss: 0.03550691454252994, val accuracy: 0.0006694624644723845
Epoch 61 train loss: 0.00415193357609193, train accuracy: 0.999267578125


 61%|██████    | 61/100 [4:34:38<2:55:28, 269.95s/it]

Epoch 61 val loss: 0.03596621285905297, val accuracy: 0.0006569463288532093
Epoch 62 train loss: 0.004797116671397816, train accuracy: 0.9991455078125


 62%|██████▏   | 62/100 [4:39:08<2:50:56, 269.90s/it]

Epoch 62 val loss: 0.03590946843743284, val accuracy: 0.0006470509530793337
Epoch 63 train loss: 0.0254128664109885, train accuracy: 0.9936102504531542


 63%|██████▎   | 63/100 [4:43:38<2:46:29, 269.97s/it]

Epoch 63 val loss: 0.03656950671974343, val accuracy: 0.0006626405369231701
Epoch 64 train loss: 0.09166296509404977, train accuracy: 0.9730547325064739


 64%|██████▍   | 64/100 [4:48:09<2:42:06, 270.17s/it]

Epoch 64 val loss: 0.03551026919831049, val accuracy: 0.0006752916827612937
Epoch 65 train loss: 0.02840786589270768, train accuracy: 0.992397965863347


 65%|██████▌   | 65/100 [4:52:37<2:37:20, 269.73s/it]

Epoch 65 val loss: 0.03561214695277529, val accuracy: 0.0006640859245654122
Epoch 66 train loss: 0.048981552419718355, train accuracy: 0.9866452272981405


 66%|██████▌   | 66/100 [4:57:07<2:32:54, 269.85s/it]

Epoch 66 val loss: 0.03526096382384251, val accuracy: 0.0006686047604323154
Epoch 67 train loss: 0.012878008060700571, train accuracy: 0.9978434244791666


 67%|██████▋   | 67/100 [5:01:38<2:28:34, 270.14s/it]

Epoch 67 val loss: 0.035702564939481624, val accuracy: 0.0006686047604323154
Epoch 68 train loss: 0.009704232472965183, train accuracy: 0.9982096354166666


 68%|██████▊   | 68/100 [5:06:09<2:24:07, 270.23s/it]

Epoch 68 val loss: 0.035708400556122205, val accuracy: 0.0006739813003392877
Epoch 69 train loss: 0.004689524575951509, train accuracy: 0.9995524088541666


 69%|██████▉   | 69/100 [5:10:39<2:19:34, 270.14s/it]

Epoch 69 val loss: 0.03594475022821129, val accuracy: 0.0006779124476053058
Epoch 70 train loss: 0.002965708226838615, train accuracy: 0.999755859375


 70%|███████   | 70/100 [5:15:08<2:15:01, 270.05s/it]

Epoch 70 val loss: 0.03604662798267609, val accuracy: 0.0006680170718314303
Epoch 71 train loss: 0.002850937773473561, train accuracy: 0.9997109596927961


 71%|███████   | 71/100 [5:19:39<2:10:33, 270.12s/it]

Epoch 71 val loss: 0.0362542042673054, val accuracy: 0.0006686047604323154
Epoch 72 train loss: 0.0026970313665515278, train accuracy: 0.9995930989583334


 72%|███████▏  | 72/100 [5:24:09<2:06:01, 270.07s/it]

Epoch 72 val loss: 0.036341005265812715, val accuracy: 0.0006725359076983334
Epoch 73 train loss: 0.0026266884051437955, train accuracy: 0.9997151692708334


 73%|███████▎  | 73/100 [5:28:38<2:01:26, 269.89s/it]

Epoch 73 val loss: 0.03650949338025512, val accuracy: 0.0006705028314552063
Epoch 74 train loss: 0.0021521355586931654, train accuracy: 0.9996744791666666


 74%|███████▍  | 74/100 [5:33:09<1:57:01, 270.04s/it]

Epoch 74 val loss: 0.03661918960054302, val accuracy: 0.0006752916827612937
Epoch 75 train loss: 0.002066273134914809, train accuracy: 0.999755859375


 75%|███████▌  | 75/100 [5:37:39<1:52:30, 270.02s/it]

Epoch 75 val loss: 0.03677872098240025, val accuracy: 0.0006751566725423455
Epoch 76 train loss: 0.0029855045898633157, train accuracy: 0.99951171875


 76%|███████▌  | 76/100 [5:42:08<1:47:52, 269.70s/it]

Epoch 76 val loss: 0.036924643070255496, val accuracy: 0.0006725359076983334
Epoch 77 train loss: 0.002117586912087669, train accuracy: 0.9997151692708334


 77%|███████▋  | 77/100 [5:46:38<1:43:28, 269.93s/it]

Epoch 77 val loss: 0.037003647037417806, val accuracy: 0.0006751566725423455
Epoch 78 train loss: 0.0030293429463199573, train accuracy: 0.9994303385416666


 78%|███████▊  | 78/100 [5:51:08<1:38:56, 269.82s/it]

Epoch 78 val loss: 0.03712841841407528, val accuracy: 0.0006718132138772124
Epoch 79 train loss: 0.0020766196539625525, train accuracy: 0.9996337890625


 79%|███████▉  | 79/100 [5:55:38<1:34:29, 269.99s/it]

Epoch 79 val loss: 0.0372708300462868, val accuracy: 0.0006699151428543213
Epoch 80 train loss: 0.002333519210878876, train accuracy: 0.9995524088541666


 80%|████████  | 80/100 [6:00:08<1:29:58, 269.91s/it]

Epoch 80 val loss: 0.03744087008092676, val accuracy: 0.0006712255252763274
Epoch 81 train loss: 0.0019832484558719443, train accuracy: 0.9996744791666666


 81%|████████  | 81/100 [6:04:37<1:25:26, 269.82s/it]

Epoch 81 val loss: 0.03752901089427376, val accuracy: 0.0006725359076983334
Epoch 82 train loss: 0.0019392524282011436, train accuracy: 0.9995930989583334


 82%|████████▏ | 82/100 [6:09:08<1:20:59, 269.95s/it]

Epoch 82 val loss: 0.03773094639207084, val accuracy: 0.0006751566725423455
Epoch 83 train loss: 0.001507156403931731, train accuracy: 0.999755859375


 83%|████████▎ | 83/100 [6:13:37<1:16:26, 269.78s/it]

Epoch 83 val loss: 0.03779505755588617, val accuracy: 0.0006712255252763274
Epoch 84 train loss: 0.002012868617081646, train accuracy: 0.9996337890625


 84%|████████▍ | 84/100 [6:18:06<1:11:55, 269.72s/it]

Epoch 84 val loss: 0.03801826725306826, val accuracy: 0.0006751566725423455
Epoch 85 train loss: 0.0013137802707205992, train accuracy: 0.9997965494791666


 85%|████████▌ | 85/100 [6:22:37<1:07:27, 269.81s/it]

Epoch 85 val loss: 0.03804256531345249, val accuracy: 0.0006705028314552063
Epoch 86 train loss: 0.0061552376776793, train accuracy: 0.9985758463541666


 86%|████████▌ | 86/100 [6:27:06<1:02:55, 269.66s/it]

Epoch 86 val loss: 0.03802501463482181, val accuracy: 0.0006757443611432306
Epoch 87 train loss: 0.00284158245024931, train accuracy: 0.9993489583333334


 87%|████████▋ | 87/100 [6:31:36<58:26, 269.72s/it]  

Epoch 87 val loss: 0.03805941409290497, val accuracy: 0.0006895708791844119
Epoch 88 train loss: 0.0023280613831957453, train accuracy: 0.9995524088541666


 88%|████████▊ | 88/100 [6:36:05<53:56, 269.71s/it]

Epoch 88 val loss: 0.038348487462424787, val accuracy: 0.0006770547435652366
Epoch 89 train loss: 0.002270266202079559, train accuracy: 0.9994710286458334


 89%|████████▉ | 89/100 [6:40:35<49:25, 269.63s/it]

Epoch 89 val loss: 0.03842313639152294, val accuracy: 0.0006764670549643516
Epoch 90 train loss: 0.002037372736898154, train accuracy: 0.9995930989583334


 90%|█████████ | 90/100 [6:45:04<44:53, 269.38s/it]

Epoch 90 val loss: 0.03864455071122644, val accuracy: 0.0006737112799013913
Epoch 91 train loss: 0.0015413953627406347, train accuracy: 0.9996744791666666


 91%|█████████ | 91/100 [6:49:32<40:20, 268.97s/it]

Epoch 91 val loss: 0.03868032165599005, val accuracy: 0.0006770547435652366
Epoch 92 train loss: 0.0018795455783523114, train accuracy: 0.9997151692708334


 92%|█████████▏| 92/100 [6:54:01<35:51, 268.95s/it]

Epoch 92 val loss: 0.038896935932270754, val accuracy: 0.0006757443611432306
Epoch 93 train loss: 0.0018030191498231336, train accuracy: 0.9996744791666666


 93%|█████████▎| 93/100 [6:58:30<31:24, 269.16s/it]

Epoch 93 val loss: 0.03892808694721108, val accuracy: 0.0006796755084092486
Epoch 94 train loss: 0.0018138640916731674, train accuracy: 0.9996337890625


 94%|█████████▍| 94/100 [7:03:00<26:55, 269.21s/it]

Epoch 94 val loss: 0.03907679039853725, val accuracy: 0.0006836066556752668
Epoch 95 train loss: 0.005388520754725808, train accuracy: 0.998779296875


 95%|█████████▌| 95/100 [7:07:29<22:26, 269.21s/it]

Epoch 95 val loss: 0.039040443922040924, val accuracy: 0.0006822962732532607
Epoch 96 train loss: 0.016080084934704548, train accuracy: 0.9961133934557438


 96%|█████████▌| 96/100 [7:11:58<17:56, 269.10s/it]

Epoch 96 val loss: 0.038715381743883136, val accuracy: 0.0006652613017671822
Epoch 97 train loss: 0.1556118214405918, train accuracy: 0.957408685858051


 97%|█████████▋| 97/100 [7:16:27<13:27, 269.30s/it]

Epoch 97 val loss: 0.03872027200406448, val accuracy: 0.0006530151815871911
Epoch 98 train loss: 0.05457333811015511, train accuracy: 0.9849923669050137


 98%|█████████▊| 98/100 [7:20:56<08:58, 269.20s/it]

Epoch 98 val loss: 0.03823487065215112, val accuracy: 0.0007016343364216504
Epoch 99 train loss: 0.038398023849974074, train accuracy: 0.9892858744909366


 99%|█████████▉| 99/100 [7:25:25<04:29, 269.08s/it]

Epoch 99 val loss: 0.038275171309979476, val accuracy: 0.0006954000974733212
Epoch 100 train loss: 0.011081159510164676, train accuracy: 0.9982503255208334


100%|██████████| 100/100 [7:29:55<00:00, 269.95s/it]

Epoch 100 val loss: 0.03846991122148854, val accuracy: 0.0007016343364216504





In [18]:
results

Unnamed: 0,epoch,train_loss,train_accuracy,val_loss,val_accuracy
0,1,2.075794,0.138165,0.008368,0.000553
1,2,2.048750,0.207299,0.008363,0.000594
2,3,1.868140,0.311392,0.008440,0.000608
3,4,1.539145,0.461559,0.008769,0.000595
4,5,1.368787,0.524557,0.009320,0.000660
...,...,...,...,...,...
95,96,0.077500,0.996113,0.038715,0.000665
96,97,0.123601,0.957409,0.038720,0.000653
97,98,0.067961,0.984992,0.038235,0.000702
98,99,0.036201,0.989286,0.038275,0.000695


In [23]:
results['val_accuracy'].max()

0.0007227354653926953

In [20]:
results.to_csv('lstm_only_selected.csv')

In [19]:
loaded_model = LSTMModel(16_000, 64, 2, len(subset_of_labels))
loaded_model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu')))

loaded_model.eval()

with torch.no_grad():
    val_loss = 0
    val_accuracy = 0
    for batch in DataLoader(validation, batch_size=BATCH_SIZE):
        x = batch["array"]
        x = torch.stack(x).to(device)
        x = x.unsqueeze(1)
        x = x.permute(2, 1, 0)
        y = batch["label"].to(device)
        y_pred = model(x.float())
        val_loss += criterion(y_pred, y).item()
        val_accuracy += calculate_accuracy(y_pred, y)
    val_loss /= len(validation)
    val_accuracy = val_accuracy / len(validation)
print(val_loss)
print(val_accuracy)

0.03846991122148854
0.0007016343364216504
