# Part 2: Evaluation with Trainable Basis Functions

In [None]:
# Libraries related to PyTorch
import torch
from torch import Tensor
import torchaudio 
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import WeightedRandomSampler,DataLoader
import torch.optim as optim

# Libraries related to hydra
import hydra
from hydra.utils import to_absolute_path
from omegaconf import DictConfig, OmegaConf

# custom packages
import models as Model 

# Libraries related to PyTorch Lightning
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule

# Libraries used in ligthning module
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from sklearn.metrics import precision_recall_fscore_support
from dataset.speechcommands import idx2name, name2idx
from sklearn.metrics import confusion_matrix
import numpy as np
import re
import itertools

#Libraried related to dataset
from AudioLoader.Speech import SPEECHCOMMANDS_12C #for 12 classes KWS task

# nnAudio Front-end
from nnAudio.features.mel import MelSpectrogram, STFT


# Configuration 

In [None]:
device = 'cuda:0'
gpus = 1
batch_size= 100
max_epochs = 200
check_val_every_n_epoch = 2
num_sanity_val_steps = 5

data_root= './' # Download the data here
download_option= False

n_mels= 40 
#number of Mel bases

input_dim= (n_mels*101)
output_dim= 12

random_mel= False   
#To control random initial mel bases  

# nnAudio guideline for trainable basis functions 
```
nnAudio.features.mel.MelSpectrogram(trainable_mel= ,trainable_STFT=) 
```
The function above is controlling if mel bases and STFT trainable

* A. Both Mel and STFT are non-trainable: 
`trainable_mel=False, trainable_STFT=False`
* B. Mel is trainable while STFT is fixed: 
`trainable_mel=True, trainable_STFT=False`
* C. gMel is fixed while gSTFT is trainable: 
`trainable_mel=False, trainable_STFT=True`
* D. Both gMel and gSTFT are trainable:
`trainable_mel=True, trainable_STFT=True`

In [None]:
mel_layer = MelSpectrogram(sr=16000, 
                           n_fft=480,
                           win_length=None,
                           n_mels=n_mels, 
                           hop_length=160,
                           window='hann',
                           center=True,
                           pad_mode='reflect',
                           power=2.0,
                           htk=False,
                           fmin=0.0,
                           fmax=None,
                           norm=1,
                           trainable_mel=True,
                           trainable_STFT=False,
                           verbose=True)

# Setting up dataset

In [None]:
testset = SPEECHCOMMANDS_12C(root=data_root,
                              url='speech_commands_v0.02',
                              folder_in_archive='SpeechCommands',
                              download= download_option,subset= 'testing')

# Data processing and loading

In [None]:
#Data processing
def data_processing(data):
    waveforms = []
    labels = []
    
    for batch in data:
        waveforms.append(batch[0].squeeze(0)) #after squeeze => (audio_len) tensor # remove batch dim
        labels.append(batch[2])      
        
    waveform_padded = nn.utils.rnn.pad_sequence(waveforms, batch_first=True)  
    
    output_batch = {'waveforms': waveform_padded, 
             'labels': torch.tensor(labels),
             }
    return output_batch

# load data
testloader = DataLoader(testset,   
                              collate_fn=lambda x: data_processing(x),
                                        batch_size=batch_size)    

# Ligthning module

In [None]:
class SpeechCommand(LightningModule):
    def training_step(self, batch, batch_idx):
        outputs, spec = self(batch['waveforms']) 
        #return outputs [2D] for calculate loss, return spec [3D] for visual
        loss = self.criterion(outputs, batch['labels'].long())

        acc = sum(outputs.argmax(-1) == batch['labels'])/outputs.shape[0] #batch wise
        
        self.log('Train/acc', acc, on_step=False, on_epoch=True)
        if batch_idx == 0:
            self.log_images(spec, 'Train/Spec')
            cm = plot_confusion_matrix(batch['labels'].cpu(),
                                       outputs.argmax(-1).cpu(),
                                       name2idx.keys(),
                                       title='Train: Confusion matrix',
                                       normalize=False)
            self.logger.experiment.add_figure('Train/confusion_maxtrix', cm, global_step=self.current_epoch)            
        self.log('Train/Loss', loss, on_step=False, on_epoch=True)
        #log(graph title, take acc as data, on_step: plot every step, on_epch: plot every epoch)
        return loss

     
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
        
        optimizer.step(closure=optimizer_closure)
        with torch.no_grad():
            torch.clamp_(self.mel_layer.mel_basis, 0, 1)
        #after optimizer step, do clamp function on mel_basis (only applicable for nnAudio)
        #FastAudio internal has clamp function
        
   
    def validation_step(self, batch, batch_idx):               
        outputs, spec = self(batch['waveforms'])
        loss = self.criterion(outputs, batch['labels'].long())        
       
        self.log('Validation/Loss', loss, on_step=False, on_epoch=True)          

        if batch_idx == 0:
            fig, axes = plt.subplots(1,1)            
            mel_filter_banks = self.mel_layer.mel_basis
            for i in mel_filter_banks:
                axes.plot(i.cpu())

            self.logger.experiment.add_figure(
                'Validation/MelFilterBanks',
                fig,
                global_step=self.current_epoch)
                
        #these is for plot mel filter band in nnAudio 
        #fbank_matrix contain all FastAudio filterbank value (mel bases)
        
        if batch_idx == 0:           
            fig, axes = plt.subplots(2,2)
            for ax, kernel_num in zip(axes.flatten(), [2,10,20,50]):
                ax.plot(self.mel_layer.stft.wsin[kernel_num,0].cpu())  #STFT in included in Melspectrogram()
                ax.set_ylim(-1,1)
                fig.suptitle('sin')

            self.logger.experiment.add_figure(
                    'Validation/sin',
                    fig,
                    global_step=self.current_epoch)

            fig, axes = plt.subplots(2,2)
            for ax, kernel_num in zip(axes.flatten(), [2,10,20,50]):
                ax.plot(self.mel_layer.stft.wcos[kernel_num,0].cpu())
                ax.set_ylim(-1,1)
                fig.suptitle('cos')

            self.logger.experiment.add_figure(
                    'Validation/cos',
                    fig,
                    global_step=self.current_epoch)

        self.log_images(spec, 'Validation/Spec')
        #plot log_images for 1st epoch_1st batch
        
            
        output_dict = {'outputs': outputs,
                       'labels': batch['labels']}        
        return output_dict

    
    def validation_epoch_end(self, outputs):
        pred = []
        label = []
        for output in outputs:
            pred.append(output['outputs'])
            label.append(output['labels'])
        label = torch.cat(label, 0)
        pred = torch.cat(pred, 0)
        acc = sum(pred.argmax(-1) == label)/label.shape[0]
        
        cm = plot_confusion_matrix(label.cpu(),
                                   pred.argmax(-1).cpu(),
                                   name2idx.keys(),
                                   title='Validation: Confusion matrix',
                                   normalize=False)
        self.logger.experiment.add_figure('Validation/confusion_maxtrix', cm, global_step=self.current_epoch)
        
        self.log('Validation/acc', acc, on_step=False, on_epoch=True)    
        #use the return value from validation_step: output_dict , to calculate the overall accuracy   
        #epoch wise 
                              
    def test_step(self, batch, batch_idx):               
        outputs, spec = self(batch['waveforms'])
        loss = self.criterion(outputs, batch['labels'].long())        

        self.log('Test/Loss', loss, on_step=False, on_epoch=True)          

        if batch_idx == 0:
            fig, axes = plt.subplots(1,1)           
            mel_filter_banks = self.mel_layer.mel_basis
            for i in mel_filter_banks:
                axes.plot(i.cpu())

            self.logger.experiment.add_figure(
                'Test/MelFilterBanks',
                fig,
                global_step=self.current_epoch)
                
        #for plotting mel bases in nnAudio 
        #fbank_matrix contain all FastAudio filterbank value (mel bases)
            
            self.log_images(spec, 'Test/Spec')
        #plot log_images for 1st epoch_1st batch
        
        output_dict = {'outputs': outputs,
                       'labels': batch['labels']}        
        return output_dict

    
    def test_epoch_end(self, outputs):
        pred = []
        label = []
        for output in outputs:
            pred.append(output['outputs'])
            label.append(output['labels'])
        label = torch.cat(label, 0)
        pred = torch.cat(pred, 0)
        
        result_dict = {}
        for key in [None, 'micro', 'macro', 'weighted']:
            result_dict[key] = {}
            p, r, f1, _ = precision_recall_fscore_support(label.cpu(), pred.argmax(-1).cpu(), average=key, zero_division=0)
            result_dict[key]['precision'] = p
            result_dict[key]['recall'] = r
            result_dict[key]['f1'] = f1
            
        barplot(result_dict, 'precision')
        barplot(result_dict, 'recall')
        barplot(result_dict, 'f1')
            
        acc = sum(pred.argmax(-1) == label)/label.shape[0]
        self.log('Test/acc', acc, on_step=False, on_epoch=True)
        
        self.log('Test/micro_f1', result_dict['micro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/macro_f1', result_dict['macro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/weighted_f1', result_dict['weighted']['f1'], on_step=False, on_epoch=True)
        
        cm = plot_confusion_matrix(label.cpu(),
                                   pred.argmax(-1).cpu(),
                                   name2idx.keys(),
                                   title='Test: Confusion matrix',
                                   normalize=False)
        self.logger.experiment.add_figure('Test/confusion_maxtrix', cm, global_step=self.current_epoch)        
        
        torch.save(result_dict, "result_dict.pt")        
        
        return result_dict
        
    def log_images(self, tensors, key):
        fig, axes = plt.subplots(2,2, figsize=(12,5), dpi=100)
        for ax, tensor in zip(axes.flatten(), tensors):
            ax.imshow(tensor.cpu().detach(), aspect='auto', origin='lower', cmap='jet')
        plt.tight_layout()
        self.logger.experiment.add_figure(f"{key}", fig, global_step=self.current_epoch)
        plt.close(fig)
        #plot images in TensorBoard        
    
    
    def configure_optimizers(self):
        model_param = []
        for name, params in self.named_parameters():
            if 'mel_layer.' in name:
                pass
            else:
                model_param.append(params)          
        optimizer = optim.SGD([
                                {"params": self.mel_layer.parameters(),
                                 "lr": 1e-3,
                                 "momentum": 0.9,
                                 "weight_decay": 0.001},
                                {"params": model_param,
                                 "lr": 1e-3,
                                 "momentum": 0.9,
                                 "weight_decay": 0.001}            
                              ])
        #for applying diff lr in model and mel bases function       


def barplot(result_dict, title, figsize=(4,12), minor_interval=0.2, log=False):
    fig, ax = plt.subplots(1,1, figsize=figsize)
    metric = {}
    for idx, item in enumerate(result_dict[None][title]):
        metric[idx2name[idx]] = item
    xlabels = list(metric.keys())
    values = list(metric.values())
    if log:
        values = np.log(values)
    ax.barh(xlabels, values)
    ax.tick_params(labeltop=True, labelright=False)
    ax.xaxis.grid(True, which='minor')
    ax.xaxis.set_minor_locator(MultipleLocator(minor_interval))
    ax.set_ylim([-1,len(xlabels)])
    ax.set_title(title)
    ax.grid(axis='x')
    ax.grid(b=True, which='minor', linestyle='--')
    fig.savefig(f'{title}.png', bbox_inches='tight')
    fig.tight_layout() # prevent edge from missing
#         fig.set_tight_layout(True)
    return fig
          
    
def plot_confusion_matrix(correct_labels,
                          predict_labels,
                          labels,
                          title='Confusion matrix',
                          normalize=False):
    ''' 
    Parameters:
        correct_labels                  : These are your true classification categories.
        predict_labels                  : These are you predicted classification categories
        labels                          : This is a lit of labels which will be used to display the axix labels
        title='Confusion matrix'        : Title for your matrix
        tensor_name = 'MyFigure/image'  : Name for the output summay tensor

    Returns:
        summary: TensorFlow summary 

    Other itema to note:
        - Depending on the number of category and the data , you may have to modify the figzie, font sizes etc. 
        - Currently, some of the ticks dont line up due to rotations.
    '''
    cm = confusion_matrix(correct_labels, predict_labels, labels=range(len(labels)))
    if normalize:
        cm = cm.astype('float')*10 / cm.sum(axis=1)[:, np.newaxis]
        cm = np.nan_to_num(cm, copy=True)
        cm = cm.astype('int')

    np.set_printoptions(precision=2)

    fig, ax = plt.subplots(1, 1, figsize=(7, 7), dpi=160, facecolor='w', edgecolor='k')
    im = ax.imshow(cm, cmap='Oranges')

    classes = [re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', x) for x in labels]
    #classes = ['\n'.join(l) for l in classes]

    tick_marks = np.arange(len(classes))

    ax.set_xlabel('Predicted', fontsize=7)
    ax.set_xticks(tick_marks)
    c = ax.set_xticklabels(classes, fontsize=6, rotation=0,  ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    ax.set_ylabel('True Label', fontsize=7)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=6, va ='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")
    fig.set_tight_layout(True)


    return fig

# Setting up model

In [None]:
class Linearmodel_nnAudio(SpeechCommand):
    def __init__(self): 
        super().__init__()
        self.mel_layer = mel_layer       
        self.criterion = nn.CrossEntropyLoss()
        self.linearlayer = nn.Linear(input_dim, output_dim)
        #cfg.model.args.input_dim will be calculated in training script 
   
        if random_mel == True:
            nn.init.kaiming_uniform_(self.mel_layer.mel_basis, mode='fan_in')
            self.mel_layer.mel_basis.requires_grad = False
            torch.relu_(self.mel_layer.mel_basis)
            self.mel_layer.mel_basis.requires_grad = True
            #for randomly initialize mel bases
    
    def forward(self, x): 
        #x: 2D [B, 16000]
        spec = self.mel_layer(x)  
        #spec: 3D [B, F40, T101]
        return out, spec 
        spec = torch.log(spec+1e-10)
        flatten_spec = torch.flatten(spec, start_dim=1) 
        #flatten_spec: 2D [B, F*T(40*101)] 
        #start_dim: flattening start from 1st dimention
        
        out = self.linearlayer(flatten_spec) 
        #out: 2D [B,number of class(12)] 
                               


# nnAudio is integrated into the model at line 50 of models/nnAudio_model.py 
net = Linearmodel_nnAudio()

In [None]:
net = net.to(device)

In [None]:
net.linearlayer

# Loading pretrained weight

In [None]:
# net.load_from_checkpoint('/workspace/helen/trainable-STFT-Mel/multirun/2022-04-02/23-01-35/1/SGD-n_mels=40-Linearmodel_nnAudio-mel=True-STFT=False-speechcommand/version_1/checkpoints/last.ckpt')

In [None]:
net.load_from_checkpoint('pretrained_weight/SGD-n_mels=40-Linearmodel_nnAudio-mel=True-STFT=False-speechcommand/version_1/checkpoints/last.ckpt')

# Testing the model performance

In [None]:
trainer = Trainer(gpus=gpus, max_epochs=max_epochs,
    check_val_every_n_epoch= check_val_every_n_epoch,
    num_sanity_val_steps=num_sanity_val_steps)

trainer.test(net, testloader)
#added validloader, in order to reach validation_step