## Importing libraries

In [None]:
!pip install torch torchmetrics torchaudio datasets transformers scikit-learn matplotlib wandb torchcodec

from datasets import load_dataset
from transformers import (
    AutoFeatureExtractor,
    AutoModelForAudioClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
import librosa
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
from google.colab import drive
import os
from google.colab import userdata
import torch
import wandb
import torchmetrics


## Getting paths for metadata and getting a key for wandb

In [None]:
DATA_DIR = '/content/'
TRAIN_PATH = os.path.join(DATA_DIR, 'train.csv')
TEST_PATH = os.path.join(DATA_DIR, 'val.csv')


wandb_kay = userdata.get('WANDB')

## Log in to wandb

In [None]:
wandb.login(key=wandb_kay)

wandb.init(
    project="audio_test1",
)


## Cloning your google drive

In [None]:
drive.mount('/content/drive')

## Selecting the target directory

In [None]:
!rsync -a "/content/drive/MyDrive/audio_cls/" .

## Downloading a dataset

In [None]:
dataset = load_dataset(
    "csv",
    data_files={
        "train": TRAIN_PATH,
        "val": TEST_PATH
    }
)
dataset

## fix of path for a google colab

In [None]:
DATA_DIR_TRAIN = "/content/train"
DATA_DIR_TEST  = "/content/test"

def fix_path_train(example):
    tokens = example["path"].split('D:\\audio_cls_coursework\\data\\train\\')
    example["path"] = f"{DATA_DIR_TRAIN}/{tokens[1]}"
    return example

def fix_path_test(example):
    tokens = example["path"].split('D:\\audio_cls_coursework\\data\\test\\')
    example["path"] = f"{DATA_DIR_TEST}/{tokens[1]}"
    return example

def fix_label(example):
    example["labels"] = example.pop("target")
    return example

train_df = dataset['train']
valid_df = dataset['val']

train_df = train_df.map(fix_path_train)
valid_df = valid_df.map(fix_path_test)

train_df = train_df.map(fix_label)
valid_df = valid_df.map(fix_label)


## Initial model

In [None]:
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"
NUM_LABELS = 4

feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)

model = AutoModelForAudioClassification.from_pretrained(
    MODEL_NAME,
    num_labels=NUM_LABELS,
    ignore_mismatched_sizes=True
)

## Downloading audio samples

In [None]:
def preprocess(batch):
    waveform, sr = librosa.load(batch["path"], sr=16000)
    inputs = feature_extractor(
        waveform,
        sampling_rate=16000,
        return_tensors="pt"
    )
    batch["input_values"] = inputs["input_values"].squeeze(0).tolist()
    return batch

train_df = train_df.map(preprocess)
valid_df = valid_df.map(preprocess)

## Implement a loss computing

In [None]:
KEY2LOSSES = {'ce': torch.nn.CrossEntropyLoss}

def compute_loss(
    model,
    inputs,
    loss_name="ce",
    return_outputs=False,
    loss_kwargs=None,
    multilabel=False
):
    labels = inputs.pop("labels")
    outputs = model(**inputs)
    logits = outputs.logits
    assert loss_kwargs and isinstance(loss_kwargs, dict) and len(loss_kwargs) > 0, \
        "`loss_kwargs` must be a non-empty dict."
    loss_kwargs = loss_kwargs
    loss_func = KEY2LOSSES[loss_name](**loss_kwargs)
    if loss_name == "focal" and multilabel:
        labels = labels.float()
    elif loss_name == "ce":
        labels = labels.long()
    loss = loss_func(logits, labels)
    return (loss, outputs) if return_outputs else loss



## Implement metrics computing

In [None]:
def compute_metrics_hf(eval_pred):
    logits, labels = eval_pred
    logits = torch.tensor(logits)
    labels = torch.tensor(labels).long()
    probs = torch.nn.functional.softmax(logits, dim=-1)
    preds = probs.argmax(dim=-1)
    labels_np = labels.cpu().numpy()
    preds_np = preds.cpu().numpy()
    return {
        "accuracy": accuracy_score(labels_np, preds_np),
        "f1_macro": f1_score(labels_np, preds_np, average="macro"),
        "precision_macro": precision_score(labels_np, preds_np, average="macro"),
        "recall_macro": recall_score(labels_np, preds_np, average="macro"),
        "rocauc": torchmetrics.functional.auroc(
            probs, labels, task="multiclass", num_classes=probs.shape[-1]
        ).item(),
    }

## Implement a data collator

In [None]:
def collate_fn(features):
    xs = torch.stack(
        [torch.tensor(f["input_values"], dtype=torch.float32) for f in features]
    )
    labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
    return {"input_values": xs, "labels": labels}

## Initial a training config

In [None]:

train_batch_size = 48
val_batch_size = 24
EPOCHS = 200


early_stopping = EarlyStoppingCallback(
    early_stopping_patience=5,
    early_stopping_threshold=0.0
)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=val_batch_size,
    num_train_epochs=EPOCHS,
    eval_strategy="steps",
    save_steps=10,
    eval_steps=1,
    logging_steps=1,
    report_to="wandb",
    fp16=False,
    gradient_checkpointing=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_df,
    eval_dataset=valid_df,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics_hf,
    data_collator=collate_fn,
    callbacks=[early_stopping]
)

## Train a model

In [None]:
trainer.train()

In [None]:
trainer.save_model("./results/best_model")  # збереження локально
feature_extractor.save_pretrained("./results/best_model")  # якщо є tokenizer/feature_extractor

# логування у W&B
artifact = wandb.Artifact("best_model", type="model")
artifact.add_dir("./results/best_model")
wandb.log_artifact(artifact)