### DEPENDENCIES

In [1]:
import torch
import pandas as pd
import os

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report            
from utils import resample_audio

# RUN

### Load Tokenizers

In [2]:
# from scripts.Tokenizers import TokenizersConfig, Tokenizers

# def infer_token(audio_path, checkpoint_path):
#     # load the tokenizer checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = TokenizersConfig(checkpoint['cfg'])
#     BEATs_tokenizer = Tokenizers(cfg)
#     BEATs_tokenizer.load_state_dict(checkpoint['model'])
#     BEATs_tokenizer.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     labels = BEATs_tokenizer.extract_labels(audio_input_16khz, padding_mask=None)
#     return labels

# labels = infer_token('audios/ex_baby.wav', 'checkpoints/Tokenizer_iter3_plus_AS2M.pt')

### Load Pre-trained

In [3]:
# from scripts.BEATs import BEATs, BEATsConfig

# def infer_pretrained(audio_path, checkpoint_path):
#     # load the pre-trained checkpoints
#     checkpoint = torch.load(checkpoint_path)

#     cfg = BEATsConfig(checkpoint['cfg'])
#     BEATs_model = BEATs(cfg)
#     BEATs_model.load_state_dict(checkpoint['model'])
#     BEATs_model.eval()

#     audio_input_16khz = resample_audio(audio_path)
#     representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
#     return representation
    
# representation = infer_pretrained('audios/ex_baby.wav', 'checkpoints/BEATs_iter3_plus_AS2M.pt')

### Load Fine-tuned Models

In [4]:
from scripts.BEATs import BEATs, BEATsConfig

def infer_finetuned(audio_path, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    cfg = BEATsConfig(checkpoint['cfg'])
    BEATs_model = BEATs(cfg)
    BEATs_model.load_state_dict(checkpoint['model'])
    BEATs_model.eval()

    audio_input_16khz = resample_audio(audio_path)
    probs = BEATs_model.extract_features(audio_input_16khz, padding_mask=None)[0]
    return probs, checkpoint

In [5]:
def topk_labels_prob(probs, checkpoint):
    classes = pd.read_csv('labels/class_labels_indices.csv', index_col='index')

    results = []
    for (top_label_prob, top_label_idx) in zip(*probs.topk(k=1)):
        top_label = [checkpoint['label_dict'][label_idx.item()] for label_idx in top_label_idx]
        # Get classes from AudioSet class labels indices
        tags = []
        for c in top_label:
            tag = classes[classes['mid'] == c]['display_name'].values[0]
            tags.append(tag)
        results.append([tags[0], top_label_prob.tolist()[0]])
    return results

In [6]:
def infer(audio_path, checkpoint_path):
    probs, checkpoint = infer_finetuned(audio_path, checkpoint_path)
    results = topk_labels_prob(probs, checkpoint)
    return results

In [7]:
def batch_infer(folder:str, checkpoint_path:str, output = {}):
    for name in os.listdir(folder):
        path = os.path.join(folder, name)
        if os.path.isfile(path) and path.endswith(".wav"):
            filename = os.path.basename(path)
            results = infer(path, checkpoint_path)
            output[filename] = results[0]
        elif os.path.isdir(path):
            batch_infer(path, checkpoint_path, output)
    return output

### INFERENCES - WARNING: the following command takes a while (11 min aprox.)

You need to have all the audio files ready in the 'files' folder.

In [8]:
beats_results = batch_infer('files/2023-07-18', 'checkpoints/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')

arranged_dict = [{**{'audio_id': key}, **{f'col_{i+1}': value[i] for i in range(2)}} for key, value in beats_results.items()]

# Convert the list of dictionaries into a DataFrame
df_beats_results_ASlabels = pd.DataFrame(arranged_dict)
cols = {'col_1': 'BEATs_Labels','col_2': 'BEATs_Probs'}
df_beats_results_ASlabels.rename(columns=cols, inplace=True)

df_beats_results_ASlabels.to_csv('output/df_beats_results_ASlabels.csv')
df_beats_results_ASlabels.head()

Unnamed: 0,audio_id,BEATs_Labels,BEATs_Probs
0,0adb647e97de4fc8881e4c5359d3fb12.wav,Silence,0.151737
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,Music,0.093592
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,Music,0.154701
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,Dog,0.6441
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,Animal,0.639572


In [9]:
tags = ['baby', 'cat', 'dog']


def list2dict(tags):
    if not tags:
        raise ValueError("Class list is empty.")
    classes_dict = {num + 1: tag.lower() for num, tag in enumerate(tags)}
    return classes_dict


def class_filter(input):
    input = input.lower()
    tags_dict = list2dict(tags)
    for i, tag in tags_dict.items():
        if (tag == input) or (tag in input):
            return i
    return 0


df_beats_results_ASlabels = pd.read_csv('output/df_beats_results_ASlabels.csv')
df_beats_results_numbered = df_beats_results_ASlabels.copy()

df_beats_results_numbered['BEATs_Labels'] = df_beats_results_ASlabels['BEATs_Labels'].apply(class_filter)

df_beats_results_numbered.to_csv('output/df_beats_results_numbered.csv')
df_beats_results_numbered.head()

Unnamed: 0.1,Unnamed: 0,audio_id,BEATs_Labels,BEATs_Probs
0,0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,0.151737
1,1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.093592
2,2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.154701
3,3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,0.6441
4,4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,0,0.639572


In [10]:
df_beats_results_numbered = pd.read_csv('output/df_beats_results_numbered.csv', index_col=0)
df_beats_results_numbered.head()

Unnamed: 0.1,Unnamed: 0,audio_id,BEATs_Labels,BEATs_Probs
0,0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,0.151737
1,1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.093592
2,2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.154701
3,3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,0.6441
4,4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,0,0.639572


### Metrics

In [15]:
def metrics(y_true, y_pred):
    # Accuracy
    accuracy = accuracy_score(y_true, y_pred)
    print("Precisión (Accuracy):", accuracy)

    # Precision
    precision = precision_score(y_true, y_pred, average='micro')
    print("Precisión (Precision):", precision)

    # Recall
    recall = recall_score(y_true, y_pred, average='micro')
    print("Recall:", recall)

    # F1 Score
    f1 = f1_score(y_true, y_pred, average='micro')
    print("Puntuación F1:", f1)

    # Confusion Matrix
    confusion = confusion_matrix(y_true, y_pred)
    print("Matriz de Confusión:")
    print(confusion)

    # Detailed Classification Report
    report = classification_report(y_true, y_pred)
    print("Informe de clasificación:")
    print(report)

In [16]:
labeled_dataset = pd.read_csv('labels/labeled_dataset.csv', index_col=0)
labeled_dataset.head()
labeled_dataset.shape

(183, 4)

In [17]:
merged_inferences = pd.merge(labeled_dataset, df_beats_results_numbered[['audio_id', 'BEATs_Labels']], on='audio_id')
merged_inferences.to_csv('merged_inferences.csv')
merged_inferences.head()

Unnamed: 0,audio_id,Label,HTS_Labels,ResNet_Labels,BEATs_Labels
0,0adb647e97de4fc8881e4c5359d3fb12.wav,0,,0,0
1,first_5_seconds-0adb647e97de4fc8881e4c5359d3fb...,0,0.0,0,0
2,next_5_seconds-0adb647e97de4fc8881e4c5359d3fb1...,0,0.0,0,0
3,46a32acc19f84410baa8c07ddaa6ac5a.wav,3,,3,3
4,first_5_seconds-46a32acc19f84410baa8c07ddaa6ac...,3,3.0,0,0


In [18]:
y_true = merged_inferences['Label']
y_pred = merged_inferences['BEATs_Labels']

metrics(y_true, y_pred)

Precisión (Accuracy): 0.2896174863387978
Precisión (Precision): 0.2896174863387978
Recall: 0.2896174863387978
Puntuación F1: 0.2896174863387978
Matriz de Confusión:
[[24  0  0  0]
 [ 5 17  1  0]
 [28  0  3  0]
 [96  0  0  9]]
Informe de clasificación:
              precision    recall  f1-score   support

           0       0.16      1.00      0.27        24
           1       1.00      0.74      0.85        23
           2       0.75      0.10      0.17        31
           3       1.00      0.09      0.16       105

    accuracy                           0.29       183
   macro avg       0.73      0.48      0.36       183
weighted avg       0.85      0.29      0.26       183

