In [None]:
from pathlib import Path

import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from beat_this.dataset import BeatDataModule
from beat_this.model.pl_module import PLBeatThis

args = {
    "name": "",
    "gpu": 0,
    "force_flash_attention": False,
    "compile": ["frontend", "transformer_blocks", "task_heads"],
    "n_layers": 6, 
    "transformer_dim": 512,
    "frontend_dropout": 0.1,
    "transformer_dropout": 0.2,
    "lr": 0.0008,
    "weight_decay": 0.01,
    "logger": "none",
    "num_workers": 8,
    "n_heads": 16,
    "fps": 50,
    "loss": "shift_tolerant_weighted_bce", # other options are "fast_shift_tolerant_weighted_bce", "weighted_bce", "bce"
    "warmup_steps": 1000,
    "max_epochs": 100,
    "batch_size": 8,
    "accumulate_grad_batches": 8,
    "train_length": 1500,
    "dbn": False,
    "eval_trim_beats": 5, # skip the first given seconds per piece in evaluating
    "val_frequency": 5, # validate every N epochs
    "tempo_augmentation": True,
    "pitch_augmentation": True,
    "mask_augmentation": True,
    "sum_head": True,
    "partial_transformers": True,
    "length_based_oversampling_factor": 0.65,
    "val": True, # whether to include the validation data in training
    # "hung_data": False,
    # "fold": None,
    "seed": 127,
    # "resume_checkpoint": None,
    # "resume_id": None,
}

seed_everything(args["seed"], workers = True)

params_str = f"{'noval ' if not args.get('val') else ''}{'hung ' if args.get('hung_data') else ''}{'fold' + str(args.get('fold')) + ' ' if args.get('fold') is not None else ''}{args.get('loss')}-h{args.get('transformer_dim')}-aug{args.get('tempo_augmentation')}{args.get('pitch_augmentation')}{args.get('mask_augmentation')}{' nosumH ' if not args.get('sum_head') else ''}{' nopartialT ' if not args.get('partial_transformers') else ''}"
if args.get('logger') == "wandb":
    if args.get('resume_checkpoint') and args.get('resume_id'):
        wandb_args = dict(id = args.get('resume_id'), resume = "must")
    else:
        wandb_args = {}
    logger = WandbLogger(
        project = "beat_this", name = f"{args.get('name')} {params_str}".strip(), **wandb_args
    )
else:
    logger = None

# i'm on cpu so no flash attention for me :(
if args.get("force_flash_attention"):
    print("Forcing the use of the flash attention.")
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(False)

augmentations = {}
if args.tempo_augmentation:
    augmentations["tempo"] = {"min": -20, "max": 20, "stride": 4}
if args.pitch_augmentation:
    augmentations["pitch"] = {"min": -5, "max": 6}
if args.mask_augmentation:
    # kind, min_count, max_count, min_len, max_len, min_parts, max_parts
    augmentations["mask"] = {
        "kind": "permute",
        "min_count": 1,
        "max_count": 6,
        "min_len": 0.1,
        "max_len": 2,
        "min_parts": 5,
        "max_parts": 9,
    }



### Downloading songs

In [None]:
import os
import time
import requests
import subprocess
import pandas as pd
from pathlib import Path
from dotenv import load_dotenv
import googleapiclient.discovery

load_dotenv("keys.env")

acoustidURL    = "https://api.acoustid.org/v2/lookup"
THRESHOLD      = 0.8

sourcesFile    = Path("data/songs/sources.csv")
songsDir       = Path("data/songs/")
fpDir          = Path("data/fingerprints/")
metadataFile   = Path("data/metadata.csv")

i = 2
N = 10
acoustID = "258db8e7-6136-4836-bf7a-41f8ab8c1aac"

for j in range(N):
    song = songsDir / f"{i}_{j}.mp3"
    text = fpDir / f"{i}_{j}.txt"
    # the lines below executes the fpcalc command in the shell
    subprocess.run(f"fpcalc {song} > {text}", shell = True, check = True)

    with open(text, "r") as f:
        lines = f.readlines()
        duration = int(lines[0].strip().split('=')[1])
        fingerprint = lines[1].strip().split('=')[1]
    
    response = requests.get(acoustidURL, params = {
        "client": os.getenv("ACOUSTID_API_KEY"),
        "format": "json",
        "duration": duration,
        "fingerprint": fingerprint
    })
    time.sleep(0.35) # to avoid rate limiting
    if response.status_code != 200 or len(response.json()['results']) == 0:
        print(f"AcoustID API request failed for file {i}_{j}.mp3! Exit code ")
        continue

    for d in response.json()['results']:
        if acoustID == d['id'] and d['score'] > score:
            score  = d['score']
            bestID = i

Unnamed: 0_level_0,title,artist,year,album,acoustID
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,A nuestro modo,Gabino Pampini,1986,Fuerza Noble,258db8e7-6136-4836-bf7a-41f8ab8c1aac
2,Aguzate,Richie Ray & Bobby Cruz,1970,Aguzate,b874c7c2-a702-4c70-8f74-1b7c56eee20d
3,Amor Traicionero,Guayacan Orquesta,1993,Con el Corazón Abíerto,cb7ea308-9264-4ff7-9b96-49bf42be0704
4,Arroz con habichuela,El Gran Combo de Puerto Rico,2006,Arroz con habichuela,d62f529a-969f-4696-b4d6-910a32bfef14
5,"Bang, Bang",Joe Cuba Sextet,1966,Wanted Dead or Alive,94c43fa0-4677-40af-8ef4-5ea98ef3a0eb
