### DEPENDENCIES

In [104]:
import torch
import torchaudio
import numpy as np
import math
import pandas as pd
import os, logging, typing
import sox
import soundfile

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report            
from utils import resample_audio

# RUN

### Load Tokenizers

In [105]:
# from Tokenizers import TokenizersConfig, Tokenizers

# def infer_token(audio_path, checkpoint_path):
#     # load the tokenizer checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = TokenizersConfig(checkpoint['cfg'])
#     BEATs_tokenizer = Tokenizers(cfg)
#     BEATs_tokenizer.load_state_dict(checkpoint['model'])
#     BEATs_tokenizer.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=None)
#     return labels

# labels = infer_token('audios/ex_baby.wav', 'checkpoints/Tokenizer_iter3_plus_AS2M.pt')

### Load Pre-trained

In [106]:
# from BEATs import BEATs, BEATsConfig

# def infer_pretrained(audio_path, checkpoint_path):
#     # load the pre-trained checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = BEATsConfig(checkpoint['cfg'])
#     BEATs_model = BEATs(cfg)
#     BEATs_model.load_state_dict(checkpoint['model'])
#     BEATs_model.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
#     return representation
    
# representation = infer_pretrained('audios/ex_baby.wav', 'checkpoints/BEATs_iter3_plus_AS2M.pt')

### Load Fine-tuned Models

In [107]:
from BEATs import BEATs, BEATsConfig

def infer_finetuned(audio_path, checkpoint_path):
    # load the fine-tuned checkpoints
    checkpoint = torch.load(checkpoint_path)

    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()

    audio_input_16khz = resample_audio(audio_path)
    probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
    return probs, checkpoint

In [108]:
def topk_labels_prob(probs, checkpoint):
    classes = pd.read_csv('labels/class_labels_indices.csv', index_col='index')

    results = []
    for (top_label_prob, top_label_idx) in zip(*probs.topk(k=1)):
        top_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top_label_idx]
        # Get classes from AudioSet class labels indices
        tags = []
        for c in top_label:
            tag = classes[classes['mid'] == c]['display_name'].values[0]
            tags.append(tag)
        results.append([tags[0], top_label_prob.tolist()[0]])
    return results

In [109]:
def infer(audio_path, checkpoint_path):
    probs, checkpoint = infer_finetuned(audio_path, checkpoint_path)
    results = topk_labels_prob(probs, checkpoint)
    return results

In [110]:
def batch_infer(folder:str, checkpoint_path:str, output = {}):
    for name in os.listdir(folder):
        path = os.path.join(folder, name)
        if os.path.isfile(path) and path.endswith(".wav"):
            filename = os.path.basename(path)
            results = infer(path, checkpoint_path)
            output[filename] = results[0]
        elif os.path.isdir(path):
            batch_infer(path, checkpoint_path, output)
    return output

### INFERENCES - WARNING: the following command takes a while (11 min aprox.)

You need to have all the audio files ready in the 'files' folder.

In [111]:
# beats_results = batch_infer('files/2023-07-18', 'checkpoints/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')

In [112]:
def class_filter(input):
    output = 0
    if 'Baby' in input :
        output = 1
    elif 'Cat' == input:
        output = 2
    elif 'Dog' == input:
        output = 3
    return output

In [113]:
arranged_dict = [{**{'audio_id': key}, **{f'col_{i+1}': value[i] for i in range(2)}} for key, value in beats_results.items()]

# Convert the list of dictionaries into a DataFrame
df_beats_results = pd.DataFrame(arranged_dict)
cols = {'col_1': 'BEATs_Labels','col_2': 'BEATs_Probs'}
df_beats_results.rename(columns=cols, inplace=True)
df_beats_results.head()

Unnamed: 0,audio_id,BEATs_Labels,BEATs_Probs
0,0adb647e97de4fc8881e4c5359d3fb12.wav,Silence,0.151737
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,Music,0.093592
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,Music,0.154701
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,Dog,0.6441
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,Animal,0.639572


In [114]:
df_beats_results['BEATs_Labels'] = df_beats_results['BEATs_Labels'].apply(class_filter)
df_beats_results

Unnamed: 0,audio_id,BEATs_Labels,BEATs_Probs
0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,0.151737
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.093592
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.154701
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,0.644100
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,0,0.639572
...,...,...,...
178,first_5_seconds-004b0020439747cb8dfd74129d9fba...,0,0.844228
179,next_5_seconds-004b0020439747cb8dfd74129d9fbac...,0,0.555513
180,f49029fa9ae94fd59500f0d16f4b80e8.wav,1,0.603912
181,first_5_seconds-f49029fa9ae94fd59500f0d16f4b80...,1,0.747653


In [115]:
df_beats_results.to_csv('beats_results.csv')

In [116]:
df_beats_results = pd.read_csv('beats_results.csv', index_col=0)

### Metrics

In [123]:
def metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    print("Precisión (Accuracy):", accuracy)

    # Calcular la precisión del modelo
    precision = precision_score(y_true, y_pred, average='micro')
    print("Precisión (Precision):", precision)

    # Calcular el recall del modelo
    recall = recall_score(y_true, y_pred, average='micro')
    print("Recall:", recall)

    # Calcular la puntuación F1 del modelo
    f1 = f1_score(y_true, y_pred, average='micro')
    print("Puntuación F1:", f1)

    # Obtener la matriz de confusión
    confusion = confusion_matrix(y_true, y_pred)
    print("Matriz de Confusión:")
    print(confusion)

    # Obtener un informe de clasificación detallado
    report = classification_report(y_true, y_pred)
    print("Informe de clasificación:")
    print(report)

In [118]:
labeled_dataset = pd.read_csv('labels/labeled_dataset.csv', index_col=0)
labeled_dataset.head()

Unnamed: 0,audio_id,Label,HTS_Labels,ResNet_Labels
0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,,0
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.0,0
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.0,0
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,,3
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,3,3.0,0


In [119]:
merged_inferences = pd.merge(labeled_dataset, df_beats_results[['audio_id', 'BEATs_Labels']], on='audio_id')
merged_inferences.head()

Unnamed: 0,audio_id,Label,HTS_Labels,ResNet_Labels,BEATs_Labels
0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,,0,0
1,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,,3,3
2,0dfc7cca4a104595b849f2852721620f.wav,3,,0,0
3,3b5793fe82214df1b8bc786888dd2b10.wav,3,,0,0
4,2b80e3e51bbe441f80d802c2858b3a33.wav,3,,3,3


In [126]:
y_true = merged_inferences['Label']
y_pred = merged_inferences['BEATs_Labels']

metrics(y_true, y_pred)

Precisión (Accuracy): 0.32786885245901637
Precisión (Precision): 0.32786885245901637
Recall: 0.32786885245901637
Puntuación F1: 0.32786885245901637
Matriz de Confusión:
[[ 7  0  0  0]
 [ 1  8  0  0]
 [ 9  0  1  0]
 [31  0  0  4]]
Informe de clasificación:
              precision    recall  f1-score   support

           0       0.15      1.00      0.25         7
           1       1.00      0.89      0.94         9
           2       1.00      0.10      0.18        10
           3       1.00      0.11      0.21        35

    accuracy                           0.33        61
   macro avg       0.79      0.53      0.40        61
weighted avg       0.90      0.33      0.32        61

