### DEPENDENCIES

In [1]:
import torch
import torchaudio
import numpy as np
import math
import pandas as pd
import os, logging, typing
import sox
import soundfile

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report            
from utils import resample_audio

SoX could not be found!

    If you do not have SoX, proceed here:
     - - - http://sox.sourceforge.net/ - - -

    If you do (or think that you should) have SoX, double-check your
    path variables.
    


# RUN

### Load Tokenizers

In [2]:
# from Tokenizers import TokenizersConfig, Tokenizers

# def infer_token(audio_path, checkpoint_path):
#     # load the tokenizer checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = TokenizersConfig(checkpoint['cfg'])
#     BEATs_tokenizer = Tokenizers(cfg)
#     BEATs_tokenizer.load_state_dict(checkpoint['model'])
#     BEATs_tokenizer.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=None)
#     return labels

# labels = infer_token('audios/ex_baby.wav', 'checkpoints/Tokenizer_iter3_plus_AS2M.pt')

### Load Pre-trained

In [3]:
# from BEATs import BEATs, BEATsConfig

# def infer_pretrained(audio_path, checkpoint_path):
#     # load the pre-trained checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = BEATsConfig(checkpoint['cfg'])
#     BEATs_model = BEATs(cfg)
#     BEATs_model.load_state_dict(checkpoint['model'])
#     BEATs_model.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
#     return representation
    
# representation = infer_pretrained('audios/ex_baby.wav', 'checkpoints/BEATs_iter3_plus_AS2M.pt')

### Load Fine-tuned Models

In [4]:
from BEATs import BEATs, BEATsConfig

def infer_finetuned(audio_path, checkpoint_path):
    # load the fine-tuned checkpoints
    checkpoint = torch.load(checkpoint_path)

    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()

    audio_input_16khz = resample_audio(audio_path)
    probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
    return probs, checkpoint

In [5]:
def topk_labels_prob(probs, checkpoint):
    classes = pd.read_csv('labels/class_labels_indices.csv', index_col='index')

    results = []
    for (top_label_prob, top_label_idx) in zip(*probs.topk(k=1)):
        top_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top_label_idx]
        # Get classes from AudioSet class labels indices
        tags = []
        for c in top_label:
            tag = classes[classes['mid'] == c]['display_name'].values[0]
            tags.append(tag)
        results.append([tags[0], top_label_prob.tolist()[0], tags[1], top_label_prob.tolist()[1]])
    return results

In [6]:
def infer(audio_path, checkpoint_path):
    probs, checkpoint = infer_finetuned(audio_path, checkpoint_path)
    results = topk_labels_prob(probs, checkpoint)
    return results

### Compare to previous results...

In [7]:
import pandas as pd

In [8]:
labeled_dataset = pd.read_csv('labels/labeled_dataset.csv', index_col=0)
labeled_dataset.head()

Unnamed: 0,audio_id,Label,HTS_Labels,ResNet_Labels
0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,,0
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.0,0
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.0,0
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,,3
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,3,3.0,0


In [9]:
def batch_infer(folder:str, checkpoint_path:str, output = {}):
    for name in os.listdir(folder):
        path = os.path.join(folder, name)
        if os.path.isfile(path) and path.endswith(".wav"):
            filename = os.path.basename(path)
            results = infer(path, checkpoint_path)
            output[filename] = results[0]
        elif os.path.isdir(path):
            batch_infer(path, checkpoint_path, output)
    return output

In [11]:
pepe = infer('2023-07-18\E4AAECA673A121689701159_1689701159000_2_2_0\0adb647e97de4fc8881e4c5359d3fb12.wav', 'checkpoints/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
pepe 

LibsndfileError: Error opening '2023-07-18\\E4AAECA673A121689701159_1689701159000_2_2_0\x00adb647e97de4fc8881e4c5359d3fb12.wav': System error.

### INFERENCES - WARNING: the following command takes a while (11 min aprox.)

You need to have all the audio files ready in the 'files' folder.

In [None]:
beats_results = batch_infer('2023-07-18', 'checkpoints/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')

In [None]:
arranged_dict = [{**{'audio_id': key}, **{f'col_{i+1}': value[i] for i in range(4)}} for key, value in beats_results.items()]

# Convert the list of dictionaries into a DataFrame
df_beats_results = pd.DataFrame(arranged_dict)
cols = {'col_1': 'BEATs_Labels','col_2': 'BEATs_Probs'}
df_beats_results.rename(columns=cols, inplace=True)
df_beats_results.head()

### Metrics

In [None]:
def metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    print("Precisión (Accuracy):", accuracy)

    # Calcular la precisión del modelo
    precision = precision_score(y_true, y_pred)
    print("Precisión (Precision):", precision)

    # Calcular el recall del modelo
    recall = recall_score(y_true, y_pred)
    print("Recall:", recall)

    # Calcular la puntuación F1 del modelo
    f1 = f1_score(y_true, y_pred)
    print("Puntuación F1:", f1)

    # Obtener la matriz de confusión
    confusion = confusion_matrix(y_true, y_pred)
    print("Matriz de Confusión:")
    print(confusion)

    # Obtener un informe de clasificación detallado
    report = classification_report(y_true, y_pred)
    print("Informe de clasificación:")
    print(report)