# Import Packages

In [None]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
import seaborn as sn
import torch.nn as nn
from knowledge import KD
from torch.optim import Adam
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dataset import SoundDataset
from models.MobileViT import MBViT
from torch.utils.data import DataLoader
from models.PANNs import Wavegram_Logmel_Cnn14

In [None]:
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)

# Loading Models

## Teacher Model

In [None]:
teacher = Wavegram_Logmel_Cnn14()
for p in teacher.parameters(): p.requires_grad=False

## Student Model

In [None]:
student = MBViT()

## Models Hyperparams

In [None]:
# Learning Rate
Lr = 0.0001
# Optimizer
student_optim = Adam(student.parameters(), lr=Lr, weight_decay=0)

# Dataset

In [None]:
train_df = pd.read_csv('ICBHI/Train.csv')
val_df = pd.read_csv('ICBHI/Val.csv')


### Creating Dataloader

In [None]:
train_ds = SoundDataset(train_df, '/ICBHI/Train')
val_ds = SoundDataset(val_df, '/ICBHI/Val')

In [None]:
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False)

# Knwoledge Distillation 

In [None]:
# Distillation instance
temperature = 3. 
alpha = 0.5

distiller = KD(teacher, student, train_loader, val_loader, student_optim, temperature, alpha)

## Training Student

In [None]:
epochs = 100
# Learning rate scheduler
step_size = 100
gamma = 0.5

hist = distiller.fit_student(epochs=epochs)

## Learing Curves

In [None]:
x_arr = np.arange(len(hist[0])) + 1

fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(1, 2, 1)
ax.plot(x_arr, hist[0], '-o', label='Train loss')
ax.plot(x_arr, hist[1], '--<', label='Validation loss')
ax.set_xlabel('Epoch', size=15)
ax.set_ylabel('Loss', size=15)
ax.legend(fontsize=15)
ax = fig.add_subplot(1, 2, 2)
ax.plot(x_arr, hist[2], '-o', label='Train acc.')
ax.plot(x_arr, hist[3], '--<', label='Validation acc.')
ax.legend(fontsize=15)
ax.set_xlabel('Epoch', size=15)
ax.set_ylabel('Accuracy', size=15)

#plt.savefig('figures/14_13.png')
plt.show()

## Evaluating 

In [None]:
test_df = pd.read_csv('ICBHI/Test.csv')
test_ds = SoundDataset(test_df, 'ICBHI/Test')
test_dataloader = DataLoader(test_ds, batch_size=1, shuffle=False)

In [None]:
test_metrics = distiller.evaluate(test_loader=test_dataloader)

## Confusion Matrix

In [None]:
classes_names = ['Normal','Crackles','Wheezes','Both']
plt.figure(figsize = (8,5))
ax = sn.heatmap(test_metrics['CM'], annot=True, fmt=".0f", cmap ='Blues',xticklabels=classes_names, yticklabels=classes_names,
           linecolor='k',cbar=False)
ax.set_xlabel('Predicted Label', fontsize=12)    
ax.set_ylabel('True Label', fontsize=12)       
ax.set_xticklabels(classes_names,rotation=15);
#plt.savefig("cm.pdf", dpi=200), plt.show();

## Model Performance

In [None]:
print('Results on Test Set \n')
print(f"Accuracy: {test_metrics['Accuracy']:.4f}")
print(f"Roc AUC: {test_metrics['Roc_AUC']:.4f}")
# Sensitivity
print(f"Sensitivity : {test_metrics['Sensitivity']:.4f}")
# Specificity
print(f"Specificity : {test_metrics['Specificity']:.4f}")
# Score
print(f"Score : {test_metrics['Score']:.4f}")