# Part 3: Evaluation of the resulting model

### Welcome to the part 3 tutorial!!
Since you have trained a linear model with trainable basis functions in part 2 tutorial, now it is time to evaluate the model performance and do some visualization!\
You can use your own model weight or apply the pretrained weight that we have prepared for you (in the `trained_weight_for_tutorial3` folder)

### Let's start!

[Step 1: import related libraries](#Step-1:-import-related-libraries)\
[Step 2: setting up configuration](#Step-2:-setting-up-configuration)\
[Step 3: setting up nnAudio basis functions](#Step-3:-setting-up-nnAudio-basis-functions)\
[Step 4: loading the dataset](#Step-4:-loading-the-dataset)\
[Step 5: data processing and loading](#Step-5:-data-processing-and-loading)\
[Step 6: setting up the lightning module](#Step-6:-setting-up-the-lightning-module)\
[Step 7: defining the model](#Step-7:-defining-the-model)\
[Step 8: loading pre-trained weight to the model](#Step-8:-loading-pre-trained-weight-to-the-model)\
[Step 9: evaluating the model performance](#Step-9:-evaluating-the-model-performance)

[Visualizing the result](#Visualizing-the-result)
* [Visualizing the Mel bins](#Visualizing-the-Mel-bins)
* [Visualizing the short-time Fourier transform (STFT)](#Visualizing-the-STFT)



## Step 1: import related libraries

In [None]:
# Libraries related to PyTorch
import torch
from torch import Tensor 
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

# Libraries related to PyTorch Lightning
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule

# Libraries used for visualization
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import numpy as np
import re
import itertools

#Libraries related to dataset
from AudioLoader.Speech import SPEECHCOMMANDS_12C #for 12 classes KWS task
from AudioLoader.Speech import idx2name, name2idx

# nnAudio Front-end
from nnAudio.features.mel import MelSpectrogram, STFT


## Step 2: setting up configuration

In [None]:
device = 'cuda:0'
gpus = 1
batch_size= 100
max_epochs = 200
check_val_every_n_epoch = 2
num_sanity_val_steps = 5

data_root= './' # Download the data here
download_option= False

n_mels= 40 
#number of Mel bins

input_dim= (n_mels*101)
output_dim= 12

## Step 3: setting up nnAudio basis functions

* The model weight inside the trained_weight_for_tutorial2 folder is trained with `setting D`
* If you are using your own model trained weight, please match the MelSpectrogram() below with your experiment setting

In [None]:
mel_layer = MelSpectrogram(sr=16000, 
                           n_fft=480,
                           win_length=None,
                           n_mels=n_mels, 
                           hop_length=160,
                           window='hann',
                           center=True,
                           pad_mode='reflect',
                           power=2.0,
                           htk=False,
                           fmin=0.0,
                           fmax=None,
                           norm=1,
                           trainable_mel=True,
                           trainable_STFT=True,
                           verbose=True)

## Step 4: loading the dataset

In [None]:
testset = SPEECHCOMMANDS_12C(root=data_root,
                              url='speech_commands_v0.02',
                              folder_in_archive='SpeechCommands',
                              download= download_option,subset= 'testing')

## Step 5: data processing and loading

In [None]:
#data padding
def data_processing(data):
    waveforms = []
    labels = []
    
    for batch in data:
        waveforms.append(batch[0].squeeze(0)) #after squeeze => (audio_len) tensor # remove batch dim
        labels.append(batch[2])      
        
    waveform_padded = nn.utils.rnn.pad_sequence(waveforms, batch_first=True)  
    
    output_batch = {'waveforms': waveform_padded, 
             'labels': torch.tensor(labels),
             }
    return output_batch

#data loading
testloader = DataLoader(testset,   
                              collate_fn=lambda x: data_processing(x),
                                        batch_size=batch_size, num_workers =1)    

## Step 6: setting up the lightning module

In [None]:
class SpeechCommand(LightningModule):     
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
        
        optimizer.step(closure=optimizer_closure)
        with torch.no_grad():
            torch.clamp_(self.mel_layer.mel_basis, 0, 1)
        #after optimizer step, do clamp function on mel_basis
        
    def test_step(self, batch, batch_idx):               
        outputs, spec = self(batch['waveforms'])
        loss = self.criterion(outputs, batch['labels'].long())        

        self.log('Test/Loss', loss, on_step=False, on_epoch=True)          
        
        output_dict = {'outputs': outputs,
                       'labels': batch['labels']}        
        return output_dict
    

    def test_epoch_end(self, outputs):
        pred = []
        label = []
        for output in outputs:
            pred.append(output['outputs'])
            label.append(output['labels'])
        label = torch.cat(label, 0)
        pred = torch.cat(pred, 0)
        
        result_dict = {}
        for key in [None, 'micro', 'macro', 'weighted']:
            result_dict[key] = {}
            p, r, f1, _ = precision_recall_fscore_support(label.cpu(), pred.argmax(-1).cpu(), average=key, zero_division=0)
            result_dict[key]['precision'] = p
            result_dict[key]['recall'] = r
            result_dict[key]['f1'] = f1
            
        barplot(result_dict, 'precision', figsize=(4,6))
        barplot(result_dict, 'recall',figsize=(4,6))
        barplot(result_dict, 'f1',figsize=(4,6))
            
        acc = sum(pred.argmax(-1) == label)/label.shape[0]
        self.log('Test/acc', acc, on_step=False, on_epoch=True)
        
        self.log('Test/micro_f1', result_dict['micro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/macro_f1', result_dict['macro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/weighted_f1', result_dict['weighted']['f1'], on_step=False, on_epoch=True)
        
        cm = plot_confusion_matrix(label.cpu(),
                                   pred.argmax(-1).cpu(),
                                   name2idx.keys(),
                                   title='Test: Confusion matrix',
                                   normalize=False)                    
        return result_dict

    
def plot_confusion_matrix(correct_labels,
                          predict_labels,
                          labels,
                          title='Confusion matrix',
                          normalize=False):
    ''' 
    Parameters:
        correct_labels                  : These are your true classification categories.
        predict_labels                  : These are you predicted classification categories
        labels                          : This is a lit of labels which will be used to display the axix labels
        title='Confusion matrix'        : Title for your matrix
        tensor_name = 'MyFigure/image'  : Name for the output summay tensor
    Returns:
        summary: TensorFlow summary 
    Other itema to note:
        - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 
        - Currently, some of the ticks dont line up due to rotations.
    '''
    cm = confusion_matrix(correct_labels, predict_labels, labels=range(len(labels)))
    if normalize:
        cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
        cm = np.nan_to_num(cm, copy=True)
        cm = cm.astype('int')

    np.set_printoptions(precision=2)

    fig, ax = plt.subplots(1, 1, figsize=(4.5, 4.5), dpi=160, facecolor='w', edgecolor='k')
    fig.suptitle('confusion_matrix',fontsize=7)
    im = ax.imshow(cm, cmap='Oranges')

    classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
    #classes = ['\n'.join(l) for l in classes]

    tick_marks = np.arange(len(classes))

    ax.set_xlabel('Predicted', fontsize=7)
    ax.set_xticks(tick_marks)
    c = ax.set_xticklabels(classes, fontsize=5, rotation=0,  ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    ax.set_ylabel('True Label', fontsize=7)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=5, va ='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
    fig.set_tight_layout(True)

    return fig

def barplot(result_dict, title, figsize=(4,12), minor_interval=0.2, log=False):
    fig, ax = plt.subplots(1,1, figsize=figsize)
    metric = {}
    for idx, item in enumerate(result_dict[None][title]):
        metric[idx2name[idx]] = item
    xlabels = list(metric.keys())
    values = list(metric.values())
    if log:
        values = np.log(values)
    ax.barh(xlabels, values)
    ax.tick_params(labeltop=True, labelright=False)
    ax.xaxis.grid(True, which='minor')
    ax.xaxis.set_minor_locator(MultipleLocator(minor_interval))
    ax.set_ylim([-1,len(xlabels)])
    ax.set_title(title)
    ax.grid(axis='x')
    ax.grid(b=True, which='minor', linestyle='--')
    fig.savefig(f'{title}.png', bbox_inches='tight')
    fig.tight_layout() # prevent edge from missing
    return fig

## Step 7: defining the model 

In [None]:
class Linearmodel_nnAudio(SpeechCommand):
    def __init__(self): 
        super().__init__()
        self.mel_layer = mel_layer       
        self.criterion = nn.CrossEntropyLoss()
        self.linearlayer = nn.Linear(input_dim, output_dim)

    
    def forward(self, x): 
        #x: 2D [B, 16000]
        spec = self.mel_layer(x)  
        #spec: 3D [B, F40, T101]
        
        spec = torch.log(spec+1e-10)
        
        flatten_spec = torch.flatten(spec, start_dim=1) 
        #flatten_spec: 2D [B, F*T(40*101)] 
        #start_dim: flattening start from 1st dimention
        
        out = self.linearlayer(flatten_spec) 
        #out: 2D [B,number of class(12)]                               
        return out, spec 

model_nnAudo = Linearmodel_nnAudio()
model_nnAudo = model_nnAudo.to(device)

## Step 8: loading pre-trained weight to the model
Everytime you train a model after part 2 tutorial, the trained weight will be saved in `lightning_logs` folder.

We have prepared a checkpoint file inside the `trained_weight_for_tutorial2` folder which use for demostration in the following.

The detail of trained weight:
* Linearmodel in keyword spotting task
* Setting D: Both Mel and STFT are trainable ( `trainable_mel=True, trainable_STFT=True`)
* `n_mels = 40`
* Test/acc = 45.1%

In [None]:
trained_weight= model_nnAudo.load_from_checkpoint('./trained_weight_for_tutorial3/Linearmodel_nnAudio-speechcommand-mel=trainable-STFT=trainable/version_1/checkpoints/last.ckpt')

## Step 9: evaluating the model performance 
Model performance on the KWS task can be evaluated using the following metrics on the test set:
* Test/Loss (cross-entropy)
* Test/acc (accuracy)
* F1 matrix (F1 scores)
* Confusion_matrix


In [None]:
trainer = Trainer(gpus=gpus, max_epochs=max_epochs,
    check_val_every_n_epoch= check_val_every_n_epoch,
    num_sanity_val_steps=num_sanity_val_steps)

trainer.test(trained_weight, testloader)

# Visualizing the result 

* [Visualizing the Mel bins](#Visualizing-the-Mel-bins)
* [Visualizing the short-time Fourier transform (STFT)](#Visualizing-the-STFT)

We can visualise some of the learned kernels within our 1st layer of nnAudio as the weights are stored in our checkpoint file.

The structure inside the checkpoint file looks like this:
```
weight=torch.load('xxxx/checkpoints/xxxx.ckpt')
├── epoch
├── global_step
├── pytorch-lightning_version
│     
├── state_dict
│     ├─ mel_layer.mel_basis
│     ├─ mel_layer.stft.wsin
│     ├─ mel_layer.stft.wcos
│     ├─ mel_layer.stft.window_mask   
│     ├─ linearlayer.weight
│     ├─ linearlayer.bias
│     │
│     
├── callbacks
├── optimizer_states
├── lr_schedulers
```

`torch.load('xxxx/checkpoints/xxxx.ckpt')` is a dictionary, its keys can be checked in the following:

In [None]:
weight=torch.load('trained_weight_for_tutorial3/Linearmodel_nnAudio-speechcommand-mel=trainable-STFT=trainable/version_1/checkpoints/last.ckpt')

In [None]:
weight.keys()

`'state_dict'` is one of the dictionary key in the checkpoint file, it is an `OrderedDict` which including the **trained weight for basis functions (Mel bins, STFT) and layer weight (linear layer in this case)**.\
Keys for the 'state_dict' (OrderedDict) can be checked in the following:



In [None]:
weight['state_dict'].keys()

## Visualizing the Mel bins
The shape of `mel_layer.mel_basis` should be `[n_mels, (n_fft/2+1)]`, whereby n_mels is number of Mel bin and n_fft refers to the length of the windowed signal after padding with zeros.\
In this tutorial example, the shape of mel_layer.mel_basis is `[40,241]`.


In [None]:
mel_bins = weight['state_dict']['mel_layer.mel_basis']

**Individual Mel bin can be shown in the following:**
```python
plt.plot(mel_bins[i].cpu().detach().numpy())
```
Simply replace `i` with `the index of Mel base`

In [None]:
plt.plot(mel_bins[0].cpu().detach().numpy())
plt.title('Amplitude of an individual Mel bin')
plt.xlabel('No. of frequency bin')
plt.ylabel('Amplitude')
plt.show()


**40 Mel bases can be shown in the following:**

In [None]:
for i in mel_bins:
    plt.plot(i.cpu().detach().numpy()) 
    
plt.title('Amplitude of 40 Mel bins')
plt.xlabel('No. of frequency bin')
plt.ylabel('Amplitude')
plt.show()

## Visualizing the STFT
Shape of `'mel_layer.stft.wsin'` and `'mel_layer.stft.wcos'` should be `[(n_fft/2+1),1,n_fft]`\
In the case of `linear model in keyword spotting task`, their shape should be  `[241, 1, 480]`

In [None]:
wsin = weight['state_dict']['mel_layer.stft.wsin']
wcos = weight['state_dict']['mel_layer.stft.wcos']

In [None]:
# Visualizing STFT_wsin
fig, axes = plt.subplots(2,2)
for ax, kernel_num in zip(axes.flatten(), [2,10,20,50]):
    ax.plot(wsin[kernel_num,0].cpu())
    ax.set_ylim(-1,1)
    fig.suptitle('Visualizing STFT_wsin')
    
plt.setp(axes[-1, :], xlabel='No. of sample')
plt.setp(axes[:, 0], ylabel='Amplitude')
plt.show()

In [None]:
# Visualizing STFT_wcos
fig, axes = plt.subplots(2,2)
for ax, kernel_num in zip(axes.flatten(), [2,10,20,50]):
    ax.plot(wcos[kernel_num,0].cpu())
    ax.set_ylim(-1,1)
    fig.suptitle('Visualizing STFT_wcos')
    
plt.setp(axes[-1, :], xlabel='No. of sample')
plt.setp(axes[:, 0], ylabel='Amplitude')   
plt.show()

# Congratulations!  You have finished the Part 3 tutorial.
Feel free to move on to `Part 4: Using more complex non-linear models`