In [None]:
### https://pytorch.org/audio/stable/tutorials/speech_recognition_pipeline_tutorial.html
### https://colab.research.google.com/github/m3hrdadfi/soxan/blob/main/notebooks/Eating_Sound_Collection_using_Wav2Vec2.ipynb#scrollTo=Fv62ShDsH5DZ classification with pretrained transformer as base
### https://arxiv.org/abs/2006.11477 paper

In [None]:
from torch import nn
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
import os
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torchaudio.list_audio_backends())
TRAIN_DIR = './data/train/binary_classification/yes_no/'

In [None]:
# loading data to dataset:
class SoundDataset(Dataset):

    def __init__(self, directory, gpu):
        self.directory = directory
        self.classes = os.listdir(directory)
        self.gpu = gpu
        self.class_to_num = {cl : i for i, cl in enumerate(self.classes)}
        self.num_to_class = {i : cl for i, cl in enumerate(self.classes)}
        paths = []
        for cl in self.classes:
            tmp = [os.path.join(directory+cl, path) for path in os.listdir(directory + cl)]
            paths+=tmp
        self.paths = paths
    
    
    def __len__(self):
        
        return len(self.paths)
    
    def __getitem__(self, index):
        audio_sample_path = self.paths[index]
        label = self.paths[index].split('/')[-1].split('\\')[0]

        signal, sr = torchaudio.load(audio_sample_path, format = 'wav')
        signal = signal[0]
        if self.gpu:
            signal.to(device)
        
        label_numeric = self.class_to_num[label]
        label_tensor = torch.tensor(label_numeric)
        return signal, label_tensor
        
        

In [None]:
dataset = SoundDataset(TRAIN_DIR, True)
train_dataset, validation_dataset = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size = 64, shuffle=True)

In [None]:
bundle = torchaudio.pipelines.WAV2VEC2_BASE #WAV2VEC2_ASR_BASE_960H 
model = bundle.get_model().to(device)
model

In [None]:
waveform = dataset[0][0]
waveform = waveform.reshape(1, -1).to(device)
model(waveform)[0].cpu().detach().numpy().shape

In [None]:
features = model.extract_features(waveform)[0]

In [None]:
features[-1].cpu().shape

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(12, 4))
for i, ax in zip([0, 5, 11], axes):
    feats = features[i]
    ax.imshow(feats.detach()[0].cpu(), interpolation="nearest")
axes[0].set_title(f"features from transformer layers 1, 6, 12", fontsize = 15)
fig.tight_layout()
plt.savefig('./media/features_encoder.png')

In [None]:
class Wav2Vec2ClassificationModel(nn.Module):
    def __init__(self, num_labels, hidden_size, final_dropout = 0.1, pooling = 'mean'):
        super().__init__()
        self.base_model = torchaudio.pipelines.WAV2VEC2_BASE.get_model().to(device)
        self.classification_head = Wav2Vec2ClassificationHead(num_labels, hidden_size, final_dropout)
        self.pooling = pooling
        self.loss = []
        self.val_accuracy = []
        for param in self.base_model.feature_extractor.parameters():
                param.requires_grad = False

    def forward(self, inputs):
        # Get features from the base model
        features = self.base_model(inputs)[0] # 0 because it returns a tuple, we need only the first of the tuple

        logits = self.classification_head(features)

        return logits
    
    def predict(self, inputs):
        outputs = self.forward(inputs.to(device))
        if self.pooling == 'mean':
            predicted = torch.max(torch.mean(outputs, 1),1)[1]
        elif self.pooling == 'max':
            predicted = torch.max(torch.max(outputs, 1)[0],1)[1]
        elif self.pooling == 'sum':
            predicted = torch.max(torch.sum(outputs, 1),1)[1]
        return predicted

    

    def train_model(self, train_loader, val_loader, criterion, optimizer, num_epochs=10):
        for epoch in range(num_epochs):
            self.train()
            for i, (batch_inputs, batch_labels) in enumerate(train_loader):
                optimizer.zero_grad()
                outputs = self(batch_inputs.to(device))
                batch_labels = batch_labels.to(device)
                if self.pooling == 'mean':
                    outputs = torch.mean(outputs, 1)
                elif self.pooling == 'max':
                    outputs = torch.max(outputs, 1)[0]
                elif self.pooling == 'sum':
                    outputs = torch.sum(outputs, 1)
                loss = criterion(outputs, batch_labels)
                loss.backward()
                optimizer.step()
                print(f'batch number {i+1}/{len(train_loader)}, loss = { np.round(loss.item(), 4)}', end = '\r')

            # Validation
            self.eval()
            with torch.no_grad():
                total_correct = 0
                total_samples = 0
                for val_batch_inputs, val_batch_labels in val_loader:
                    val_outputs = self(val_batch_inputs.to(device))
                    
                    if self.pooling == 'mean':
                        predicted = torch.max(torch.mean(val_outputs, 1),1)[1]
                    elif self.pooling == 'max':
                        predicted = torch.max(torch.max(val_outputs, 1)[0],1)[1]
                    elif self.pooling == 'sum':
                        predicted = torch.max(torch.sum(val_outputs, 1),1)[1]
                    total_correct += (predicted == val_batch_labels.to(device)).sum().item()
                    total_samples += val_batch_labels.size(0)
                accuracy = total_correct / total_samples
                self.val_accuracy.append(accuracy)

            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}, Validation Accuracy: {accuracy}')


class Wav2Vec2ClassificationHead(nn.Module):
    """Head for wav2vec classification task."""

    def __init__(self, num_labels, hidden_size, final_dropout):
        super().__init__()
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(final_dropout)
        self.out_proj = nn.Linear(hidden_size, num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

# Binary classification - yes/no

In [None]:
dataset = SoundDataset(TRAIN_DIR, True)
train_dataset, validation_dataset = random_split(dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=64)

In [None]:
model_mean = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'mean').to(device)
model_max = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'max').to(device)
model_sum = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'sum').to(device)

### learning rate = 1e-4

In [None]:
for model in [model_mean, model_max, model_sum]:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
    print(model.pooling)
    model.train_model(train_loader, val_loader, criterion, optimizer, num_epochs=10)    

In [None]:
def plot_confusion_matrix(model, dataloader, classes, ax = None): # takes some time, because it requires for the model to pass through all the samples from the dataloader.
    true_labels = []
    predicted_labels = []
    for batch, labels in dataloader:
        predicted_labels+=list(model.predict(batch).cpu().numpy())
        true_labels+=list(labels.numpy())
    cm = confusion_matrix(true_labels, predicted_labels)
    ConfusionMatrixDisplay(confusion_matrix=cm,
                            display_labels=classes).plot(ax = ax, colorbar = False)
    return cm

In [None]:
fig, axes = plt.subplots(1,3,figsize = (12, 8))
for ax, model in zip(axes, [model_max, model_mean, model_sum]):
    plot_confusion_matrix(model, validation_dataloader, dataset.classes, ax = ax)
    ax.set_title(model.pooling)
    ax.set_ylabel('')
ax.set_ylabel('True label')
plt.savefig('./media/wav2vec2lr1e4.png', dpi = 200)

### learning rate = 1e-3

In [None]:
model_mean = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'mean').to(device)
model_max = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'max').to(device)
model_sum = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'sum').to(device)
for model in [model_mean, model_max, model_sum]:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
    print(model.pooling)
    model.train_model(train_loader, val_loader, criterion, optimizer, num_epochs=10)

In [None]:
fig, axes = plt.subplots(1,3,figsize = (12, 8))
for ax, model in zip(axes, [model_max, model_mean, model_sum]):
    plot_confusion_matrix(model, validation_dataloader, dataset.classes, ax = ax)
    ax.set_title(model.pooling)
    ax.set_ylabel('')
axes[0].set_ylabel('True label')
plt.savefig('./media/wav2vec2lr1e3.png', dpi = 200)

### learning rate = 1e-5

In [None]:
model_mean = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'mean').to(device)
model_max = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'max').to(device)
model_sum = Wav2Vec2ClassificationModel(num_labels=2, hidden_size=768, final_dropout=0.1, pooling = 'sum').to(device)
for model in [model_mean, model_max, model_sum]:
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
    print(model.pooling)
    model.train_model(train_loader, val_loader, criterion, optimizer, num_epochs=10)

In [None]:
fig, axes = plt.subplots(1,3,figsize = (12, 8))
for ax, model in zip(axes, [model_max, model_mean, model_sum]):
    plot_confusion_matrix(model, val_loader, dataset.classes, ax = ax)
    ax.set_title(model.pooling)
    ax.set_ylabel('')
axes[0].set_ylabel('True label')
plt.savefig('./media/wav2vec2lr1e5.png', dpi = 200)

# whole dataset (30 classes)

In [None]:
dataset = SoundDataset('./data/train/padded/', True)

In [None]:
dataset.classes, len(dataset)

In [None]:
from torch.utils.data import Subset
class_counts = {}
for _, label in dataset:
    label = label.item()
    if label not in class_counts:
        class_counts[label] = 0
    class_counts[label] += 1

# Calculate the desired number of samples for each class in training and validation sets
total_samples = len(dataset)
train_ratio = 0.8  # Adjust as needed
train_class_counts = {label: int(train_ratio * count) for label, count in class_counts.items()}
val_class_counts = {label: count - train_class_counts[label] for label, count in class_counts.items()}

# Create samplers for training and validation sets while maintaining class balance
train_indices = []
val_indices = []
shuffled_dataset = DataLoader(dataset, shuffle=True).dataset
for idx, (_, label) in enumerate(shuffled_dataset):
    if train_class_counts[label.item()] > 0:
        train_indices.append(idx)
        train_class_counts[label.item()] -= 1
    else:
        val_indices.append(idx)

train_dataset = Subset(shuffled_dataset, train_indices)
validation_dataset = Subset(shuffled_dataset, val_indices)

In [None]:
len(train_dataset), len(validation_dataset), len(dataset.classes)

In [None]:
validation_dataloader = DataLoader(validation_dataset, batch_size=64)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle = True)

In [None]:
full_model = Wav2Vec2ClassificationModel(num_labels=30, hidden_size=768, final_dropout=0.1, pooling = 'mean').to(device)
full_model.load_state_dict(torch.load('./media/full_W2V_epoch3.pth'))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(full_model.parameters(), lr = 1e-5)
full_model.train_model(train_dataloader, validation_dataloader, criterion, optimizer, num_epochs=1)

In [None]:
torch.save(full_model.state_dict(), './media/full_W2V_epoch5.pth')

In [None]:
true_labels = []
predicted_labels = []
for batch, labels in validation_dataloader:
    predicted_labels+=list(full_model.predict(batch).cpu().numpy())
    true_labels+=list(labels.numpy())
cm = confusion_matrix(true_labels, predicted_labels)

In [None]:
fig, ax = plt.subplots(1,1,figsize = (30,30))
ConfusionMatrixDisplay(cm, display_labels = dataset.classes).plot(ax = ax, colorbar = False)
plt.savefig('./media/wav2vec2full.png', dpi = 200)

In [None]:
diag = cm.diagonal()
np.fill_diagonal(cm, 0)

In [None]:
fig, ax = plt.subplots(1,1, figsize = (30,30))
ConfusionMatrixDisplay(cm, display_labels = dataset.classes).plot(ax = ax, colorbar=False)
plt.savefig('./media/wav2vec2full_errors.png', dpi = 200)

# error analysis

In [None]:
shuffled_dataset.class_to_num

In [None]:
predicted_labels = np.array(predicted_labels)
true_labels = np.array(true_labels)
pred_24 = np.where(predicted_labels == 24)
pred_23 = np.where(predicted_labels == 23)
true_24 = np.where(true_labels == 24)
true_23 = np.where(true_labels == 23)

In [None]:
import librosa
from IPython.display import Audio

#### true three predicted as tree

In [None]:
y, sr = librosa.load(shuffled_dataset.paths[np.setdiff1d(true_23, pred_24)[0]])
Audio(data=y, rate=16000)

In [None]:
y, sr = librosa.load(shuffled_dataset.paths[np.setdiff1d(true_23, pred_24)[1]])
Audio(data=y, rate=16000)

In [None]:
y, sr = librosa.load(shuffled_dataset.paths[np.setdiff1d(true_23, pred_24)[2]])
Audio(data=y, rate=16000)

In [None]:
y, sr = librosa.load(shuffled_dataset.paths[np.setdiff1d(true_23, pred_24)[3]])
Audio(data=y, rate=16000)

In [None]:
y, sr = librosa.load(shuffled_dataset.paths[np.setdiff1d(true_23, pred_24)[4]])
Audio(data=y, rate=16000)