### Import dependencies

In [0]:
#@title Install packages
!pip install thop

Collecting thop
  Downloading https://files.pythonhosted.org/packages/6c/8b/22ce44e1c71558161a8bd54471123cc796589c7ebbfc15a7e8932e522f83/thop-0.0.31.post2005241907-py3-none-any.whl
Installing collected packages: thop
Successfully installed thop-0.0.31.post2005241907


In [0]:
#@title Import packages
# coding: utf-8
import shutil
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa
import math
import random

from pathlib import Path

from torchvision.transforms import Compose
from tqdm import *

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from thop import profile
from torchvision.models.quantization import shufflenet_v2_x0_5

### Recover Google Speech Commands dataset

In [0]:
#@title Remove old data dir
!rm -rf data/

In [0]:
#@title Download and unzip dataset
!wget -nc http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
!mkdir -p data/audio
!mkdir model
!tar -xzvf speech_commands_v0.01.tar.gz -C data/audio/
!rm speech_commands_v0.01.tar.gz

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
./up/6f342826_nohash_0.wav
./up/e0a7c5a0_nohash_0.wav
./up/4d4e17f5_nohash_1.wav
./up/b0f24c9b_nohash_0.wav
./up/735845ab_nohash_2.wav
./up/53d5b86f_nohash_0.wav
./up/1a5b9ca4_nohash_1.wav
./up/23abe1c9_nohash_2.wav
./up/bdee441c_nohash_1.wav
./up/a1cff772_nohash_1.wav
./up/1ecfb537_nohash_3.wav
./up/37fc5d97_nohash_3.wav
./up/bd8412df_nohash_1.wav
./up/e53139ad_nohash_1.wav
./up/10ace7eb_nohash_3.wav
./up/30065f33_nohash_0.wav
./up/eefd26f3_nohash_0.wav
./up/c9b653a0_nohash_2.wav
./up/02746d24_nohash_0.wav
./up/e1469561_nohash_0.wav
./up/4bba14ce_nohash_0.wav
./up/b5d1e505_nohash_1.wav
./up/531a5b8a_nohash_1.wav
./up/0135f3f2_nohash_0.wav
./up/dbb40d24_nohash_4.wav
./up/e9287461_nohash_1.wav
./up/71e6ab20_nohash_0.wav
./up/ead2934a_nohash_1.wav
./up/f9af0887_nohash_0.wav
./up/ff63ab0b_nohash_0.wav
./up/f3d06008_nohash_0.wav
./up/918a2473_nohash_4.wav
./up/e54a0f16_nohash_0.wav
./up/cb8f8307_nohash_1.wav
./up/d197e3ae_noh

In [0]:
#@title Split training e testing set
def move_files(src_folder, to_folder, list_file):
    with open(list_file) as f:
        for line in f.readlines():
            line = line.rstrip()
            dirname = os.path.dirname(line)
            dest = os.path.join(to_folder, dirname)
            if not os.path.exists(dest):
                os.mkdir(dest)
            shutil.move(os.path.join(src_folder, line), dest)


audio_folder = os.path.join('data/', 'audio')
validation_path = os.path.join(audio_folder, 'validation_list.txt')
test_path = os.path.join(audio_folder, 'testing_list.txt')

valid_folder = os.path.join('data/', 'valid')
test_folder = os.path.join('data/', 'test')
train_folder = os.path.join('data/', 'train')
os.mkdir(valid_folder)
os.mkdir(test_folder)

move_files(audio_folder, test_folder, test_path)
move_files(audio_folder, valid_folder, validation_path)
os.rename(audio_folder, train_folder)

In [0]:
#@title Remove validation set
!cp -R data/valid/* data/test/
!rm -rf data/valid/

In [0]:
#@title Manage dataset in Py
CLASSES = 'unknown, silence, yes, no, up, down, left, right, on, off, stop, go'.split(', ')


class SpeechCommandsDataset(Dataset):
    """Google speech commands dataset. Only 'yes', 'no', 'up', 'down', 'left',
    'right', 'on', 'off', 'stop' and 'go' are treated as known classes.
    All other classes are used as 'unknown' samples.
    See for more information: https://www.kaggle.com/c/tensorflow-speech-recognition-challenge
    """

    def __init__(self, folder, transform=None, classes=CLASSES, silence_percentage=0.1):
        all_classes = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith('_')]
        #for c in classes[2:]:
        #    assert c in all_classes

        class_to_idx = {classes[i]: i for i in range(len(classes))}
        for c in all_classes:
            if c not in class_to_idx:
                class_to_idx[c] = 0

        data = []
        for c in all_classes:
            d = os.path.join(folder, c)
            target = class_to_idx[c]
            for f in os.listdir(d):
                path = os.path.join(d, f)
                data.append((path, target))

        # add silence
        target = class_to_idx['silence']
        data += [('', target)] * int(len(data) * silence_percentage)

        self.classes = classes
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path, target = self.data[index]
        data = {'path': path, 'target': target}

        if self.transform is not None:
            data = self.transform(data)

        return data

    def make_weights_for_balanced_classes(self):
        """adopted from https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3"""

        nclasses = len(self.classes)
        count = np.zeros(nclasses)
        for item in self.data:
            count[item[1]] += 1

        N = float(sum(count))
        weight_per_class = N / count
        weight = np.zeros(len(self))
        for idx, item in enumerate(self.data):
            weight[idx] = weight_per_class[item[1]]
        return weight


class BackgroundNoiseDataset(Dataset):
    """Dataset for silence / background noise."""

    def __init__(self, folder, transform=None, sample_rate=16000, sample_length=1):
        audio_files = [d for d in os.listdir(folder) if os.path.isfile(os.path.join(folder, d)) and d.endswith('.wav')]
        samples = []
        for f in audio_files:
            path = os.path.join(folder, f)
            s, sr = librosa.load(path, sample_rate)
            samples.append(s)

        samples = np.hstack(samples)
        c = int(sample_rate * sample_length)
        r = len(samples) // c
        self.samples = samples[:r*c].reshape(-1, c)
        self.sample_rate = sample_rate
        self.classes = CLASSES
        self.transform = transform
        self.path = folder

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        data = {'samples': self.samples[index], 'sample_rate': self.sample_rate, 'target': 1, 'path': self.path}

        if self.transform is not None:
            data = self.transform(data)

        return data

### Preprocess data

In [0]:
#@title
def should_apply_transform(prob=0.5):
    """Transforms are only randomly applied with the given probability."""
    return random.random() < prob


class LoadAudio(object):
    """Loads an audio into a numpy array."""

    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def __call__(self, data):
        path = data['path']
        if path:
            samples, sample_rate = librosa.load(path, self.sample_rate)
        else:
            # silence
            sample_rate = self.sample_rate
            samples = np.zeros(sample_rate, dtype=np.float32)
        data['samples'] = samples
        data['sample_rate'] = sample_rate
        return data


class FixAudioLength(object):
    """Either pads or truncates an audio into a fixed length."""

    def __init__(self, time=1):
        self.time = time

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        length = int(self.time * sample_rate)
        if length < len(samples):
            data['samples'] = samples[:length]
        elif length > len(samples):
            data['samples'] = np.pad(samples, (0, length - len(samples)), "constant")
        return data


class ChangeAmplitude(object):
    """Changes amplitude of an audio randomly."""

    def __init__(self, amplitude_range=(0.7, 1.1)):
        self.amplitude_range = amplitude_range

    def __call__(self, data):
        if not should_apply_transform():
            return data

        data['samples'] = data['samples'] * random.uniform(*self.amplitude_range)
        return data


class ChangeSpeedAndPitchAudio(object):
    """Change the speed of an audio. This transform also changes the pitch of the audio."""

    def __init__(self, max_scale=0.2):
        self.max_scale = max_scale

    def __call__(self, data):
        if not should_apply_transform():
            return data

        samples = data['samples']
        sample_rate = data['sample_rate']
        scale = random.uniform(-self.max_scale, self.max_scale)
        speed_fac = 1.0  / (1 + scale)
        data['samples'] = np.interp(np.arange(0, len(samples), speed_fac), np.arange(0,len(samples)), samples).astype(np.float32)
        return data


class StretchAudio(object):
    """Stretches an audio randomly."""

    def __init__(self, max_scale=0.2):
        self.max_scale = max_scale

    def __call__(self, data):
        if not should_apply_transform():
            return data

        scale = random.uniform(-self.max_scale, self.max_scale)
        data['samples'] = librosa.effects.time_stretch(data['samples'], 1+scale)
        return data


class TimeshiftAudio(object):
    """Shifts an audio randomly."""

    def __init__(self, max_shift_seconds=0.2):
        self.max_shift_seconds = max_shift_seconds

    def __call__(self, data):
        if not should_apply_transform():
            return data

        samples = data['samples']
        sample_rate = data['sample_rate']
        max_shift = (sample_rate * self.max_shift_seconds)
        shift = random.randint(-max_shift, max_shift)
        a = -min(0, shift)
        b = max(0, shift)
        samples = np.pad(samples, (a, b), "constant")
        data['samples'] = samples[:len(samples) - a] if a else samples[b:]
        return data


class AddBackgroundNoise(Dataset):
    """Adds a random background noise."""

    def __init__(self, bg_dataset, max_percentage=0.45):
        self.bg_dataset = bg_dataset
        self.max_percentage = max_percentage

    def __call__(self, data):
        if not should_apply_transform():
            return data

        samples = data['samples']
        noise = random.choice(self.bg_dataset)['samples']
        percentage = random.uniform(0, self.max_percentage)
        data['samples'] = samples * (1 - percentage) + noise * percentage
        return data


class ToMelSpectrogram(object):
    """Creates the mel spectrogram from an audio. The result is a 32x32 matrix."""

    def __init__(self, n_mels=32):
        self.n_mels = n_mels

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        s = librosa.feature.melspectrogram(samples, sr=sample_rate, n_mels=self.n_mels)
        data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
        return data


class ToTensor(object):
    """Converts into a tensor."""

    def __init__(self, np_name, tensor_name, normalize=None):
        self.np_name = np_name
        self.tensor_name = tensor_name
        self.normalize = normalize

    def __call__(self, data):
        tensor = torch.FloatTensor(data[self.np_name])
        if self.normalize is not None:
            mean, std = self.normalize
            tensor -= mean
            tensor /= std
        data[self.tensor_name] = tensor
        return data

In [0]:
#@title
class ToSTFT(object):
    """Applies on an audio the short time fourier transform."""

    def __init__(self, n_fft=2048, hop_length=512):
        self.n_fft = n_fft
        self.hop_length = hop_length

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        data['n_fft'] = self.n_fft
        data['hop_length'] = self.hop_length
        data['stft'] = librosa.stft(samples, n_fft=self.n_fft, hop_length=self.hop_length)
        data['stft_shape'] = data['stft'].shape
        return data


class StretchAudioOnSTFT(object):
    """Stretches an audio on the frequency domain."""

    def __init__(self, max_scale=0.2):
        self.max_scale = max_scale

    def __call__(self, data):
        if not should_apply_transform():
            return data

        stft = data['stft']
        sample_rate = data['sample_rate']
        hop_length = data['hop_length']
        scale = random.uniform(-self.max_scale, self.max_scale)
        stft_stretch = librosa.core.phase_vocoder(stft, 1+scale, hop_length=hop_length)
        data['stft'] = stft_stretch
        return data


class TimeshiftAudioOnSTFT(object):
    """A simple timeshift on the frequency domain without multiplying with exp."""

    def __init__(self, max_shift=8):
        self.max_shift = max_shift

    def __call__(self, data):
        if not should_apply_transform():
            return data

        stft = data['stft']
        shift = random.randint(-self.max_shift, self.max_shift)
        a = -min(0, shift)
        b = max(0, shift)
        stft = np.pad(stft, ((0, 0), (a, b)), "constant")
        if a == 0:
            stft = stft[:,b:]
        else:
            stft = stft[:,0:-a]
        data['stft'] = stft
        return data


class AddBackgroundNoiseOnSTFT(Dataset):
    """Adds a random background noise on the frequency domain."""

    def __init__(self, bg_dataset, max_percentage=0.45):
        self.bg_dataset = bg_dataset
        self.max_percentage = max_percentage

    def __call__(self, data):
        if not should_apply_transform():
            return data

        noise = random.choice(self.bg_dataset)['stft']
        percentage = random.uniform(0, self.max_percentage)
        data['stft'] = data['stft'] * (1 - percentage) + noise * percentage
        return data


class FixSTFTDimension(object):
    """Either pads or truncates in the time axis on the frequency domain, applied after stretching, time shifting etc."""

    def __call__(self, data):
        stft = data['stft']
        t_len = stft.shape[1]
        orig_t_len = data['stft_shape'][1]
        if t_len > orig_t_len:
            stft = stft[:,0:orig_t_len]
        elif t_len < orig_t_len:
            stft = np.pad(stft, ((0, 0), (0, orig_t_len-t_len)), "constant")

        data['stft'] = stft
        return data


class ToMelSpectrogramFromSTFT(object):
    """Creates the mel spectrogram from the short time fourier transform of a file. The result is a 32x32 matrix."""

    def __init__(self, n_mels=32):
        self.n_mels = n_mels

    def __call__(self, data):
        stft = data['stft']
        sample_rate = data['sample_rate']
        n_fft = data['n_fft']
        mel_basis = librosa.filters.mel(sample_rate, n_fft, self.n_mels)
        s = np.dot(mel_basis, np.abs(stft)**2.0)
        data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
        return data


class DeleteSTFT(object):
    """Pytorch doesn't like complex numbers, use this transform to remove STFT after computing the mel spectrogram."""

    def __call__(self, data):
        del data['stft']
        return data


class AudioFromSTFT(object):
    """Inverse short time fourier transform."""

    def __call__(self, data):
        stft = data['stft']
        data['istft_samples'] = librosa.core.istft(stft, dtype=data['samples'].dtype)
        return data

In [0]:
#@title
def get_data(batch_size, root, use_gpu=True):

    data_aug_transform = Compose([ChangeAmplitude(), ChangeSpeedAndPitchAudio(), FixAudioLength(), ToSTFT(),
                                  StretchAudioOnSTFT(), TimeshiftAudioOnSTFT(), FixSTFTDimension()])
    bg_dataset = BackgroundNoiseDataset(root + "train/_background_noise_", data_aug_transform)
    add_bg_noise = AddBackgroundNoiseOnSTFT(bg_dataset)

    train_feature_transform = Compose([ToMelSpectrogramFromSTFT(n_mels=40), DeleteSTFT(),
                                       ToTensor('mel_spectrogram', 'input')])
    train_dataset = SpeechCommandsDataset(root + "train",
                                          Compose([LoadAudio(), data_aug_transform,
                                                   add_bg_noise, train_feature_transform]))

    valid_feature_transform = Compose([ToMelSpectrogram(n_mels=40), ToTensor('mel_spectrogram', 'input')])
    valid_dataset = SpeechCommandsDataset(root + "test",
                                          Compose([LoadAudio(), FixAudioLength(), valid_feature_transform]))

    weights = train_dataset.make_weights_for_balanced_classes()
    sampler = WeightedRandomSampler(weights, len(weights))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
                                  pin_memory=use_gpu, num_workers=2)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
                                  pin_memory=use_gpu, num_workers=2)

    return train_dataloader, valid_dataloader

### Build ShuffleNetV2

In [0]:
#@title
def initialize_net(model_path=''):
    net = shufflenet_v2_x0_5(num_classes=len(CLASSES), quantize=False)
    net.conv1[0] = nn.Conv2d(1, 24, 3, 2, 1, bias=False)

    if Path(model_path).exists():
      state_dict = torch.load(model_path)
      net.load_state_dict(state_dict, strict=False)

    return net

### Train

In [0]:
#@title
def get_optimizer(model, lr, wd, momentum):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
    return optimizer

  
def get_cost_function():
    cost_function = torch.nn.CrossEntropyLoss()
    return cost_function


def train(net, data_loader, optimizer, cost_function, device='cuda:0'):
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    net.train()  # Strictly needed if network contains layers which has different behaviours between train and test

    pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, position=0, leave=True)
    for batch in pbar:
        inputs = batch['input']
        inputs = torch.unsqueeze(inputs, 1)
        targets = batch['target']

        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = net(inputs)

        # Apply the loss
        loss = cost_function(outputs, targets)

        # Reset the optimizer

        # Backward pass
        loss.backward()

        # Update parameters
        optimizer.step()

        optimizer.zero_grad()

        # Better print something, no?
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

        # Free cuda memory
        torch.cuda.empty_cache()

    return cumulative_loss/samples, cumulative_accuracy/samples*100

### Test

In [0]:
#@title
def test(net, data_loader, cost_function, device='cuda:0', num_batches=0):
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    net.eval()  # Strictly needed if network contains layers which has different behaviours between train and test
    done_batches = 0

    with torch.no_grad():
        pbar = tqdm(data_loader, unit="audios", unit_scale=data_loader.batch_size, position=0, leave=True)
        for batch in pbar:
            inputs = batch['input']
            inputs = torch.unsqueeze(inputs, 1)
            targets = batch['target']

            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs = net(inputs)

            # Apply the loss
            loss = cost_function(outputs, targets)
            done_batches += 1

            # Better print something
            samples += inputs.shape[0]
            cumulative_loss += loss.item()  # Note: the .item() is needed to extract scalars from tensors
            _, predicted = outputs.max(1)
            cumulative_accuracy += predicted.eq(targets).sum().item()

            # Free cuda memory
            torch.cuda.empty_cache()

            if  0 < num_batches <= done_batches:
                return cumulative_loss/samples, cumulative_accuracy/samples*100

    return cumulative_loss/samples, cumulative_accuracy/samples*100

### Predict

In [0]:
#@title
def predict(input_path, net, use_gpu=True):
    print("Predicting " + input_path + "...")
    device = 'cuda:0' if use_gpu else 'cpu'

    feature_transform = Compose([ToMelSpectrogram(n_mels=40), ToTensor('mel_spectrogram', 'input')])
    transform = Compose([LoadAudio(), FixAudioLength(), feature_transform])

    audio = transform({'path': input_path})
    audio_loader = DataLoader(audio, batch_size=1, shuffle=False,
                              pin_memory=use_gpu, num_workers=2)

    net.eval()  # Strictly needed if network contains layers which has different behaviours between train and test

    audio = audio_loader.dataset['input']
    audio = torch.unsqueeze(audio, 1)
    audio = audio.view(1, 1, 40, 32)

    # Load data into GPU
    audio = audio.to(device)

    # Forward into network
    output = net.forward(audio)
    class_index = output.data.max(1, keepdim=True)[1]

    torch.cuda.empty_cache()
    return CLASSES[class_index]

### Main

In [0]:
#@title
def main(batch_size=128, 
         use_gpu=True,
         learning_rate=0.001, 
         weight_decay=0.000001, 
         momentum=0.9, 
         epochs=50,
         root="data/",
         save=True,
         perform_training=True):

    device = 'cuda:0' if use_gpu else 'cpu'
    torch.cuda.empty_cache

    model_path = 'model/weights_mn_2.pth'

    net = initialize_net(model_path=model_path).to(device)

    # Op Counter
    net.eval()
    net.to(device)
    dsize = (32, 1, 40, 32)
    r_input = torch.randn(dsize).to(device)
    macs, params = profile(net, inputs=(r_input,), verbose=False)
    print("\n%s\t| %s" % ("Params(M)", "FLOPs(G)"))
    print("%.2f\t\t| %.2f" % (params / (1000 ** 2), macs / (1000 ** 3)))
    print()

    if not perform_training:
        # print(root)
        input_audios = [predict(input_path=os.path.join(root, r), net=net, use_gpu=use_gpu) for r in os.listdir(root)]
        print(input_audios)
        return
    else:
        if root is '':
            print("You need to specify the dataset path")
            return
        train_loader, test_loader = get_data(batch_size=batch_size,
                                             root=root, use_gpu=use_gpu)

        print("{")
        optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)

        cost_function = get_cost_function()

        # train_loss, train_accuracy = test(net, train_loader, cost_function, device)
        # test_loss, test_accuracy = test(net, test_loader, cost_function, device)
        # print('\nTest before training')
        # print('\t{"Training loss": %.5f, "Training accuracy": %.2f},' % (train_loss, train_accuracy))
        # print('\t{"Test loss": %.5f, "Test accuracy": %.2f},' % (test_loss, test_accuracy))
        # print('-----------------------------------------------------')

        train_loss_curve = []
        train_accuracy_curve = []
        test_loss_curve = []
        test_accuracy_curve = []

        print('\t"training": [')

        for e in range(epochs):
            train_loss, train_accuracy = train(net, train_loader, optimizer, cost_function, device)
            test_loss, test_accuracy = test(net, test_loader, cost_function, device)
            print('\n\t\tEpoch: {:d}'.format(e+1))
            print('\t\t\t{"Training loss": %.5f, "Training accuracy": %.2f},' % (train_loss, train_accuracy))
            print('\t\t\t{"Test loss": %.5f, "Test accuracy": %.2f},' % (test_loss, test_accuracy))
            # print('-----------------------------------------------------')

            train_loss_curve.append(train_loss)
            train_accuracy_curve.append(train_accuracy)
            test_loss_curve.append(test_loss)
            test_accuracy_curve.append(test_accuracy)

            # saving model
            if save:
                print("Saving weights")
                torch.save(net.state_dict(), model_path)

        # train_loss, train_accuracy = test(net, train_loader, cost_function, device)
        # test_loss, test_accuracy = test(net, test_loader, cost_function, device)
        # print('\n\t\tTest after training')
        # print('\t\t{"Training loss": %.5f, "Training accuracy": %.2f},' % (train_loss, train_accuracy))
        # print('\t\t{"Test loss": %.5f, "Test accuracy": %.2f}' % (test_loss, test_accuracy))
        # print('-----------------------------------------------------')

        print('\t],')

### Execution

In [0]:
#@title Add a input
!mkdir -p data/input/
!cp data/test/go/022cd682_nohash_0.wav data/input/

In [0]:
#@title Run 10 training epochs
main(root='data/', epochs=1, batch_size=32, perform_training=True,
         save=True, use_gpu=True)

In [0]:
#@title Run input
main(root='data/input/', perform_training=False, save=False,
          use_gpu=True)


Params(M)	| FLOPs(G)
0.06		| 0.02

Predicting data/input/022cd682_nohash_0.wav...
['go']


##Export models

In [0]:
#@title Load high weights from drive
from google.colab import drive
drive.mount('/content/drive')
!cp "/content/drive/My Drive/Colab Notebooks/weights_shufflenet_v2_05.pth" model/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
def quantize_model(model, backend):
    _dummy_input_data = torch.rand(32, 1, 40, 32)
    if backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported ")
    torch.backends.quantized.engine = backend
    model.eval()
    
    # Make sure that weight qconfig matches that of the serialized models
    my_qconfig = torch.quantization.QConfig(activation=torch.quantization.MinMaxObserver.with_args(dtype=torch.quint8),
                           weight=torch.quantization.default_observer.with_args(dtype=torch.qint8))
    if backend == 'fbgemm':
        model.qconfig = torch.quantization.QConfig(
            activation=torch.quantization.default_observer,
            weight=torch.quantization.default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        model.qconfig = my_qconfig

    model.fuse_model()
    torch.quantization.prepare(model, inplace=True)
    
    # Calibrate with the training set
    batch_size = 32
    num_batches = 10
    print('Calibration in progress. Total batches: ' + str(batch_size * num_batches))
    cost_function = get_cost_function()
    train_loader, test_loader = get_data(batch_size=batch_size,
                                          root='data/', use_gpu=False)
    test_loss, test_accuracy = test(model, train_loader, cost_function, 'cpu', num_batches=num_batches)
    print()
    print('Test loss: %.5f, Test accuracy: %.2f' % (test_loss, test_accuracy))

    torch.quantization.convert(model, inplace=True)

    return

def get_net(import_model_path, quantize=False):
    empty_net = shufflenet_v2_x0_5(num_classes=len(CLASSES), quantize=False)
    empty_net.conv1[0] = nn.Conv2d(1, 24, 3, 2, 1, bias=False)

    if Path(import_model_path).exists():
        net = empty_net
        
        if quantize:
            backend = 'qnnpack'
            state_dict = torch.load(import_model_path, map_location=lambda storage, loc: storage)
            net.load_state_dict(state_dict, strict=False)
            quantize_model(model=net, backend=backend)
        else:
            state_dict = torch.load(import_model_path, map_location=lambda storage, loc: storage)
            net.load_state_dict(state_dict, strict=False)       
    else:
        raise Exception("File to import does not exists.")
    
    return net


def export_model(input_path, quantize=False):
    import_model_path = 'model/weights_shufflenet_v2_05.pth'
    export_model_path = 'model/mobile_shufflenet_v2_05.pt'
    
    dsize = (1, 1, 40, 32)

    feature_transform = Compose([ToMelSpectrogram(n_mels=40), ToTensor('mel_spectrogram', 'input')])
    transform = Compose([LoadAudio(), FixAudioLength(), feature_transform])

    audio = transform({'path': input_path})
    audio_loader = DataLoader(audio, batch_size=1, shuffle=False,
                              pin_memory=False, num_workers=2)
    audio = audio_loader.dataset['input']
    audio = torch.unsqueeze(audio, 1)
    audio = audio.view(dsize)

    net = get_net(import_model_path=import_model_path, quantize=False).to(torch.device("cpu"))
    net.cpu()
    net.eval()

    if quantize:
        # saving quantized version
        quantized_net = get_net(import_model_path=import_model_path, quantize=True).to(torch.device("cpu"))
        # quantized_model_path = import_model_path.split('.')[0] + '_quantized.pth'
        # torch.jit.save(torch.jit.script(quantized_net), quantized_model_path)
        # print("Saving quantized weights in " + quantized_model_path)

        # Test
        num_batches = 5
        batch_size = 32
        print('Test on ' + str(batch_size * num_batches) + ' samples')
        cost_function = get_cost_function()
        train_loader, test_loader = get_data(batch_size=batch_size,
                                              root='data/', use_gpu=False)
        test_loss, test_accuracy = test(quantized_net, test_loader, cost_function, 'cpu', num_batches=num_batches)
        print()
        print('Test loss: %.5f, Test accuracy: %.2f' % (test_loss, test_accuracy))

        # saving quantized version for mobile
        traced_script_module = torch.jit.trace(quantized_net, audio, check_trace=False)
        export_quantized_model_path = export_model_path.split('.')[0] + '_quantized.pt'
        traced_script_module.save(export_quantized_model_path)
        print("Exporting quantized version in " + export_quantized_model_path)

    
    # saving not-quantized version for mobile
    traced_script_module = torch.jit.trace(net, audio)
    traced_script_module.save(export_model_path)
    print("Exporting model in " + export_model_path)

In [0]:
data = 'data/input/022cd682_nohash_0.wav'
quantize = True

export_model(input_path=data, quantize=quantize)

Calibration in progress. Total batches: 320


  1%|          | 288/56224 [00:31<1:54:05,  8.17audios/s]


Test loss: 0.03246, Test accuracy: 63.44
Test on 160 samples


  1%|          | 96/15008 [00:13<43:58,  5.65audios/s]


Test loss: 0.02426, Test accuracy: 72.50
Exporting quantized version in model/mobile_shufflenet_v2_05_quantized.pt
Exporting model in model/mobile_shufflenet_v2_05.pt


In [0]:
#@title copy models into drive
!cp model/mobile_shufflenet_v2_05.pt "/content/drive/My Drive/Colab Notebooks/"
!cp model/mobile_shufflenet_v2_05_quantized.pt "/content/drive/My Drive/Colab Notebooks/"