In [1]:
# point huggingface cache to /tmp
!export HF_HOME="/tmp/.cache/huggingface"
!export HF_DATASETS_CACHE="/tmp/.cache/huggingface/datasets"
!export TRANSFORMERS_CACHE="/tmp/.cache/huggingface/models"

import warnings
import os
warnings.simplefilter(action='ignore', category=FutureWarning)
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["HF_DATASETS_CACHE"]="/tmp/.cache/huggingface/datasets"
#os.environ["TRANSFORMERS_CACHE"]="/tmp/.cache/huggingface/models"

RELOAD_TRAIN = False
RELOAD_VAL = False

In [2]:
import datasets
import torch
import numpy as np
import utils.EvalMetrics as EvalMetrics

from data_parsers.Parser import AudioCapsParser, ClothoParser, SplitType
from utils.DataTools import *
from utils.LossFunctions import contrastiveCE, create_contrastive_loss
from models.AudioTextRetriever import (AudioTextRetrieverCrossAtt, AudioTextRetrieverCrossAtt2, 
                                       AudioTextRetrieverCrossAtt3, TemporalAudioTextRetrieverCrossAtt,
                                       AudioTextRetrieverWithMLP, AudioTextRetrieverSelfAtt)
from models.AudioEncoders import ASTEncoder, TemporalASTEncoder
from models.TextEncoders import RoBERTaEncoder, TemporalRoBERTaEncoder
from pathlib import Path
from transformers import EvalPrediction, Trainer, TrainingArguments
from typing import Dict

device = "cuda" if torch.cuda.is_available() else "cpu"
#datasets.config.IN_MEMORY_MAX_SIZE = 128 * 2**30
HF_DATASETS_DIR = Path("/tmp/kokcz/datasets/huggingface")

In [3]:
clotho = ClothoParser("../datasets/clotho")
audiocaps = AudioCapsParser("/tmp/kokcz/datasets/audiocaps")

TRAIN_SETS = [
    {"parser": clotho, "pos_samples":5, "neg_samples": 5},
    {"parser": audiocaps, "pos_samples":1, "neg_samples": 2},
]

VAL_SETS = [
    {"name": "Clotho", "parser": clotho, "pos_samples":1, "neg_samples": 0},
    {"name": "AudioCaps", "parser": audiocaps, "pos_samples":1, "neg_samples": 0},
]

In [4]:
if RELOAD_TRAIN or not (HF_DATASETS_DIR / "train").exists():
    def get_train(parser, pos_samples, neg_samples):
        train_set = parser.to_hf(SplitType.DEV)
        train_set = train_set.filter(lambda row: is_valid_audio(row["path"], 0.1), num_proc=32)
        train_set = train_set.map(create_sample_generator(num_pos=pos_samples, num_neg=neg_samples), batched=True)
        return train_set
    
    train_set = datasets.concatenate_datasets([get_train(**args) for args in TRAIN_SETS])
    train_set = train_set.shuffle()
    train_set = train_set.flatten_indices() # avoid 10x slowdown
    train_set.save_to_disk(str(HF_DATASETS_DIR / "train"), num_proc=32)
else:
    train_set = datasets.load_from_disk(str(HF_DATASETS_DIR/"train"))

Loading dataset from disk:   0%|          | 0/32 [00:00<?, ?it/s]

In [5]:
if RELOAD_VAL or not (HF_DATASETS_DIR / "val").exists():
    def get_val(parser, pos_samples, neg_samples):
        val_set = parser.to_hf(SplitType.VAL)
        val_set = val_set.filter(lambda row: is_valid_audio(row["path"], 0.1), num_proc=32)
        val_set = val_set.map(create_sample_generator(num_pos=pos_samples, num_neg=neg_samples), batched=True)
        return val_set

    val_set = {}
    for metadata in VAL_SETS:
        ds_name = metadata.pop("name")
        ds = get_val(**metadata)
        val_set[ds_name] = ds
        ds.save_to_disk(str(HF_DATASETS_DIR/"val"/ds_name), num_proc=32)
else:
    names = [metadata["name"] for metadata in VAL_SETS]
    val_set = {name : datasets.load_from_disk(str(HF_DATASETS_DIR/"val"/name)) for name in names}

Loading dataset from disk:   0%|          | 0/32 [00:00<?, ?it/s]

Loading dataset from disk:   0%|          | 0/32 [00:00<?, ?it/s]

In [6]:
retriever = TemporalAudioTextRetrieverCrossAtt(
    loss_fn=create_contrastive_loss(0.7, True),
    text_enc=TemporalRoBERTaEncoder(2048, 512, seq_len=40),
    audio_enc=TemporalASTEncoder(2048, 512, pooling_kernel_size=8, pooling_stride=6, pooling_padding=0),
    num_heads=8,
    att_aggregation="mean"
).to(device)
collator = AudioTextDataCollator(retriever.AudioEncoder.cpu().preprocess, retriever.TextEncoder.cpu().preprocess)
#collator = ProcessedAudioTextDataCollator(retriever.AudioEncoder.cpu().preprocess, retriever.TextEncoder.cpu().preprocess)

Created contrastive loss function for cross-attention model with temperature 0.7.


Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def capture_arg(is_cross_att):
    return lambda compute_metrics: lambda output: compute_metrics(output, is_cross_att)

@capture_arg(is_cross_att=True)
def compute_metrics(output, is_cross_att):
    logits, audio_embed, text_embed = output.predictions
    labels = output.label_ids
    #a2t_metrics = EvalMetrics.AudioToTextRetrieval(embeddings, labels, is_cross_att)
    t2a_metrics = EvalMetrics.TextToAudioRetrieval(labels, audio_embed, text_embed, retriever.do_cross_attention if is_cross_att else None)
    #audio_r_1, audio_r_5 = a2t_metrics.recall_at_k([1, 5])
    text_r_1, text_r_5 = t2a_metrics.recall_at_k([1,5])
    return {
        #"R@1 (A->T)": audio_r_1,
        #"R@5 (A->T)": audio_r_5,
        #"mAP@10 (A->T)": a2t_metrics.mAP_at_k(10),
        #"MeanR (A->T)": a2t_metrics.mean_rank(),
        "R@1 (T->A)": text_r_1,
        "R@5 (T->A)": text_r_5,
        "mAP@10 (T->A)": t2a_metrics.mAP_at_k(10),
        "MeanR (T->A)": t2a_metrics.mean_rank()
    }

In [8]:
train_args = TrainingArguments(
    output_dir="/tmp/kokcz/train_out",
    overwrite_output_dir=True,
    group_by_length=False,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    eval_strategy="steps",
    num_train_epochs=7,
    save_steps=2048,
    eval_steps=1024,
    logging_steps=32,
    learning_rate=5e-6,
    metric_for_best_model="AudioCaps_MeanR (T->A)",
    greater_is_better=False,
    dataloader_num_workers=48,
    load_best_model_at_end=True,
    remove_unused_columns=False,
    run_name="crossatt4-mean-retrieval"
)

In [9]:
trainer = Trainer(
    model=retriever,
    args=train_args,
    data_collator=collator,
    compute_metrics=compute_metrics,
    train_dataset=train_set,
    eval_dataset=val_set
)

In [None]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkokcz[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113109333544141, max=1.0…

Step,Training Loss,Validation Loss,Clotho Loss,Clotho R@1 (t->a),Clotho R@5 (t->a),Clotho Map@10 (t->a),Clotho Meanr (t->a),Audiocaps Loss,Audiocaps R@1 (t->a),Audiocaps R@5 (t->a),Audiocaps Map@10 (t->a),Audiocaps Meanr (t->a)
1024,0.5568,No log,2.4529,0.050718,0.161722,0.099144,81.305263,1.505775,0.045564,0.197122,0.113964,51.897362
2048,0.4667,No log,2.061093,0.0689,0.241148,0.143935,59.023923,1.193610,0.069544,0.265707,0.156055,39.722782
3072,0.3647,No log,1.958749,0.090909,0.270813,0.173435,47.801914,1.118516,0.075300,0.312710,0.173129,35.968825
4096,0.3148,No log,1.949999,0.083254,0.278469,0.165031,50.595215,1.019659,0.076259,0.313669,0.178495,30.121823
5120,0.2855,No log,1.830827,0.092823,0.310048,0.184138,43.972249,0.992769,0.085851,0.342446,0.195495,29.664748
6144,0.2788,No log,1.834458,0.10622,0.321531,0.197658,46.457416,0.968023,0.087290,0.349161,0.198546,28.871463
7168,0.2401,No log,1.8234,0.121531,0.325359,0.210267,43.690909,0.959870,0.100240,0.371223,0.212538,27.815348
8192,0.2371,No log,1.816666,0.114833,0.32823,0.20546,43.244976,0.930535,0.096403,0.384652,0.217321,25.882494
9216,0.2353,No log,1.81693,0.117703,0.333014,0.208768,43.833493,0.904991,0.093046,0.371223,0.210439,26.415348
10240,0.1991,No log,1.779964,0.112919,0.339713,0.211184,39.957895,No Log,No Log,No Log,No Log,No Log


In [12]:
trainer.save_model("../saved_models/mean_downsample_202")

In [None]:
def wandb_hp_space(trial):
    return {
        "project": "audio-text-retrieval",
        "method": "grid",
        "metric": {"name": "mAP@10 (T->A)", "goal": "maximize"},
        "parameters": {
            "temperature": {"values": np.linspace(0.1, 1.5, 15).tolist()}
        },
    }

In [None]:
def model_init(trial):
    print("trial: ", trial)
    return AudioTextRetriever(contrastiveCE if trial is None else create_contrastive_loss(trial["temperature"])).to(device)

In [None]:
hp_train_args = TrainingArguments(
    output_dir="/tmp/kokcz/train_out",
    overwrite_output_dir=True,
    group_by_length=False,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    eval_strategy="steps",
    learning_rate=1e-4,
    num_train_epochs=50,
    fp16=False,
    save_steps=128,
    eval_steps=64,
    logging_steps=32,
    optim="adamw_torch",
    save_total_limit=2,
    dataloader_num_workers=4,
    load_best_model_at_end=True,
    remove_unused_columns=False,
    greater_is_better=False,
    run_name="audio-text-retrieval_temp-sweep"
)

In [None]:
hp_trainer = Trainer(
    model=None,
    args=hp_train_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    model_init=model_init,
    data_collator=collator,
    compute_metrics=compute_metrics
)

In [None]:
hp_trainer.hyperparameter_search(direction="minimize", backend="wandb", hp_space=wandb_hp_space)

In [None]:
hp_trainer.save_model("../saved_models/hp_temp")