# Import and misc

In [1]:
# Instal latest torch and torchaudio

In [2]:
# !pip install thop

In [3]:
from typing import Tuple, Union, List, Callable, Optional
from tqdm import tqdm
from itertools import islice
import pathlib
import dataclasses

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
from torch import nn
from torch import distributions
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence

import torchaudio
from IPython import display as display_
import os 
import random

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


In [5]:
seed_everything(42)

# Task

In this notebook we will implement a model for finding a keyword in a stream.

We will implement the version with CRNN because it is easy and improves the model. 
(from https://www.dropbox.com/s/22ah2ba7dug6pzw/KWS_Attention.pdf)

In [6]:
@dataclasses.dataclass
class TaskConfig:
    keyword: str = 'sheila'  # We will use 1 key word -- 'sheila'
    batch_size: int = 128
    learning_rate: float = 3e-4
    weight_decay: float = 1e-5
    num_epochs: int = 20
    n_mels: int = 40
    cnn_out_channels: int = 8
    kernel_size: Tuple[int, int] = (5, 20)
    stride: Tuple[int, int] = (2, 8)
    hidden_size: int = 64
    gru_num_layers: int = 2
    bidirectional: bool = False
    num_classes: int = 2
    sample_rate: int = 16000
    device: torch.device = torch.device(
        'cuda:0' if torch.cuda.is_available() else 'cpu')
    temperature: int = 2

# Data

In [7]:
!wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz -O speech_commands_v0.01.tar.gz
!mkdir speech_commands && tar -C speech_commands -xvzf speech_commands_v0.01.tar.gz 1> log

--2022-11-06 19:59:23--  http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
Resolving download.tensorflow.org (download.tensorflow.org)... 108.177.119.128, 2a00:1450:4013:c00::80
Connecting to download.tensorflow.org (download.tensorflow.org)|108.177.119.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1489096277 (1.4G) [application/gzip]
Saving to: ‘speech_commands_v0.01.tar.gz’


2022-11-06 19:59:39 (90.9 MB/s) - ‘speech_commands_v0.01.tar.gz’ saved [1489096277/1489096277]

mkdir: cannot create directory ‘speech_commands’: File exists


In [8]:
class SpeechCommandDataset(Dataset):

    def __init__(
        self,
        transform: Optional[Callable] = None,
        path2dir: str = None,
        keywords: Union[str, List[str]] = None,
        csv: Optional[pd.DataFrame] = None
    ):        
        self.transform = transform

        if csv is None:
            path2dir = pathlib.Path(path2dir)
            keywords = keywords if isinstance(keywords, list) else [keywords]
            
            all_keywords = [
                p.stem for p in path2dir.glob('*')
                if p.is_dir() and not p.stem.startswith('_')
            ]

            triplets = []
            for keyword in all_keywords:
                paths = (path2dir / keyword).rglob('*.wav')
                if keyword in keywords:
                    for path2wav in paths:
                        triplets.append((path2wav.as_posix(), keyword, 1))
                else:
                    for path2wav in paths:
                        triplets.append((path2wav.as_posix(), keyword, 0))
            
            self.csv = pd.DataFrame(
                triplets,
                columns=['path', 'keyword', 'label']
            )

        else:
            self.csv = csv
    
    def __getitem__(self, index: int):
        instance = self.csv.iloc[index]

        path2wav = instance['path']
        wav, sr = torchaudio.load(path2wav)
        wav = wav.sum(dim=0)
        
        if self.transform:
            wav = self.transform(wav)

        return {
            'wav': wav,
            'keywors': instance['keyword'],
            'label': instance['label']
        }

    def __len__(self):
        return len(self.csv)

In [9]:
dataset = SpeechCommandDataset(
    path2dir='speech_commands', keywords=TaskConfig.keyword
)

In [10]:
dataset.csv.sample(5)

Unnamed: 0,path,keyword,label
50866,speech_commands/two/f4cae173_nohash_0.wav,two,0
1003,speech_commands/happy/d3a18257_nohash_1.wav,happy,0
38126,speech_commands/right/471a0925_nohash_0.wav,right,0
144,speech_commands/happy/518588b6_nohash_1.wav,happy,0
31651,speech_commands/go/5588c7e6_nohash_0.wav,go,0


### Augmentations

In [11]:
class AugsCreation:

    def __init__(self):
        self.background_noises = [
            'speech_commands/_background_noise_/white_noise.wav',
            'speech_commands/_background_noise_/dude_miaowing.wav',
            'speech_commands/_background_noise_/doing_the_dishes.wav',
            'speech_commands/_background_noise_/exercise_bike.wav',
            'speech_commands/_background_noise_/pink_noise.wav',
            'speech_commands/_background_noise_/running_tap.wav'
        ]

        self.noises = [
            torchaudio.load(p)[0].squeeze()
            for p in self.background_noises
        ]

    def add_rand_noise(self, audio):

        # randomly choose noise
        noise_num = torch.randint(low=0, high=len(
            self.background_noises), size=(1,)).item()
        noise = self.noises[noise_num]

        noise_level = torch.Tensor([1])  # [0, 40]

        noise_energy = torch.norm(noise)
        audio_energy = torch.norm(audio)
        alpha = (audio_energy / noise_energy) * \
            torch.pow(10, -noise_level / 20)

        start = torch.randint(
            low=0,
            high=max(int(noise.size(0) - audio.size(0) - 1), 1),
            size=(1,)
        ).item()
        noise_sample = noise[start: start + audio.size(0)]

        audio_new = audio + alpha * noise_sample
        audio_new.clamp_(-1, 1)
        return audio_new

    def __call__(self, wav):
        aug_num = torch.randint(low=0, high=4, size=(1,)).item()   # choose 1 random aug from augs
        augs = [
            lambda x: x,
            lambda x: (x + distributions.Normal(0, 0.01).sample(x.size())).clamp_(-1, 1),
            lambda x: torchaudio.transforms.Vol(.25)(x),
            lambda x: self.add_rand_noise(x)
        ]

        return augs[aug_num](wav)

In [12]:
indexes = torch.randperm(len(dataset))
train_indexes = indexes[:int(len(dataset) * 0.8)]
val_indexes = indexes[int(len(dataset) * 0.8):]

train_df = dataset.csv.iloc[train_indexes].reset_index(drop=True)
val_df = dataset.csv.iloc[val_indexes].reset_index(drop=True)

In [13]:
# Sample is a dict of utt, word and label
train_set = SpeechCommandDataset(csv=train_df, transform=AugsCreation())
val_set = SpeechCommandDataset(csv=val_df)

### Sampler for oversampling:

In [14]:
# We should provide to WeightedRandomSampler _weight for every sample_; by default it is 1/len(target)

def get_sampler(target):
    class_sample_count = np.array(
        [len(np.where(target == t)[0]) for t in np.unique(target)])   # for every class count it's number of occ.
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in target])
    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.float()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

In [15]:
train_sampler = get_sampler(train_set.csv['label'].values)

In [16]:
class Collator:
    
    def __call__(self, data):
        wavs = []
        labels = []    

        for el in data:
            wavs.append(el['wav'])
            labels.append(el['label'])

        # torch.nn.utils.rnn.pad_sequence takes list(Tensors) and returns padded (with 0.0) Tensor
        wavs = pad_sequence(wavs, batch_first=True)    
        labels = torch.Tensor(labels).long()
        return wavs, labels

###  Dataloaders

In [17]:
# Here we are obliged to use shuffle=False because of our sampler with randomness inside.

train_loader = DataLoader(train_set, batch_size=TaskConfig.batch_size,
                          shuffle=False, collate_fn=Collator(),
                          sampler=train_sampler,
                          num_workers=2, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=TaskConfig.batch_size,
                        shuffle=False, collate_fn=Collator(),
                        num_workers=2, pin_memory=True)

### Creating MelSpecs on GPU for speeeed: 


In [18]:
class LogMelspec:

    def __init__(self, is_train, config):
        # with augmentations
        if is_train:
            self.melspec = nn.Sequential(
                torchaudio.transforms.MelSpectrogram(
                    sample_rate=config.sample_rate,
                    n_fft=400,
                    win_length=400,
                    hop_length=160,
                    n_mels=config.n_mels
                ),
                torchaudio.transforms.FrequencyMasking(freq_mask_param=15),
                torchaudio.transforms.TimeMasking(time_mask_param=35),
            ).to(config.device)

        # no augmentations
        else:
            self.melspec = torchaudio.transforms.MelSpectrogram(
                sample_rate=config.sample_rate,
                n_fft=400,
                win_length=400,
                hop_length=160,
                n_mels=config.n_mels
            ).to(config.device)

    def __call__(self, batch):
        # already on device
        return torch.log(self.melspec(batch).clamp_(min=1e-9, max=1e9))

In [19]:
melspec_train = LogMelspec(is_train=True, config=TaskConfig)
melspec_val = LogMelspec(is_train=False, config=TaskConfig)

In [20]:
# !git clone https://github.com/leksious/KWS.git

### Quality measurment functions:

In [21]:
# FA - true: 0, model: 1
# FR - true: 1, model: 0

def count_FA_FR(preds, labels):
    FA = torch.sum(preds[labels == 0])
    FR = torch.sum(labels[preds == 0])
    
    # torch.numel - returns total number of elements in tensor
    return FA.item() / torch.numel(preds), FR.item() / torch.numel(preds)

In [22]:
def get_au_fa_fr(probs, labels):
    sorted_probs, _ = torch.sort(probs)
    sorted_probs = torch.cat((torch.Tensor([0]), sorted_probs, torch.Tensor([1])))
    labels = torch.cat(labels, dim=0)
        
    FAs, FRs = [], []
    for prob in sorted_probs:
        preds = (probs >= prob) * 1
        FA, FR = count_FA_FR(preds, labels)        
        FAs.append(FA)
        FRs.append(FR)
    # plt.plot(FAs, FRs)
    # plt.show()

    # ~ area under curve using trapezoidal rule
    return -np.trapz(FRs, x=FAs)

# Model

In [23]:
class Attention(nn.Module):

    def __init__(self, hidden_size: int):
        super().__init__()

        self.energy = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1)
        )
    
    def forward(self, input):
        energy = self.energy(input)
        alpha = torch.softmax(energy, dim=-2)
        return (input * alpha).sum(dim=-2)

class CRNN(nn.Module):

    def __init__(self, config: TaskConfig):
        super().__init__()
        self.config = config

        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=1, out_channels=config.cnn_out_channels,
                kernel_size=config.kernel_size, stride=config.stride
            ),
            nn.Flatten(start_dim=1, end_dim=2),
        )

        self.conv_out_frequency = (config.n_mels - config.kernel_size[0]) // \
            config.stride[0] + 1
        
        self.gru = nn.GRU(
            input_size=self.conv_out_frequency * config.cnn_out_channels,
            hidden_size=config.hidden_size,
            num_layers=config.gru_num_layers,
            dropout=0.1,
            bidirectional=config.bidirectional,
            batch_first=True
        )

        self.attention = Attention(config.hidden_size)
        self.classifier = nn.Linear(config.hidden_size, config.num_classes)
    
    def forward(self, input):
        input = input.unsqueeze(dim=1)
        conv_output = self.conv(input).transpose(-1, -2)
        gru_output, _ = self.gru(conv_output)
        contex_vector = self.attention(gru_output)
        output = self.classifier(contex_vector)
        return output

config = TaskConfig()
model = CRNN(config)
model

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): GRU(144, 64, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (classifier): Linear(in_features=64, out_features=2, bias=True)
)

In [24]:
def train_epoch(model, opt, loader, log_melspec, device):
    model.train()
    for i, (batch, labels) in tqdm(enumerate(loader), total=len(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)

        opt.zero_grad()

        # run model # with autocast():
        logits = model(batch)
        # we need probabilities so we use softmax & CE separately
        probs = F.softmax(logits, dim=-1)
        loss = F.cross_entropy(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        opt.step()

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)
        FA, FR = count_FA_FR(argmax_probs, labels)
        acc = torch.sum(argmax_probs == labels) / torch.numel(argmax_probs)

    return acc

In [25]:
@torch.no_grad()
def validation(model, loader, log_melspec, device):
    model.eval()

    val_losses, accs, FAs, FRs = [], [], [], []
    all_probs, all_labels = [], []
    for i, (batch, labels) in tqdm(enumerate(loader)):
        batch, labels = batch.to(device), labels.to(device)
        batch = log_melspec(batch)

        output = model(batch)
        # we need probabilities so we use softmax & CE separately
        probs = F.softmax(output, dim=-1)
        loss = F.cross_entropy(output, labels)

        # logging
        argmax_probs = torch.argmax(probs, dim=-1)
        all_probs.append(probs[:, 1].cpu())
        all_labels.append(labels.cpu())
        val_losses.append(loss.item())
        accs.append(
            torch.sum(argmax_probs == labels).item() /  # ???
            torch.numel(argmax_probs)
        )
        FA, FR = count_FA_FR(argmax_probs, labels)
        FAs.append(FA)
        FRs.append(FR)

    # area under FA/FR curve for whole loader
    au_fa_fr = get_au_fa_fr(torch.cat(all_probs, dim=0).cpu(), all_labels)
    return au_fa_fr

In [26]:
from collections import defaultdict
from IPython.display import clear_output
from matplotlib import pyplot as plt

history = defaultdict(list)

# Training

In [None]:
config = TaskConfig(hidden_size=32)
model = CRNN(config).to(config.device)

print(model)

opt = torch.optim.Adam(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): GRU(144, 32, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): Tanh()
      (2): Linear(in_features=32, out_features=1, bias=True)
    )
  )
  (classifier): Linear(in_features=32, out_features=2, bias=True)
)


In [None]:
model.load_state_dict(torch.load("/content/checkpoin_teacher.pth"))

<All keys matched successfully>

In [None]:
config_2 = TaskConfig(hidden_size=8, gru_num_layers=1)
model_2 = CRNN(config_2).to(config_2.device)

print(model_2)

opt = torch.optim.Adam(
    model_2.parameters(),
    lr=config_2.learning_rate,
    weight_decay=config_2.weight_decay
)

In [None]:
from train_val import train_epoch_distillation
from config import TaskConfig as cfg

In [None]:
cfg = cfg(num_epochs=150, temperature=1, learning_rate = 3e-5)

In [None]:
for n in range(cfg.num_epochs):

    train_epoch_distillation(model_2, model, opt, train_loader, melspec_train, config.device, cfg)

    au_fa_fr = validation(model_2, val_loader,
                          melspec_val, config.device)
    history['val_metric'].append(au_fa_fr)


    clear_output()
    plt.plot(history['val_metric'])
    plt.ylabel('Metric')
    plt.xlabel('Epoch')
    plt.grid()
    plt.show()
    print(au_fa_fr)

    print('END OF EPOCH', n)

#Pruning

###Прунинг исходной модели

In [27]:
import torch.nn.utils.prune as prune

In [28]:
config = TaskConfig(hidden_size=32)
model_prune = CRNN(config).to(config.device)
model_prune.load_state_dict(torch.load("/content/checkpoin_teacher.pth"))
module = model_prune.conv[0]
prune.random_unstructured(module, name="weight", amount=0.1)
prune.remove(module, 'weight')

Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))

### Прунинг дистилированной модели

In [29]:
config_2 = TaskConfig(hidden_size=8, gru_num_layers=1)
model_2_pruned = CRNN(config_2).to(config_2.device)
model_2_pruned.load_state_dict(torch.load("/content/model_squeezed_best_quality.pth"))

  "num_layers={}".format(dropout, num_layers))


<All keys matched successfully>

In [30]:
module = model_2_pruned.conv[0]

In [31]:
prune.random_unstructured(module, name="weight", amount=0.1)

Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))

In [32]:
prune.remove(module, 'weight')

Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))

In [33]:
validation(model_2_pruned, val_loader, melspec_val, 'cuda') 

102it [00:06, 14.67it/s]


4.039432957661294e-05

# Quantization 

In [34]:
import torch.quantization

In [35]:
config = TaskConfig(hidden_size=32)
model = CRNN(config).to(config.device)
model.load_state_dict(torch.load("/content/checkpoin_teacher.pth"))

<All keys matched successfully>

In [36]:
config_2 = TaskConfig(hidden_size=8, gru_num_layers=1)
model_2 = CRNN(config_2).to(config_2.device)
model_2.load_state_dict(torch.load("/content/model_squeezed_best_quality.pth"))

<All keys matched successfully>

### Квантизация дистиллированной модели

In [37]:
model_2.cpu()
quantizted_model = torch.quantization.quantize_dynamic(
    model_2, {nn.GRU, nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
quantizted_model.eval()

  "num_layers={}".format(dropout, num_layers))


CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 8, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=8, out_features=8, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=8, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=8, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [38]:
quantizted_model.to('cpu')

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 8, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=8, out_features=8, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=8, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=8, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [39]:
config = TaskConfig(device='cpu')

In [40]:
melspec_val_quant = LogMelspec(is_train=False, config=config)

In [41]:
validation(quantizted_model, val_loader, melspec_val_quant, 'cpu') 

102it [00:08, 11.52it/s]


4.0943343954076135e-05

### Квантизация исходной модели

In [42]:
model.cpu()
quantizted_model_full = torch.quantization.quantize_dynamic(
    model, {nn.GRU, nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
quantizted_model_full.eval()

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 32, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=32, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=32, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=32, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [43]:
quantizted_model_full.to('cpu')

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 32, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=32, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=32, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=32, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [44]:
validation(quantizted_model_full, val_loader, melspec_val_quant, 'cpu') 

102it [00:09, 10.90it/s]


3.076270778068248e-05

### Квантизация дистилированной модели после прунинга 

In [45]:
model_2_pruned.cpu()
quantizted_distilled_pruned = torch.quantization.quantize_dynamic(
    model_2_pruned, {nn.GRU, nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
quantizted_distilled_pruned.eval()

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 8, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=8, out_features=8, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=8, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=8, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [46]:
quantizted_distilled_pruned.to('cpu')

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 8, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=8, out_features=8, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=8, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=8, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [47]:
config = TaskConfig(device='cpu')

In [48]:
melspec_val_quant = LogMelspec(is_train=False, config=config)

In [49]:
validation(quantizted_distilled_pruned, val_loader, melspec_val_quant, 'cpu') 

102it [00:08, 11.80it/s]


4.3789864150271194e-05

### Квантизация исходной модели после прунинга

In [50]:
model_prune.cpu()
quantizted_pruned = torch.quantization.quantize_dynamic(
    model_prune, {nn.GRU, nn.Linear, nn.Conv2d}, dtype=torch.qint8
)
quantizted_pruned.eval()

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 32, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=32, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=32, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=32, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [51]:
quantizted_pruned.to('cpu')

CRNN(
  (conv): Sequential(
    (0): Conv2d(1, 8, kernel_size=(5, 20), stride=(2, 8))
    (1): Flatten(start_dim=1, end_dim=2)
  )
  (gru): DynamicQuantizedGRU(144, 32, num_layers=2, batch_first=True, dropout=0.1)
  (attention): Attention(
    (energy): Sequential(
      (0): DynamicQuantizedLinear(in_features=32, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
      (1): Tanh()
      (2): DynamicQuantizedLinear(in_features=32, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    )
  )
  (classifier): DynamicQuantizedLinear(in_features=32, out_features=2, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [52]:
config = TaskConfig(device='cpu')

In [53]:
melspec_val_quant = LogMelspec(is_train=False, config=config)

In [54]:
validation(quantizted_pruned, val_loader, melspec_val_quant, 'cpu') 

102it [00:08, 11.57it/s]


5.2031047359800296e-05

# Results

Для квантизованных моделей флопсы считать не будем, так как они не сильно будут отличаться от исходных

In [58]:
# !pip install wandb

####Изначальная модель

In [59]:
import wandb
wandb.login()

True

In [60]:
# !pip install thop

In [61]:
from validate_speed import get_size_in_megabytes
from validate_speed import Timer
from validate_speed import calc_flops_macs
from thop import profile

In [62]:
quality_1 = validation(model.to('cuda'), val_loader,
                          melspec_val, 'cuda')

102it [00:05, 17.17it/s]


In [63]:
calc_time = Timer()
calc_time.__enter__()
for iter in range(1000):
    model(torch.zeros(1, config_2.n_mels, 101).to('cuda'))
calc_time.__exit__()
time_1 = calc_time.t / 1000

In [64]:
size_rate_1 = get_size_in_megabytes(model) / get_size_in_megabytes(model)

In [65]:
from copy import deepcopy

In [66]:
model_copy = deepcopy(model)

In [67]:
config = TaskConfig(device='cuda')

In [68]:
flops, macs_1 = calc_flops_macs(model_copy.to('cuda'), config)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


In [69]:
wandb.init(project="KWS")

In [70]:
wandb.log({'model_size' : get_size_in_megabytes(model),
           'model_quality_compared_to_original':quality_1 / quality_1, 
           'model_quality':quality_1,
           'time_for_2_sec_frame':time_1, 
           'compression_rate': size_rate_1, 
           'macs':macs_1})

In [71]:
wandb.finish()

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,1.0
macs,25387.0
model_quality,3e-05
model_quality_compared_to_original,1.0
model_size,0.10011
time_for_2_sec_frame,0.00066


####Дистиллированная модель

In [72]:
quality_2 = validation(model_2.to('cuda'), val_loader,
                          melspec_val, 'cuda')

102it [00:06, 16.80it/s]


In [73]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    model_2(torch.zeros(1, config_2.n_mels, 101).to('cuda'))
    
calc_time.__exit__()
time_2 = calc_time.t / 1000

In [74]:
size_rate_2 = get_size_in_megabytes(model) / get_size_in_megabytes(model_2)

In [75]:
model_copy = deepcopy(model_2)

In [76]:
config = TaskConfig(device='cuda')

In [77]:
flops, macs_2 = calc_flops_macs(model_copy.to('cuda'), config)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


In [78]:
wandb.init(project="KWS")

In [79]:
wandb.log({'model_size' : get_size_in_megabytes(model_2),
           'model_quality_compared_to_original':quality_1 / quality_2,
           'model_quality':quality_2, 
           'time_for_2_sec_frame':time_2, 
           'compression_rate': size_rate_2, 
           'macs':macs_2})

In [80]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085482…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,4.8645
macs,4603.0
model_quality,4e-05
model_quality_compared_to_original,0.72943
model_size,0.02058
time_for_2_sec_frame,0.00066


####Исходная модель с пруннингом

In [81]:
quality_3 = validation(model_prune.to('cuda'), val_loader,
                          melspec_val, 'cuda')

102it [00:05, 17.45it/s]


In [82]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    model_prune(torch.zeros(1, config_2.n_mels, 101).to('cuda'))
    
calc_time.__exit__()
time_3 = calc_time.t / 1000

In [83]:
size_rate_3 = get_size_in_megabytes(model) / get_size_in_megabytes(model_prune)

In [84]:
model_copy = deepcopy(model_prune)

In [85]:
config = TaskConfig(device='cuda')

In [86]:
flops, macs_3 = calc_flops_macs(model_copy.to('cuda'), config)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


In [87]:
wandb.init(project="KWS")

In [88]:
wandb.log({'model_size' : get_size_in_megabytes(model_prune),
           'model_quality_compared_to_original':quality_1 / quality_3, 
           'model_quality':quality_3,
           'time_for_2_sec_frame':time_3, 
           'compression_rate': size_rate_3, 
           'macs':macs_3})

In [89]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085617…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,1.0
macs,25387.0
model_quality,5e-05
model_quality_compared_to_original,0.60089
model_size,0.10011
time_for_2_sec_frame,0.00066


#### Дистиллированная модель с прунингом

In [90]:
quality_4 = validation(model_2_pruned.to('cuda'), val_loader,
                          melspec_val, 'cuda')

102it [00:06, 16.99it/s]


In [91]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    model_2_pruned(torch.zeros(1, config_2.n_mels, 101).to('cuda'))
    
calc_time.__exit__()
time_4 = calc_time.t / 1000

In [92]:
size_rate_4 = get_size_in_megabytes(model) / get_size_in_megabytes(model_2_pruned)

In [93]:
model_copy = deepcopy(model_2_pruned)

In [94]:
config = TaskConfig(device='cuda')

In [95]:
flops, macs_4 = calc_flops_macs(model_copy.to('cuda'), config)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_gru() for <class 'torch.nn.modules.rnn.GRU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


In [96]:
wandb.init(project="KWS")

In [97]:
wandb.log({'model_size' : get_size_in_megabytes(model_2_pruned),
           'model_quality_compared_to_original':quality_1 / quality_4, 
           'model_quality':quality_4,
           'time_for_2_sec_frame':time_4, 
           'compression_rate': size_rate_4, 
           'macs':macs_4})

In [98]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085566…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,4.8645
macs,4603.0
model_quality,4e-05
model_quality_compared_to_original,0.736
model_size,0.02058
time_for_2_sec_frame,0.00065


#### Квантизованная дистиллированная модель

In [99]:
quality_5 = validation(quantizted_model.to('cpu'), val_loader,
                          melspec_val_quant, 'cpu')

102it [00:08, 11.64it/s]


In [100]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    quantizted_model(torch.zeros(1, config_2.n_mels, 101))
    
calc_time.__exit__()
time_5 = calc_time.t / 1000

In [101]:
size_rate_5 = get_size_in_megabytes(model) / get_size_in_megabytes(quantizted_model)

In [102]:
wandb.init(project="KWS")

In [103]:
wandb.log({'model_size' : get_size_in_megabytes(quantizted_model),
           'model_quality_compared_to_original':quality_1 / quality_5, 
           'model_quality':quality_5,
           'time_for_2_sec_frame':time_5, 
           'compression_rate': size_rate_5, 
           'macs':macs_2})

In [104]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085566…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,7.93372
macs,4603.0
model_quality,4e-05
model_quality_compared_to_original,0.72613
model_size,0.01262
time_for_2_sec_frame,0.00139


#### Квантизованная изначальная модель 

In [105]:
quality_6 = validation(quantizted_model_full.to('cpu'), val_loader,
                          melspec_val_quant, 'cpu')

102it [00:08, 11.37it/s]


In [106]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    quantizted_model_full(torch.zeros(1, config_2.n_mels, 101))
    
calc_time.__exit__()
time_6 = calc_time.t / 1000

In [107]:
size_rate_6 = get_size_in_megabytes(model) / get_size_in_megabytes(quantizted_model_full)

In [108]:
wandb.init(project="KWS")

In [109]:
wandb.log({'model_size' : get_size_in_megabytes(quantizted_model_full),
           'model_quality_compared_to_original':quality_1 / quality_6, 
           'model_quality':quality_6,
           'time_for_2_sec_frame':time_6, 
           'compression_rate': size_rate_6, 
           'macs':macs_1})

In [110]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085566…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,2.88485
macs,25387.0
model_quality,3e-05
model_quality_compared_to_original,0.96644
model_size,0.0347
time_for_2_sec_frame,0.0017


#### квантизованная дистилированная модель с прунингом

In [111]:
quality_7 = validation(quantizted_distilled_pruned.to('cpu'), val_loader,
                          melspec_val_quant, 'cpu')

102it [00:08, 11.64it/s]


In [112]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    quantizted_distilled_pruned(torch.zeros(1, config_2.n_mels, 101))
    
calc_time.__exit__()
time_7 = calc_time.t / 1000

In [113]:
size_rate_7 = get_size_in_megabytes(model) / get_size_in_megabytes(quantizted_distilled_pruned)

In [114]:
wandb.init(project="KWS")

In [115]:
wandb.log({'model_size' : get_size_in_megabytes(quantizted_distilled_pruned),
           'model_quality_compared_to_original':quality_1 / quality_7, 
           'model_quality':quality_7,
           'time_for_2_sec_frame':time_7, 
           'compression_rate': size_rate_7, 
           'macs':macs_4})

In [116]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085556…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,7.89552
macs,4603.0
model_quality,4e-05
model_quality_compared_to_original,0.67893
model_size,0.01268
time_for_2_sec_frame,0.00142


#### Квантизованная исходная модель с прунингом

In [117]:
quality_8 = validation(quantizted_pruned.to('cpu'), val_loader,
                          melspec_val_quant, 'cpu')

102it [00:09, 11.03it/s]


In [118]:
calc_time = Timer()
calc_time.__enter__()


for iter in range(1000):
    quantizted_pruned(torch.zeros(1, config_2.n_mels, 101))
    
calc_time.__exit__()
time_8 = calc_time.t / 1000

In [119]:
size_rate_8 = get_size_in_megabytes(model) / get_size_in_megabytes(quantizted_pruned)

In [120]:
wandb.init(project="KWS")

In [121]:
wandb.log({'model_size' : get_size_in_megabytes(quantizted_pruned),
           'model_quality_compared_to_original':quality_1 / quality_8, 
           'model_quality':quality_8,
           'time_for_2_sec_frame':time_8, 
           'compression_rate': size_rate_8, 
           'macs':macs_3})

In [122]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.085547…

0,1
compression_rate,▁
macs,▁
model_quality,▁
model_quality_compared_to_original,▁
model_size,▁
time_for_2_sec_frame,▁

0,1
compression_rate,2.87978
macs,25387.0
model_quality,5e-05
model_quality_compared_to_original,0.5714
model_size,0.03476
time_for_2_sec_frame,0.00173
