In [19]:
import os
import pandas as pd
import numpy as np
from datasets import Dataset, Audio, ClassLabel, Features
import torch
from transformers import ASTConfig, ASTForAudioClassification, ASTFeatureExtractor, TrainingArguments, Trainer
import evaluate


In [2]:
import json

with open("label_map.json", "r") as f:
    label2id = json.load(f)
id2label = {v: k for k, v in label2id.items()}

In [3]:
class_labels = ClassLabel(names=[id2label[i] for i in range(len(id2label))])
features = Features({
    "audio": Audio(),
    "labels": class_labels
})

def collect_data(root_dir="./Dataset/genres/", label_map_path="label_map.json"):
    data = {"audio": [], "labels": []}
    for label in sorted(os.listdir(root_dir)):
        class_dir = os.path.join(root_dir, label)
        if not os.path.isdir(class_dir):
            continue
        for filename in os.listdir(class_dir):
            if filename.endswith(".wav"):
                filepath = os.path.join(class_dir, filename)
                data["audio"].append(filepath)
                data["labels"].append(label2id[label])
    return data

datadict = collect_data()
dataset = Dataset.from_dict(datadict, features=features)
print(dataset[0], dataset[-1])

{'audio': {'path': './Dataset/genres/blues\\blues.00000.wav', 'array': array([ 0.00732422,  0.01660156,  0.00762939, ..., -0.05560303,
       -0.06106567, -0.06417847]), 'sampling_rate': 22050}, 'labels': 0} {'audio': {'path': './Dataset/genres/rock\\rock.00099.wav', 'array': array([-0.02111816, -0.03451538, -0.03536987, ...,  0.00134277,
        0.00250244, -0.00186157]), 'sampling_rate': 22050}, 'labels': 9}


In [4]:
class_names = [id2label[i] for i in range(len(id2label))]

dataset = dataset.cast_column("labels", ClassLabel(names=class_names))
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
num_labels = len(np.unique(dataset["labels"]))
print(num_labels, class_names)

Casting the dataset: 100%|██████████| 999/999 [00:00<00:00, 483511.39 examples/s]

10 ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']





In [5]:
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)

model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate

In [6]:
def preprocess_audio(batch):
    wavs = [audio["array"] for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")

    output_batch = {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}
    return output_batch

dataset = dataset.rename_column("audio", "input_values")
dataset.set_transform(preprocess_audio, output_all_columns=False)

In [7]:
if "test" not in dataset:
    dataset = dataset.train_test_split(test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels")

In [8]:
feature_extractor.do_normalize = False
mean = []
std = []

dataset["train"].set_transform(preprocess_audio, output_all_columns=False)
for i, (audio_input, labels) in enumerate(dataset["train"]):
    cur_mean = torch.mean(dataset["train"][i][audio_input])
    cur_std = torch.std(dataset["train"][i][audio_input])
    mean.append(cur_mean)
    std.append(cur_std)
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)

feature_extractor.mean = np.mean(mean)
feature_extractor.std = np.mean(std)
feature_extractor.do_normalize = True

In [9]:
config = ASTConfig.from_pretrained(pretrained_model)

config.num_labels = num_labels
config.label2id = label2id
config.id2label = id2label

model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([10]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([10, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
training_args = TrainingArguments(
    output_dir="./runs/ast_classifier",
    logging_dir="./logs/ast_classifier",
    report_to="tensorboard",
    learning_rate=5e-5,
    push_to_hub=False,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    eval_strategy="epoch",
    save_strategy="epoch",
    eval_steps=1,
    save_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_strategy="steps",
    logging_steps=20,
)

In [11]:
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    return metrics

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,
)

In [13]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.6609,0.544654,0.835,0.844033,0.835,0.832374
2,0.3607,0.637664,0.81,0.84525,0.81,0.818392
3,0.152,0.744722,0.805,0.873713,0.805,0.818974
4,0.0145,0.720271,0.865,0.882739,0.865,0.867575
5,0.0016,0.788361,0.86,0.879927,0.86,0.862045
6,0.0007,0.638245,0.86,0.872594,0.86,0.861608
7,0.0002,0.64617,0.865,0.876706,0.865,0.866513
8,0.0002,0.651236,0.865,0.876706,0.865,0.866513
9,0.0001,0.653229,0.865,0.876706,0.865,0.866513
10,0.0001,0.653944,0.865,0.876706,0.865,0.866513


TrainOutput(global_step=1000, training_loss=0.16025200420990587, metrics={'train_runtime': 545.1198, 'train_samples_per_second': 14.657, 'train_steps_per_second': 1.834, 'total_flos': 5.416235474092032e+17, 'train_loss': 0.16025200420990587, 'epoch': 10.0})

In [20]:
trainer.save_model("./saved_AST_model")

def convert_np_floats(obj):
    if isinstance(obj, dict):
        return {k: convert_np_floats(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_np_floats(i) for i in obj]
    elif isinstance(obj, (np.float32, np.float64)):
        return float(obj)
    else:
        return obj
config_dict = feature_extractor.to_dict()
cleaned_dict = convert_np_floats(config_dict)
new_feature_extractor = ASTFeatureExtractor.from_dict(cleaned_dict)
new_feature_extractor.save_pretrained("./saved_AST_model")

['./saved_AST_model\\preprocessor_config.json']

In [None]:
# tensorboard --logdir="./logs"
# for stats