# Part 1: Training a Linear model with Trainable Basis Functions

[Step 1: import related libraries](#Step-1:-import-related-libraries)\
[Step 2: setting up configuration](#Step-2:-setting-up-configuration)\
[Step 3: setting up nnAudio basis functions](#Step-3:-setting-up-nnAudio-basis-functions)\
[Step 4: setting up dataset](#Step-4:-setting-up-dataset)\
[Step 5: data rebalancing](#Step-5:-data-rebalancing)\
[Step 6: data processing and loading](#Step-6:-data-processing-and-loading)\
[Step 7: setting up the ligthning module](#Step-7:-setting-up-the-ligthning-module)\
[Step 8: setting up model and applying the ligthning module](#Step-8:-setting-up-model-and-applying-the-ligthning-module)\
[Step 9: training model](#Step-9:-training-model)

## Step 1: import related libraries

In [None]:
# Libraries related to PyTorch
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import WeightedRandomSampler,DataLoader
import torch.optim as optim

# Libraries related to PyTorch Lightning
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule

# Libraries used in ligthning module
from sklearn.metrics import precision_recall_fscore_support

# Libraries related to dataset
from AudioLoader.Speech import SPEECHCOMMANDS_12C #for 12 classes KWS task

# nnAudio Front-end
from nnAudio.features.mel import MelSpectrogram

## Step 2: setting up configuration

In [None]:
device = 'cuda:0'
gpus = 1
batch_size= 100
max_epochs = 200
check_val_every_n_epoch = 2
num_sanity_val_steps = 5

data_root= './' # Download the data here
download_option= False

n_mels= 40 
#number of Mel bases

input_dim= (n_mels*101)
output_dim= 12

random_mel= False   
#To control random initial mel bases  

# nnAudio guideline for trainable basis functions 
```
nnAudio.features.mel.MelSpectrogram(trainable_mel= ,trainable_STFT=) 
```
The function above is controlling if mel bases and STFT trainable

* A. Both Mel and STFT are non-trainable: 
`trainable_mel=False, trainable_STFT=False`
* B. Mel is trainable while STFT is fixed: 
`trainable_mel=True, trainable_STFT=False`
* C. Mel is fixed while STFT is trainable: 
`trainable_mel=False, trainable_STFT=True`
* D. Both Mel and STFT are trainable:
`trainable_mel=True, trainable_STFT=True`

## Step 3: setting up nnAudio basis functions

In [None]:
mel_layer = MelSpectrogram(sr=16000, 
                           n_fft=480,
                           win_length=None,
                           n_mels=n_mels, 
                           hop_length=160,
                           window='hann',
                           center=True,
                           pad_mode='reflect',
                           power=2.0,
                           htk=False,
                           fmin=0.0,
                           fmax=None,
                           norm=1,
                           trainable_mel=True,
                           trainable_STFT=False,
                           verbose=True)

## Step 4: setting up dataset

In [None]:
trainset = SPEECHCOMMANDS_12C(root=data_root,
                              url='speech_commands_v0.02',
                              folder_in_archive='SpeechCommands',
                              download= download_option,subset= 'training') 

validset = SPEECHCOMMANDS_12C(root=data_root,
                              url='speech_commands_v0.02',
                              folder_in_archive='SpeechCommands',
                              download= download_option,subset= 'validation')

testset = SPEECHCOMMANDS_12C(root=data_root,
                              url='speech_commands_v0.02',
                              folder_in_archive='SpeechCommands',
                              download= download_option,subset= 'testing')

## Step 5: data rebalancing

For class weighting, rebalancing silence(10th class) and unknown(11th class) in training set

In [None]:
class_weights = [1,1,1,1,1,1,1,1,1,1,4.6,1/17]
sample_weights = [0] * len(trainset)
#create a list as per length of trainset

for idx, (data,rate,label,speaker_id, _) in enumerate(trainset):
    class_weight = class_weights[label]
    sample_weights[idx] = class_weight
#apply sample_weights in each data base on their label class in class_weight

sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights),replacement=True)

## Step 6: data processing and loading

In [None]:
#Data processing
def data_processing(data):
    waveforms = []
    labels = []
    
    for batch in data:
        waveforms.append(batch[0].squeeze(0)) #after squeeze => (audio_len) tensor # remove batch dim
        labels.append(batch[2])      
        
    waveform_padded = nn.utils.rnn.pad_sequence(waveforms, batch_first=True)  
    
    output_batch = {'waveforms': waveform_padded, 
             'labels': torch.tensor(labels),
             }
    return output_batch

# load data
trainloader = DataLoader(trainset,                                
                              collate_fn=lambda x: data_processing(x),
                                         batch_size=batch_size,sampler=sampler,num_workers=1)

validloader = DataLoader(validset,                               
                              collate_fn=lambda x: data_processing(x),
                                         batch_size=batch_size,num_workers=1)

testloader = DataLoader(testset,   
                              collate_fn=lambda x: data_processing(x),
                                        batch_size=batch_size,num_workers=1)    

## Step 7: setting up the ligthning module

In [None]:
class SpeechCommand(LightningModule):
    def training_step(self, batch, batch_idx):
        outputs, spec = self(batch['waveforms']) 
        #return outputs [2D] for calculate loss, return spec [3D] for visual
        loss = self.criterion(outputs, batch['labels'].long())

        acc = sum(outputs.argmax(-1) == batch['labels'])/outputs.shape[0] #batch wise
        
        self.log('Train/acc', acc, on_step=False, on_epoch=True)
          
        self.log('Train/Loss', loss, on_step=False, on_epoch=True)
        #log(graph title, take acc as data, on_step: plot every step, on_epch: plot every epoch)
        return loss

     
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
                       optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
        
        optimizer.step(closure=optimizer_closure)
        with torch.no_grad():
            torch.clamp_(self.mel_layer.mel_basis, 0, 1)
        #after optimizer step, do clamp function on mel_basis
        
   
    def validation_step(self, batch, batch_idx):               
        outputs, spec = self(batch['waveforms'])
        loss = self.criterion(outputs, batch['labels'].long())        
       
        self.log('Validation/Loss', loss, on_step=False, on_epoch=True)          

        output_dict = {'outputs': outputs,
                       'labels': batch['labels']}        
        return output_dict

    
    def validation_epoch_end(self, outputs):
        pred = []
        label = []
        for output in outputs:
            pred.append(output['outputs'])
            label.append(output['labels'])
        label = torch.cat(label, 0)
        pred = torch.cat(pred, 0)
        acc = sum(pred.argmax(-1) == label)/label.shape[0]
        
        self.log('Validation/acc', acc, on_step=False, on_epoch=True)    
        #use the return value from validation_step: output_dict , to calculate the overall accuracy   
        #epoch wise 
                              
    def test_step(self, batch, batch_idx):               
        outputs, spec = self(batch['waveforms'])
        loss = self.criterion(outputs, batch['labels'].long())        

        self.log('Test/Loss', loss, on_step=False, on_epoch=True)          

        output_dict = {'outputs': outputs,
                       'labels': batch['labels']}        
        return output_dict
    
    def test_epoch_end(self, outputs):
        pred = []
        label = []
        for output in outputs:
            pred.append(output['outputs'])
            label.append(output['labels'])
        label = torch.cat(label, 0)
        pred = torch.cat(pred, 0)
        
        result_dict = {}
        for key in [None, 'micro', 'macro', 'weighted']:
            result_dict[key] = {}
            p, r, f1, _ = precision_recall_fscore_support(label.cpu(), pred.argmax(-1).cpu(), average=key, zero_division=0)
            result_dict[key]['precision'] = p
            result_dict[key]['recall'] = r
            result_dict[key]['f1'] = f1
            
        acc = sum(pred.argmax(-1) == label)/label.shape[0]
        self.log('Test/acc', acc, on_step=False, on_epoch=True)
        
        self.log('Test/micro_f1', result_dict['micro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/macro_f1', result_dict['macro']['f1'], on_step=False, on_epoch=True)
        self.log('Test/weighted_f1', result_dict['weighted']['f1'], on_step=False, on_epoch=True)
        
        torch.save(result_dict, "result_dict.pt")        
        
        return result_dict        
    
    def configure_optimizers(self):
        model_param = []
        for name, params in self.named_parameters():
            if 'mel_layer.' in name:
                pass
            else:
                model_param.append(params)          
        optimizer = optim.SGD([
                                {"params": self.mel_layer.parameters(),
                                 "lr": 1e-3,
                                 "momentum": 0.9,
                                 "weight_decay": 0.001},
                                {"params": model_param,
                                 "lr": 1e-3,
                                 "momentum": 0.9,
                                 "weight_decay": 0.001}            
                              ])
        #for applying diff lr in model and mel bases function       
        return [optimizer]


## Step 8: setting up model and applying the ligthning module

In [None]:
class Linearmodel_nnAudio(SpeechCommand):
    def __init__(self): 
        super().__init__()
        self.mel_layer = mel_layer       
        self.criterion = nn.CrossEntropyLoss()
        self.linearlayer = nn.Linear(input_dim, output_dim)
   
        if random_mel == True:
            nn.init.kaiming_uniform_(self.mel_layer.mel_basis, mode='fan_in')
            self.mel_layer.mel_basis.requires_grad = False
            torch.relu_(self.mel_layer.mel_basis)
            self.mel_layer.mel_basis.requires_grad = True
            #for randomly initialize mel bases
    
    def forward(self, x): 
        #x: 2D [B, 16000]
        spec = self.mel_layer(x)  
        #spec: 3D [B, F40, T101]
        
        spec = torch.log(spec+1e-10)
        flatten_spec = torch.flatten(spec, start_dim=1) 
        #flatten_spec: 2D [B, F*T(40*101)] 
        #start_dim: flattening start from 1st dimention
        
        out = self.linearlayer(flatten_spec) 
        #out: 2D [B,number of class(12)] 
                               
        return out, spec 

net = Linearmodel_nnAudio()
net = net.to(device)

## Step 9: training model

Everytime you train a model, the trained weight will be saved in `./lightning_logs/version_<XX>` folder. You can use `Part2_tutorial` to visualize the result.

In [None]:
trainer = Trainer(gpus=gpus, max_epochs=max_epochs,
    check_val_every_n_epoch= check_val_every_n_epoch,
    num_sanity_val_steps=num_sanity_val_steps)

trainer.fit(net, trainloader, validloader)
trainer.test(net, testloader)

# This is the end of Part1_tutorial.
Feel free to move on to `Part 2: Evaluation with Trainable Basis Functions`.