In [1]:
import os
import argparse
import multiprocessing
from collections import Counter

import torch
import torchaudio
import numpy as np
import pandas as pd

from models import PretrainedModel, Model
from data import SLUDataset, get_SLU_datasets, read_config
from training import Trainer

USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")

In [2]:
config_path = 'unfreeze_word_layers.cfg'
config = read_config(config_path)
torch.manual_seed(config.seed); np.random.seed(config.seed)

# 1. Generate datasets

In [3]:
train_dataset, valid_dataset, test_dataset = get_SLU_datasets(config)
#train_dataset_sample = next(iter(train_dataset.loader))

In [7]:
config.Sy_intent

{'action': {'change language': 0,
  'activate': 1,
  'deactivate': 2,
  'increase': 3,
  'decrease': 4,
  'bring': 5},
 'object': {'none': 0,
  'music': 1,
  'lights': 2,
  'volume': 3,
  'heat': 4,
  'lamp': 5,
  'newspaper': 6,
  'juice': 7,
  'socks': 8,
  'Chinese': 9,
  'Korean': 10,
  'English': 11,
  'German': 12,
  'shoes': 13},
 'location': {'none': 0, 'kitchen': 1, 'bedroom': 2, 'washroom': 3}}

# 2. Train SLU Model

In [7]:
# Initialize SLU Model
model = Model(config=config)
model

Model(
  (pretrained_model): PretrainedModel(
    (phoneme_layers): ModuleList(
      (0): SincLayer()
      (1): Abs()
      (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
      (3): LeakyReLU(negative_slope=0.2)
      (4): Dropout(p=0.0, inplace=False)
      (5): Conv1d(80, 60, kernel_size=(5,), stride=(1,), padding=(2,))
      (6): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=True)
      (7): LeakyReLU(negative_slope=0.2)
      (8): Dropout(p=0.0, inplace=False)
      (9): Conv1d(60, 60, kernel_size=(5,), stride=(1,), padding=(2,))
      (10): MaxPool1d(kernel_size=1, stride=1, padding=0, dilation=1, ceil_mode=True)
      (11): LeakyReLU(negative_slope=0.2)
      (12): Dropout(p=0.0, inplace=False)
      (13): NCL2NLC()
      (14): GRU(60, 128, batch_first=True, bidirectional=True)
      (15): RNNSelect()
      (16): Dropout(p=0.5, inplace=False)
      (17): Downsample()
      (18): GRU(256, 128, batch_first=True, bidirectional=

In [8]:
%%time

# Train the model
trainer = Trainer(model=model, config=config)
# if restart: trainer.load_checkpoint()

for epoch in range(config.training_num_epochs):
    print("========= Epoch %d of %d =========" % (epoch+1, config.training_num_epochs))
    train_intent_acc, train_intent_loss = trainer.train(train_dataset)
    valid_intent_acc, valid_intent_loss = trainer.test(valid_dataset)

    print("========= Results: epoch %d of %d =========" % (epoch+1, config.training_num_epochs))
    print("*intents*| train accuracy: %.2f| train loss: %.2f| valid accuracy: %.2f| valid loss: %.2f\n" % (train_intent_acc, train_intent_loss, valid_intent_acc, valid_intent_loss) )
    
    trainer.save_checkpoint()

  0%|          | 0/723 [00:00<?, ?it/s]

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: frozen
word_rnn1: frozen


  0%|          | 1/723 [00:00<10:20,  1.16it/s]

intent loss: 5.902858734130859
intent acc: 0.0


 14%|█▍        | 101/723 [00:20<02:05,  4.97it/s]

intent loss: 1.37888503074646
intent acc: 0.6875


 28%|██▊       | 202/723 [00:39<01:34,  5.50it/s]

intent loss: 0.900307297706604
intent acc: 0.765625


 42%|████▏     | 301/723 [00:59<01:19,  5.28it/s]

intent loss: 0.8312504291534424
intent acc: 0.78125


 56%|█████▌    | 402/723 [01:18<00:58,  5.51it/s]

intent loss: 0.7913159132003784
intent acc: 0.78125


 69%|██████▉   | 502/723 [01:37<00:40,  5.40it/s]

intent loss: 0.4167709946632385
intent acc: 0.90625


 83%|████████▎ | 602/723 [01:56<00:21,  5.53it/s]

intent loss: 0.5207804441452026
intent acc: 0.875


 97%|█████████▋| 702/723 [02:15<00:04,  5.17it/s]

intent loss: 0.450076162815094
intent acc: 0.90625


100%|██████████| 723/723 [02:19<00:00,  5.71it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.75| train loss: 0.98| valid accuracy: 0.87| valid loss: 0.53

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: frozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:01<08:28,  1.42it/s]

intent loss: 0.31375187635421753
intent acc: 0.90625


 14%|█▍        | 101/723 [00:20<02:07,  4.89it/s]

intent loss: 0.19355660676956177
intent acc: 0.96875


 28%|██▊       | 202/723 [00:41<01:48,  4.81it/s]

intent loss: 0.26075249910354614
intent acc: 0.90625


 42%|████▏     | 301/723 [01:01<01:20,  5.25it/s]

intent loss: 0.37171459197998047
intent acc: 0.890625


 55%|█████▌    | 401/723 [01:21<01:03,  5.06it/s]

intent loss: 0.19526979327201843
intent acc: 0.953125


 69%|██████▉   | 501/723 [01:40<00:45,  4.83it/s]

intent loss: 0.2167237102985382
intent acc: 0.921875


 83%|████████▎ | 601/723 [02:00<00:23,  5.26it/s]

intent loss: 0.3530408442020416
intent acc: 0.921875


 97%|█████████▋| 702/723 [02:20<00:03,  5.36it/s]

intent loss: 0.20250743627548218
intent acc: 0.953125


100%|██████████| 723/723 [02:24<00:00,  5.22it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.92| train loss: 0.29| valid accuracy: 0.91| valid loss: 0.38

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<09:27,  1.27it/s]

intent loss: 0.08004306256771088
intent acc: 0.96875


 14%|█▍        | 102/723 [00:21<02:13,  4.67it/s]

intent loss: 0.21178334951400757
intent acc: 0.9375


 28%|██▊       | 201/723 [00:41<01:34,  5.54it/s]

intent loss: 0.1158694177865982
intent acc: 0.953125


 42%|████▏     | 302/723 [01:01<01:23,  5.05it/s]

intent loss: 0.1892060488462448
intent acc: 0.96875


 56%|█████▌    | 402/723 [01:22<01:07,  4.76it/s]

intent loss: 0.23005661368370056
intent acc: 0.921875


 69%|██████▉   | 502/723 [01:43<00:41,  5.32it/s]

intent loss: 0.15211082994937897
intent acc: 0.953125


 83%|████████▎ | 602/723 [02:03<00:23,  5.17it/s]

intent loss: 0.3073040843009949
intent acc: 0.90625


 97%|█████████▋| 702/723 [02:24<00:04,  5.21it/s]

intent loss: 0.1690542995929718
intent acc: 0.96875


100%|██████████| 723/723 [02:28<00:00,  5.38it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.95| train loss: 0.18| valid accuracy: 0.92| valid loss: 0.34

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<07:12,  1.67it/s]

intent loss: 0.2126576006412506
intent acc: 0.9375


 14%|█▍        | 101/723 [00:21<02:02,  5.07it/s]

intent loss: 0.10074374824762344
intent acc: 0.984375


 28%|██▊       | 202/723 [00:42<01:47,  4.84it/s]

intent loss: 0.030516505241394043
intent acc: 1.0


 42%|████▏     | 301/723 [01:03<01:23,  5.07it/s]

intent loss: 0.17312565445899963
intent acc: 0.9375


 55%|█████▌    | 401/723 [01:23<01:05,  4.93it/s]

intent loss: 0.16899573802947998
intent acc: 0.9375


 69%|██████▉   | 502/723 [01:45<00:44,  5.02it/s]

intent loss: 0.1739170253276825
intent acc: 0.984375


 83%|████████▎ | 602/723 [02:06<00:26,  4.54it/s]

intent loss: 0.2316683530807495
intent acc: 0.953125


 97%|█████████▋| 702/723 [02:26<00:04,  4.93it/s]

intent loss: 0.10282950103282928
intent acc: 0.984375


100%|██████████| 723/723 [02:31<00:00,  3.84it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.96| train loss: 0.13| valid accuracy: 0.93| valid loss: 0.37

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<06:52,  1.75it/s]

intent loss: 0.13689573109149933
intent acc: 0.953125


 14%|█▍        | 102/723 [00:21<02:01,  5.12it/s]

intent loss: 0.13014279305934906
intent acc: 0.953125


 28%|██▊       | 202/723 [00:42<01:42,  5.11it/s]

intent loss: 0.03510870784521103
intent acc: 0.984375


 42%|████▏     | 302/723 [01:03<01:18,  5.39it/s]

intent loss: 0.04426129162311554
intent acc: 1.0


 55%|█████▌    | 401/723 [01:23<01:07,  4.74it/s]

intent loss: 0.1253020018339157
intent acc: 0.953125


 69%|██████▉   | 501/723 [01:44<00:45,  4.84it/s]

intent loss: 0.06012031435966492
intent acc: 0.984375


 83%|████████▎ | 601/723 [02:04<00:26,  4.65it/s]

intent loss: 0.23438063263893127
intent acc: 0.953125


 97%|█████████▋| 701/723 [02:25<00:04,  5.40it/s]

intent loss: 0.17000290751457214
intent acc: 0.96875


100%|██████████| 723/723 [02:29<00:00,  4.96it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.97| train loss: 0.10| valid accuracy: 0.94| valid loss: 0.33

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<08:47,  1.37it/s]

intent loss: 0.026165127754211426
intent acc: 1.0


 14%|█▍        | 101/723 [00:21<02:01,  5.13it/s]

intent loss: 0.11579971760511398
intent acc: 0.96875


 28%|██▊       | 201/723 [00:42<01:53,  4.60it/s]

intent loss: 0.1753195822238922
intent acc: 0.953125


 42%|████▏     | 302/723 [01:03<01:22,  5.12it/s]

intent loss: 0.3193781077861786
intent acc: 0.9375


 55%|█████▌    | 401/723 [01:23<01:01,  5.22it/s]

intent loss: 0.14846791326999664
intent acc: 0.96875


 69%|██████▉   | 502/723 [01:43<00:44,  4.97it/s]

intent loss: 0.10314665734767914
intent acc: 0.96875


 83%|████████▎ | 601/723 [02:03<00:26,  4.62it/s]

intent loss: 0.18057522177696228
intent acc: 0.921875


 97%|█████████▋| 701/723 [02:24<00:04,  5.20it/s]

intent loss: 0.10463625192642212
intent acc: 0.96875


100%|██████████| 723/723 [02:28<00:00,  4.75it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.09| valid accuracy: 0.93| valid loss: 0.33

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<09:16,  1.30it/s]

intent loss: 0.02566172182559967
intent acc: 0.984375


 14%|█▍        | 101/723 [00:21<02:05,  4.94it/s]

intent loss: 0.09919841587543488
intent acc: 0.984375


 28%|██▊       | 201/723 [00:42<01:48,  4.82it/s]

intent loss: 0.1786399483680725
intent acc: 0.96875


 42%|████▏     | 302/723 [01:03<01:23,  5.05it/s]

intent loss: 0.0342530831694603
intent acc: 0.984375


 55%|█████▌    | 401/723 [01:23<01:08,  4.67it/s]

intent loss: 0.2793881893157959
intent acc: 0.9375


 69%|██████▉   | 501/723 [01:43<00:45,  4.87it/s]

intent loss: 0.014937937259674072
intent acc: 1.0


 83%|████████▎ | 602/723 [02:04<00:24,  4.95it/s]

intent loss: 0.07276134192943573
intent acc: 0.984375


 97%|█████████▋| 702/723 [02:24<00:04,  5.22it/s]

intent loss: 0.08053798228502274
intent acc: 0.984375


100%|██████████| 723/723 [02:29<00:00,  5.06it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.08| valid accuracy: 0.93| valid loss: 0.38

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<06:28,  1.86it/s]

intent loss: 0.16372990608215332
intent acc: 0.96875


 14%|█▍        | 101/723 [00:20<01:51,  5.59it/s]

intent loss: 0.052596092224121094
intent acc: 0.96875


 28%|██▊       | 202/723 [00:41<01:48,  4.79it/s]

intent loss: 0.06317295134067535
intent acc: 0.96875


 42%|████▏     | 301/723 [01:02<01:24,  5.02it/s]

intent loss: 0.1673908829689026
intent acc: 0.953125


 55%|█████▌    | 401/723 [01:23<01:05,  4.95it/s]

intent loss: 0.04265959560871124
intent acc: 0.984375


 69%|██████▉   | 502/723 [01:44<00:45,  4.83it/s]

intent loss: 0.018503032624721527
intent acc: 1.0


 83%|████████▎ | 602/723 [02:05<00:24,  4.95it/s]

intent loss: 0.027199864387512207
intent acc: 1.0


 97%|█████████▋| 702/723 [02:25<00:04,  4.48it/s]

intent loss: 0.1308099776506424
intent acc: 0.96875


100%|██████████| 723/723 [02:30<00:00,  3.83it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.07| valid accuracy: 0.94| valid loss: 0.33

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<07:10,  1.67it/s]

intent loss: 0.05123135447502136
intent acc: 0.96875


 14%|█▍        | 101/723 [00:21<02:04,  5.01it/s]

intent loss: 0.02335166186094284
intent acc: 1.0


 28%|██▊       | 201/723 [00:42<01:58,  4.42it/s]

intent loss: 0.010862879455089569
intent acc: 1.0


 42%|████▏     | 301/723 [01:03<01:32,  4.57it/s]

intent loss: 0.009808160364627838
intent acc: 1.0


 55%|█████▌    | 401/723 [01:23<01:10,  4.58it/s]

intent loss: 0.05136634409427643
intent acc: 0.984375


 69%|██████▉   | 502/723 [01:45<00:45,  4.90it/s]

intent loss: 0.039281412959098816
intent acc: 0.984375


 83%|████████▎ | 602/723 [02:06<00:26,  4.49it/s]

intent loss: 0.0031185895204544067
intent acc: 1.0


 97%|█████████▋| 702/723 [02:27<00:04,  5.08it/s]

intent loss: 0.05383843928575516
intent acc: 0.96875


100%|██████████| 723/723 [02:31<00:00,  4.43it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.07| valid accuracy: 0.93| valid loss: 0.37

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<09:36,  1.25it/s]

intent loss: 0.04114902764558792
intent acc: 0.984375


 14%|█▍        | 101/723 [00:20<02:00,  5.17it/s]

intent loss: 0.016583561897277832
intent acc: 1.0


 28%|██▊       | 201/723 [00:41<01:43,  5.03it/s]

intent loss: 0.040397197008132935
intent acc: 0.96875


 42%|████▏     | 301/723 [01:02<01:24,  5.00it/s]

intent loss: 0.04849689453840256
intent acc: 0.984375


 56%|█████▌    | 402/723 [01:22<00:58,  5.48it/s]

intent loss: 0.018668845295906067
intent acc: 1.0


 69%|██████▉   | 501/723 [01:42<00:46,  4.79it/s]

intent loss: 0.12177949398756027
intent acc: 0.96875


 83%|████████▎ | 602/723 [02:04<00:23,  5.19it/s]

intent loss: 0.05466991662979126
intent acc: 0.984375


 97%|█████████▋| 701/723 [02:23<00:04,  4.56it/s]

intent loss: 0.05544475466012955
intent acc: 1.0


100%|██████████| 723/723 [02:28<00:00,  4.77it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.93| valid loss: 0.36

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<09:21,  1.29it/s]

intent loss: 0.015931442379951477
intent acc: 1.0


 14%|█▍        | 102/723 [00:22<02:00,  5.13it/s]

intent loss: 0.012609913945198059
intent acc: 1.0


 28%|██▊       | 201/723 [00:41<01:44,  5.02it/s]

intent loss: 0.08275187015533447
intent acc: 0.984375


 42%|████▏     | 302/723 [01:02<01:29,  4.72it/s]

intent loss: 0.1525733321905136
intent acc: 0.96875


 55%|█████▌    | 401/723 [01:22<01:12,  4.44it/s]

intent loss: 0.07102292031049728
intent acc: 0.984375


 69%|██████▉   | 501/723 [01:42<00:42,  5.18it/s]

intent loss: 0.07134727388620377
intent acc: 0.984375


 83%|████████▎ | 601/723 [02:03<00:25,  4.74it/s]

intent loss: 0.09763544052839279
intent acc: 0.96875


 97%|█████████▋| 701/723 [02:25<00:04,  4.70it/s]

intent loss: 0.10471023619174957
intent acc: 0.9375


100%|██████████| 723/723 [02:29<00:00,  5.26it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.93| valid loss: 0.39

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<06:45,  1.78it/s]

intent loss: 0.07417677342891693
intent acc: 0.953125


 14%|█▍        | 102/723 [00:20<01:54,  5.41it/s]

intent loss: 0.005555316805839539
intent acc: 1.0


 28%|██▊       | 201/723 [00:40<01:45,  4.96it/s]

intent loss: 0.03143634647130966
intent acc: 0.984375


 42%|████▏     | 302/723 [01:01<01:21,  5.16it/s]

intent loss: 0.04595798999071121
intent acc: 0.984375


 55%|█████▌    | 401/723 [01:21<01:09,  4.60it/s]

intent loss: 0.014566205441951752
intent acc: 1.0


 69%|██████▉   | 502/723 [01:42<00:40,  5.47it/s]

intent loss: 0.02052786573767662
intent acc: 0.984375


 83%|████████▎ | 601/723 [02:02<00:28,  4.33it/s]

intent loss: 0.09722520411014557
intent acc: 0.984375


 97%|█████████▋| 702/723 [02:23<00:04,  5.15it/s]

intent loss: 0.007953420281410217
intent acc: 1.0


100%|██████████| 723/723 [02:27<00:00,  4.76it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.94| valid loss: 0.33

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:01<07:41,  1.56it/s]

intent loss: 0.14960825443267822
intent acc: 0.9375


 14%|█▍        | 102/723 [00:21<01:59,  5.22it/s]

intent loss: 0.019187554717063904
intent acc: 1.0


 28%|██▊       | 202/723 [00:41<01:45,  4.93it/s]

intent loss: 0.09758783876895905
intent acc: 0.953125


 42%|████▏     | 302/723 [01:02<01:21,  5.15it/s]

intent loss: 0.02087213099002838
intent acc: 1.0


 56%|█████▌    | 402/723 [01:23<01:03,  5.02it/s]

intent loss: 0.011148594319820404
intent acc: 1.0


 69%|██████▉   | 502/723 [01:44<00:43,  5.03it/s]

intent loss: 0.05805689096450806
intent acc: 0.984375


 83%|████████▎ | 602/723 [02:04<00:24,  5.01it/s]

intent loss: 0.007336929440498352
intent acc: 1.0


 97%|█████████▋| 702/723 [02:25<00:04,  5.05it/s]

intent loss: 0.06598013639450073
intent acc: 0.96875


100%|██████████| 723/723 [02:30<00:00,  4.61it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.93| valid loss: 0.35

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<06:49,  1.76it/s]

intent loss: 0.02644757181406021
intent acc: 1.0


 14%|█▍        | 102/723 [00:21<02:01,  5.12it/s]

intent loss: 0.022243767976760864
intent acc: 1.0


 28%|██▊       | 202/723 [00:42<01:54,  4.57it/s]

intent loss: 0.0503859743475914
intent acc: 0.984375


 42%|████▏     | 301/723 [01:04<01:28,  4.78it/s]

intent loss: 0.02430190145969391
intent acc: 0.984375


 56%|█████▌    | 402/723 [01:25<01:01,  5.18it/s]

intent loss: 0.07800643146038055
intent acc: 0.96875


 69%|██████▉   | 501/723 [01:45<00:44,  4.94it/s]

intent loss: 0.18627110123634338
intent acc: 0.96875


 83%|████████▎ | 602/723 [02:07<00:24,  4.90it/s]

intent loss: 0.2532527446746826
intent acc: 0.953125


 97%|█████████▋| 702/723 [02:28<00:04,  4.63it/s]

intent loss: 0.14670652151107788
intent acc: 0.953125


100%|██████████| 723/723 [02:33<00:00,  5.14it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.05| valid accuracy: 0.93| valid loss: 0.36

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<07:22,  1.63it/s]

intent loss: 0.025377213954925537
intent acc: 1.0


 14%|█▍        | 102/723 [00:21<02:23,  4.32it/s]

intent loss: 0.01360706239938736
intent acc: 1.0


 28%|██▊       | 202/723 [00:42<01:45,  4.96it/s]

intent loss: 0.009695030748844147
intent acc: 1.0


 42%|████▏     | 302/723 [01:02<01:27,  4.83it/s]

intent loss: 0.020524367690086365
intent acc: 1.0


 55%|█████▌    | 401/723 [01:24<01:07,  4.76it/s]

intent loss: 0.04915823042392731
intent acc: 0.96875


 69%|██████▉   | 501/723 [01:45<00:46,  4.76it/s]

intent loss: 0.02107015997171402
intent acc: 1.0


 83%|████████▎ | 601/723 [02:06<00:30,  3.99it/s]

intent loss: 0.09864328801631927
intent acc: 0.984375


 97%|█████████▋| 701/723 [02:27<00:04,  5.26it/s]

intent loss: 0.034230247139930725
intent acc: 0.984375


100%|██████████| 723/723 [02:31<00:00,  5.47it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.99| train loss: 0.05| valid accuracy: 0.93| valid loss: 0.38

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<10:09,  1.18it/s]

intent loss: 0.010776534676551819
intent acc: 1.0


 14%|█▍        | 101/723 [00:21<02:02,  5.08it/s]

intent loss: 0.044520966708660126
intent acc: 0.984375


 28%|██▊       | 202/723 [00:41<01:40,  5.21it/s]

intent loss: 0.045178405940532684
intent acc: 0.984375


 42%|████▏     | 302/723 [01:02<01:25,  4.92it/s]

intent loss: 0.04612509906291962
intent acc: 0.984375


 55%|█████▌    | 401/723 [01:22<01:04,  4.99it/s]

intent loss: 0.06874874234199524
intent acc: 0.96875


 69%|██████▉   | 502/723 [01:43<00:50,  4.37it/s]

intent loss: 0.03659147024154663
intent acc: 0.984375


 83%|████████▎ | 602/723 [02:03<00:25,  4.76it/s]

intent loss: 0.013231679797172546
intent acc: 1.0


 97%|█████████▋| 701/723 [02:23<00:04,  4.90it/s]

intent loss: 0.22818344831466675
intent acc: 0.96875


100%|██████████| 723/723 [02:28<00:00,  5.72it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.05| valid accuracy: 0.93| valid loss: 0.41

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<07:49,  1.54it/s]

intent loss: 0.045273780822753906
intent acc: 0.984375


 14%|█▍        | 102/723 [00:20<01:54,  5.44it/s]

intent loss: 0.014841195195913315
intent acc: 1.0


 28%|██▊       | 202/723 [00:40<01:52,  4.61it/s]

intent loss: 0.006893336772918701
intent acc: 1.0


 42%|████▏     | 302/723 [01:01<01:25,  4.92it/s]

intent loss: 0.17275351285934448
intent acc: 0.9375


 55%|█████▌    | 401/723 [01:21<01:09,  4.62it/s]

intent loss: 0.01622479408979416
intent acc: 1.0


 69%|██████▉   | 501/723 [01:41<00:44,  5.02it/s]

intent loss: 0.002327732741832733
intent acc: 1.0


 83%|████████▎ | 602/723 [02:03<00:27,  4.44it/s]

intent loss: 0.05339011549949646
intent acc: 0.984375


 97%|█████████▋| 702/723 [02:23<00:03,  5.37it/s]

intent loss: 0.11486387252807617
intent acc: 0.984375


100%|██████████| 723/723 [02:27<00:00,  5.24it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.05| valid accuracy: 0.93| valid loss: 0.42

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 1/723 [00:00<09:17,  1.30it/s]

intent loss: 0.010771676898002625
intent acc: 1.0


 14%|█▍        | 102/723 [00:21<02:03,  5.02it/s]

intent loss: 0.00395597517490387
intent acc: 1.0


 28%|██▊       | 202/723 [00:41<01:36,  5.42it/s]

intent loss: 0.022342652082443237
intent acc: 0.984375


 42%|████▏     | 302/723 [01:03<01:31,  4.62it/s]

intent loss: 0.08131878823041916
intent acc: 0.96875


 56%|█████▌    | 402/723 [01:22<01:00,  5.34it/s]

intent loss: 0.022298887372016907
intent acc: 1.0


 69%|██████▉   | 502/723 [01:43<00:42,  5.25it/s]

intent loss: 0.022434517741203308
intent acc: 0.984375


 83%|████████▎ | 601/723 [02:04<00:25,  4.76it/s]

intent loss: 0.09739715605974197
intent acc: 0.96875


 97%|█████████▋| 701/723 [02:24<00:04,  4.46it/s]

intent loss: 0.11390771716833115
intent acc: 0.984375


100%|██████████| 723/723 [02:29<00:00,  5.03it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.94| valid loss: 0.41

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<07:10,  1.67it/s]

intent loss: 0.01720711961388588
intent acc: 1.0


 14%|█▍        | 102/723 [00:20<01:55,  5.38it/s]

intent loss: 0.13646696507930756
intent acc: 0.984375


 28%|██▊       | 201/723 [00:41<01:49,  4.76it/s]

intent loss: 0.08392855525016785
intent acc: 0.96875


 42%|████▏     | 301/723 [01:02<01:25,  4.94it/s]

intent loss: 0.05759446322917938
intent acc: 0.984375


 56%|█████▌    | 402/723 [01:23<01:02,  5.10it/s]

intent loss: 0.004284888505935669
intent acc: 1.0


 69%|██████▉   | 502/723 [01:44<00:42,  5.15it/s]

intent loss: 0.2676129937171936
intent acc: 0.953125


 83%|████████▎ | 602/723 [02:04<00:23,  5.05it/s]

intent loss: 0.010850086808204651
intent acc: 1.0


 97%|█████████▋| 702/723 [02:24<00:04,  4.96it/s]

intent loss: 0.09086795896291733
intent acc: 0.9375


100%|██████████| 723/723 [02:28<00:00,  5.19it/s]
  0%|          | 0/723 [00:00<?, ?it/s]

*intents*| train accuracy: 0.98| train loss: 0.05| valid accuracy: 0.94| valid loss: 0.33

sinc0: frozen
conv1: frozen
conv2: frozen
phone_rnn0: frozen
phone_rnn1: frozen
word_rnn0: unfrozen
word_rnn1: unfrozen


  0%|          | 2/723 [00:00<07:21,  1.63it/s]

intent loss: 0.21427899599075317
intent acc: 0.9375


 14%|█▍        | 101/723 [00:20<02:02,  5.06it/s]

intent loss: 0.013571903109550476
intent acc: 0.984375


 28%|██▊       | 201/723 [00:41<01:35,  5.48it/s]

intent loss: 0.224066823720932
intent acc: 0.953125


 42%|████▏     | 301/723 [01:02<01:53,  3.72it/s]

intent loss: 0.021574005484580994
intent acc: 1.0


 56%|█████▌    | 402/723 [01:24<01:06,  4.84it/s]

intent loss: 0.0017220675945281982
intent acc: 1.0


 69%|██████▉   | 502/723 [01:44<00:46,  4.80it/s]

intent loss: 0.19392277300357819
intent acc: 0.953125


 83%|████████▎ | 601/723 [02:05<00:24,  5.02it/s]

intent loss: 0.023221246898174286
intent acc: 0.984375


 97%|█████████▋| 701/723 [02:25<00:04,  4.75it/s]

intent loss: 0.05376802384853363
intent acc: 0.96875


100%|██████████| 723/723 [02:30<00:00,  5.10it/s]


*intents*| train accuracy: 0.98| train loss: 0.06| valid accuracy: 0.93| valid loss: 0.40

CPU times: user 45min 35s, sys: 6min 52s, total: 52min 28s
Wall time: 52min 39s


# 3. Evaluate the Model

In [9]:
test_intent_acc, test_intent_loss = trainer.test(test_dataset)
print("========= Test results =========")
print("*intents*| test accuracy: %.2f| test loss: %.2f| valid accuracy: %.2f| valid loss: %.2f\n" % (test_intent_acc, test_intent_loss, valid_intent_acc, valid_intent_loss) )

*intents*| test accuracy: 0.99| test loss: 0.05| valid accuracy: 0.93| valid loss: 0.40

