In [9]:
import evaluate
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, Audio, ClassLabel, Features
from transformers import (
    ASTFeatureExtractor,
    ASTConfig,
    ASTForAudioClassification,
    Trainer,
    TrainingArguments,
)


esc50 = load_dataset("ashraq/esc50", split="train[:15%]")
class_labels = ClassLabel(names=["bang", "dog_bark"])

features = Features(
    {
        "audio": Audio(),
        "labels": class_labels,
    }
)

Repo card metadata block was not found. Setting CardData to empty.


In [10]:
# get target value - class name mappings
df = esc50.select_columns(["target", "category"]).to_pandas()
class_names = df.iloc[np.unique(df["target"], return_index=True)[1]][
    "category"
].to_list()

# cast target and audio column
esc50 = esc50.cast_column("target", ClassLabel(names=class_names))
esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000))


# rename the target feature
esc50 = esc50.rename_column("target", "labels")
num_labels = len(np.unique(esc50["labels"]))

pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate

In [11]:
def preprocess_audio(batch):
    wavs = [audio["array"] for audio in batch["input_values"]]
    # inputs are spectrograms as torch.tensors now
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")

    output_batch = {
        model_input_name: inputs.get(model_input_name),
        "labels": list(batch["labels"]),
    }
    return output_batch

In [12]:
esc50 = esc50.rename_column("audio", "input_values")
esc50 = esc50.cast_column(
    "input_values", Audio(sampling_rate=feature_extractor.sampling_rate)
)

# Filter out classes with fewer than 2 members
label_counts = pd.Series(esc50["labels"]).value_counts()
valid_labels = [label for label, count in label_counts.items() if count >= 2]
esc50 = esc50.filter(lambda example: example["labels"] in valid_labels)

if "test" not in esc50:
    esc50 = esc50.train_test_split(
        test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels"
    )

In [13]:
# calculate values for normalization
feature_extractor.do_normalize = False  # we set normalization to False in order to calculate the mean + std of the dataset
mean = []
std = []

for i, batch in enumerate(esc50["train"]):
    audio_input = torch.tensor(batch["input_values"]["array"])
    cur_mean = torch.mean(audio_input)
    cur_std = torch.std(audio_input)
    mean.append(cur_mean)
    std.append(cur_std)

feature_extractor.mean = np.mean(mean)
feature_extractor.std = np.mean(std)
feature_extractor.do_normalize = True


In [14]:
# we use the transformation w/o augmentation on the training dataset to calculate the mean + std
esc50["train"].set_transform(preprocess_audio, output_all_columns=False)
esc50["test"].set_transform(preprocess_audio, output_all_columns=False)



In [15]:
# Load configuration from the pretrained model
config = ASTConfig.from_pretrained(pretrained_model)
# Update configuration with the number of labels in our dataset
config.num_labels = num_labels
# Initialize the model with the updated configuration
model = ASTForAudioClassification.from_pretrained(
    pretrained_model, config=config, ignore_mismatched_sizes=True
)
model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
training_args = TrainingArguments(
    output_dir="./runs/ast_classifier",
    logging_dir="./logs/ast_classifier",
    report_to="tensorboard",
    learning_rate=5e-5,
    push_to_hub=False,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    eval_strategy="epoch",
    save_strategy="epoch",
    eval_steps=1,
    save_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_strategy="steps",
    logging_steps=20,
)

In [17]:
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")
AVERAGE = "macro" if config.num_labels > 2 else "binary"


def compute_metrics(eval_pred):
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(
        precision.compute(
            predictions=predictions, references=eval_pred.label_ids, average=AVERAGE
        )
    )
    metrics.update(
        recall.compute(
            predictions=predictions, references=eval_pred.label_ids, average=AVERAGE
        )
    )
    metrics.update(
        f1.compute(
            predictions=predictions, references=eval_pred.label_ids, average=AVERAGE
        )
    )
    return metrics


In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=esc50["train"],
    eval_dataset=esc50["test"],
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 