# Dash

## Load libraries


In [1]:
import importlib

import os
import numpy as np
import time

import torch
from torch import nn
import torch.optim as optim
from torch.utils import data
from torchvision import transforms
import distiller.apputils as apputils

import sys

import ai8x

kws20 = importlib.import_module("datasets.kws20")

## Prepare dataset

Calculate the weights of the classes relative to the sample size

In [2]:
from pathlib import Path

raw_data_path = Path('data/KWS_DASH/raw')
class_file_count = {}

class_dirs = [d for d in raw_data_path.iterdir() if d.is_dir() and d.stem != "_background_noise_"]


for d in class_dirs:
    class_file_count[d] = len(list(d.iterdir()))

min_file_count = float(min(class_file_count.values()))

for d in class_dirs:
    class_file_count[d] = min_file_count / class_file_count[d]
    print(f"{d.stem}: {round(class_file_count[d], 7)}")

blinds: 1.0
dash: 1.0
down: 0.0079142
energy: 0.96875
lights: 0.96875
off: 0.0082777
on: 0.0080624
up: 0.0083266


Generate processed dataset

In [3]:
train_set, test_set = kws20.KWS_DASH_get_datasets(('data', ''))

No key `noise_var` in input augmentation dictionary!  Using defaults: [Min: 0., Max: 1.]
No key `shift` in input augmentation dictionary! Using defaults: [Min:-0.1, Max: 0.1]
No key `strech` in input augmentation dictionary! Using defaults: [Min: 0.8, Max: 1.3]

Processing train...
Class blinds (# 0): 81 elements
Class dash (# 1): 81 elements
Class down (# 2): 10665 elements
Class energy (# 3): 84 elements
Class lights (# 4): 84 elements
Class off (# 5): 10158 elements
Class on (# 6): 10464 elements
Class up (# 7): 10152 elements
Class UNKNOWN: 0 elements

Processing test...
Class blinds (# 0): 12 elements
Class dash (# 1): 12 elements
Class down (# 2): 1086 elements
Class energy (# 3): 12 elements
Class lights (# 4): 12 elements
Class off (# 5): 1077 elements
Class on (# 6): 1071 elements
Class up (# 7): 1017 elements
Class UNKNOWN: 0 elements


In [4]:
train_batch_size = 32
train_loader, val_loader, test_loader, _ = apputils.get_data_loaders(
    kws20.KWS_DASH_get_datasets, ("data", None), train_batch_size, 1)
print(f"Dataset sizes:\n\ttraining={len(train_loader.sampler)}\n\tvalidation={len(val_loader.sampler)}\n\ttest={len(test_loader.sampler)}")

No key `noise_var` in input augmentation dictionary!  Using defaults: [Min: 0., Max: 1.]
No key `shift` in input augmentation dictionary! Using defaults: [Min:-0.1, Max: 0.1]
No key `strech` in input augmentation dictionary! Using defaults: [Min: 0.8, Max: 1.3]

Processing train...
Class blinds (# 0): 81 elements
Class dash (# 1): 81 elements
Class down (# 2): 10665 elements
Class energy (# 3): 84 elements
Class lights (# 4): 84 elements
Class off (# 5): 10158 elements
Class on (# 6): 10464 elements
Class up (# 7): 10152 elements
Class UNKNOWN: 0 elements

Processing test...
Class blinds (# 0): 12 elements
Class dash (# 1): 12 elements
Class down (# 2): 1086 elements
Class energy (# 3): 12 elements
Class lights (# 4): 12 elements
Class off (# 5): 1077 elements
Class on (# 6): 1071 elements
Class up (# 7): 1017 elements
Class UNKNOWN: 0 elements
Dataset sizes:
	training=37593
	validation=4176
	test=4299


In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cuda:0


In [None]:
    class_dict = {'backward': 0, 'bed': 1, 'bird': 2, 'cat': 3, 'dog': 4, 'down': 5,
                  'eight': 6, 'five': 7, 'follow': 8, 'forward': 9, 'four': 10, 'go': 11,
                  'happy': 12, 'house': 13, 'learn': 14, 'left': 15, 'marvin': 16, 'nine': 17,
                  'no': 18, 'off': 19, 'on': 20, 'one': 21, 'right': 22, 'seven': 23,
                  'sheila': 24, 'six': 25, 'stop': 26, 'three': 27, 'tree': 28, 'two': 29,
                  'up': 30, 'visual': 31, 'wow': 32, 'yes': 33, 'zero': 34, 'dash': 35, 'energy': 36, 
                  'lights': 37, 'blinds': 38}
    classes = list(class_dict.keys())
    classes.sort()
    for i, c in enumerate(classes):
        print(f"\"{c}\": {i},")

## Load pretrained model


In [6]:
def count_params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [7]:
ai8x.set_device(device=85, simulate=False, round_avg=False)

mod = importlib.import_module("models.ai85net-kws20-v3")

model = mod.AI85KWS20Netv3(num_classes=21, num_channels=128, dimensions=(128, 1), bias=False)
print(f'Number of Model Params: {count_params(model)}')

model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint(
            model, "logs/2022.05.20-235449/qat_checkpoint.pth.tar", model_device='cuda')



Configuring device: MAX78000, simulate=False.
Number of Model Params: 169472


## Replace FC layer and freeze the rest of the layers

In [8]:
def freeze_layer(layer):
    for p in layer.parameters():
        p.requires_grad = False

In [9]:
freeze_layer(model.voice_conv1)
freeze_layer(model.voice_conv2)
freeze_layer(model.voice_conv3)
freeze_layer(model.voice_conv4)
freeze_layer(model.kws_conv1)
freeze_layer(model.kws_conv2)
freeze_layer(model.kws_conv3)
freeze_layer(model.kws_conv4)
model.fc = ai8x.Linear(256, 8, bias=False, wide=True)

model = model.to(device)

## Train the model

In [15]:
num_epochs = 50
epoch = 0
optimizer = optim.Adam(model.parameters(), lr=0.001)
ms_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 80], gamma=0.5)
criterion = torch.nn.CrossEntropyLoss(
    torch.Tensor((1, 1, 0.01, 1, 1, 0.01, 0.01, 0.01))
)
criterion.to(device)

qat_policy = {
    'start_epoch': 20,
    'weight_bits': 8
    }

In [16]:
best_acc = 0
best_qat_acc = 0
for epoch in range(0, num_epochs):
    if epoch > 0 and epoch == qat_policy['start_epoch']:
        print('QAT is starting!')
        # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
        ai8x.fuse_bn_layers(model)

        # Switch model from unquantized to quantized for QAT
        ai8x.initiate_qat(model, qat_policy)

        # Model is re-transferred to GPU in case parameters were added
        model.to(device)
    running_loss = []
    train_start = time.time()
    model.train()
    for idx, (inputs, target) in enumerate(train_loader):
        inputs = inputs.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        
        model_out = model(inputs)
        
        loss = criterion(model_out, target)
        loss.backward()
        optimizer.step()
        
        running_loss.append(loss.cpu().detach().numpy())

    mean_loss = np.mean(running_loss)
    train_end = time.time()
    print("Epoch: {}/{}\t LR: {}\t Train Loss: {:.4f}\t Dur: {:.2f} sec.".format(
        epoch+1, num_epochs, ms_lr_scheduler.get_lr(), mean_loss, (train_end-train_start)))
    
    model.eval()
    acc = 0.
    acc_weight = 0
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = inputs.to(device)
            target = target.to(device)
            model_out = model(inputs)
            target_out = torch.argmax(model_out, dim=1)
            
            tp = torch.sum(target_out == target)
            acc_batch = (tp / target_out.numel()).detach().item()
            acc += target_out.shape[0] * acc_batch
            acc_weight += target_out.shape[0]
            
        total_acc = 100 * (acc / acc_weight)
        if epoch == qat_policy['start_epoch']: best_acc = 0
        if total_acc > best_acc:
            best_acc = total_acc
            checkpoint_extras = {'current_top1': best_acc,
                                 'best_top1': best_acc,
                                 'best_epoch': epoch}
            model_name = 'ai85net_kws_dash'
            model_prefix = f'{model_name}' if epoch < qat_policy['start_epoch'] else (f'qat_{model_name}')
            apputils.save_checkpoint(epoch, model_name, model, optimizer=optimizer,
                                     scheduler=None, extras=checkpoint_extras,
                                     is_best=True, name=model_prefix,
                                     dir='.')
            print(f'Best model saved with accuracy: {best_acc:.2f}%')
            
        print('\t\t Test Acc: {:.2f}'.format(total_acc))
    ms_lr_scheduler.step()

Epoch: 1/50	 LR: [0.001]	 Train Loss: 0.4758	 Dur: 24.67 sec.
Best model saved with accuracy: 89.02%
		 Test Acc: 89.02
Epoch: 2/50	 LR: [0.001]	 Train Loss: 0.4319	 Dur: 20.29 sec.
		 Test Acc: 88.81
Epoch: 3/50	 LR: [0.001]	 Train Loss: 0.4262	 Dur: 22.08 sec.
Best model saved with accuracy: 89.39%
		 Test Acc: 89.39
Epoch: 4/50	 LR: [0.001]	 Train Loss: 0.4199	 Dur: 25.70 sec.
		 Test Acc: 88.58
Epoch: 5/50	 LR: [0.001]	 Train Loss: 0.4142	 Dur: 35.91 sec.
		 Test Acc: 88.81
Epoch: 6/50	 LR: [0.001]	 Train Loss: 0.4123	 Dur: 28.17 sec.
		 Test Acc: 89.11
Epoch: 7/50	 LR: [0.001]	 Train Loss: 0.4072	 Dur: 27.57 sec.
		 Test Acc: 88.74
Epoch: 8/50	 LR: [0.001]	 Train Loss: 0.4092	 Dur: 25.46 sec.
		 Test Acc: 88.51
Epoch: 9/50	 LR: [0.001]	 Train Loss: 0.3908	 Dur: 30.96 sec.
		 Test Acc: 89.32
Epoch: 10/50	 LR: [0.001]	 Train Loss: 0.3900	 Dur: 30.26 sec.
		 Test Acc: 88.88
Epoch: 11/50	 LR: [0.001]	 Train Loss: 0.4075	 Dur: 31.53 sec.
		 Test Acc: 89.35
Epoch: 12/50	 LR: [0.001]	 Tr