## Pre-Processing into Spectrograms  

In [None]:
import numpy as np
import librosa
import os
import glob
from tqdm import tqdm
import matplotlib.pyplot as plt

from utils import compute_all_spectrograms

In [None]:
input_folder = "BallroomData"  # Replace with audio file folder path
output_folder = "spectrograms11"

compute_all_spectrograms(input_folder, output_folder)

## Creating Dataloaders

In [None]:
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from dataset import BallroomDataset
from model import BeatTrackingNet

In [None]:
spectrogram_dir = 'spectrograms'
annotation_dir = 'BallroomAnnotations'
training_dataset = BallroomDataset(spectrogram_dir, annotation_dir)

indices = list(range(len(training_dataset)))
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)

train_sub_dataset = torch.utils.data.Subset(training_dataset, train_indices)
test_sub_dataset = torch.utils.data.Subset(training_dataset, test_indices)

print(f'train: {len(train_sub_dataset)}, test: {len(test_sub_dataset)}')

In [None]:
it = iter(train_sub_dataset)
x, y = next(it)
print(f'x: {x.shape}, target: {y.shape}')

## Creating Models

In [None]:
model_fuz = BeatTrackingNet(input_dim=81, num_filters=16, kernel_size=5, num_layers=11)
model_no_fuz = BeatTrackingNet(input_dim=81, num_filters=16, kernel_size=5, num_layers=11)

a = torch.randn(1, 1, 3000, 81)
print(model_fuz(a).shape)
print(model_no_fuz(a).shape)

## Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'on {device}')

In [None]:
batch_size = 1

train_loader = DataLoader(train_sub_dataset, batch_size=batch_size, shuffle=True)

In [None]:
criterion = torch.nn.BCELoss()

optimizer_fuz = torch.optim.Adam(model_fuz.parameters(), lr=0.001)
optimizer_no_fuz = torch.optim.Adam(model_no_fuz.parameters(), lr=0.001)

In [None]:
num_epochs = 200

model_fuz = model_fuz.to(device)

for epoch in range(num_epochs):
    model_fuz.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for spectrogram, target in progress_bar:
        spectrogram, target = spectrogram.to(device).unsqueeze(1), target.to(device)

        optimizer_fuz.zero_grad()
        output = model_fuz(spectrogram)

        loss = criterion(output, target)
        loss.backward()
        optimizer_fuz.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


torch.save(model_fuz.state_dict(), 'epoch200_fuz.pt')
print("complete")

In [None]:
training_dataset.fuzziness = False

In [None]:
num_epochs = 200

model_no_fuz = model_no_fuz.to(device)

for epoch in range(num_epochs):
    model_no_fuz.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

    for spectrogram, target in progress_bar:
        spectrogram, target = spectrogram.to(device).unsqueeze(1), target.to(device)

        optimizer_no_fuz.zero_grad()
        output = model_no_fuz(spectrogram)

        loss = criterion(output, target)
        loss.backward()
        optimizer_no_fuz.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

    avg_loss = running_loss / len(train_loader)

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")


torch.save(model_no_fuz.state_dict(), 'epoch200_no_fuz.pt')
print("complete")

## Plotting  


In [None]:
from madmom.features.beats import DBNBeatTrackingProcessor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import mir_eval

In [None]:
dbn_processor = DBNBeatTrackingProcessor(min_bpm=55, max_bpm=215, threshold=0.4, fps=100)

test_loader = DataLoader(training_dataset, 1, shuffle=False)
model_no_fuz.eval()

sample_data = list(test_loader)
spectrogram, ground_truth = sample_data[2]  # Select sample
spectrogram = spectrogram.to(device).unsqueeze(1)[:,:,800:1200,:]
ground_truth = ground_truth.cpu().numpy().flatten()[800:1200]

with torch.no_grad():
    activation = model_no_fuz(spectrogram).squeeze(1).cpu().numpy().flatten()

detected_beats = dbn_processor(activation)

hop_size = 0.01
ground_truth_beats = np.where(ground_truth > 0.5)[0] * hop_size


In [None]:
plt.figure(figsize=(15, 3))
plt.plot(activation)
plt.xlabel("Time Steps (frames)")
plt.ylabel("Activation")

In [None]:
plt.figure(figsize=(15, 3))
plt.ylim(0, 8192)
# plt.xlim(5,20)
librosa.display.specshow(librosa.power_to_db(spectrogram[0,0].T.cpu(), ref=np.max), sr=44100, hop_length=441, x_axis='time', y_axis='mel', cmap='magma')

plt.scatter(detected_beats, [4096] * len(detected_beats), color='blue', label='Predicted Beats (DBN)', marker='o', edgecolors='white', s=60)

plt.scatter(ground_truth_beats, [4096] * len(ground_truth_beats), color='lime', label='Ground Truth Beats', marker='x', s=80)

plt.title("Predicted Beats (DBN) vs. Ground Truth on Mel Spectrogram")
plt.xlabel("Time (s)")
plt.ylabel("Mel Frequency Bands")
plt.legend()
# plt.colorbar(label="Log Magnitude (dB)")
plt.show()

## Metrics

In [None]:
dbn_processor = DBNBeatTrackingProcessor(min_bpm=55, max_bpm=215, threshold=0.4, fps=100)

test_loader = DataLoader(test_sub_dataset, batch_size=1, shuffle=False)
model_no_fuz.eval()

all_predictions = []
all_ground_truths = []

metrics = {'F-measure':0,
           'Correct Metric Level Continuous':0,
           'Correct Metric Level Total':0,
           'Any Metric Level Continuous':0,
           'Any Metric Level Total':0
           }

count = 0
with torch.no_grad():
    for spectrogram, ground_truth in test_loader:
        spectrogram = spectrogram.to(device).unsqueeze(1)
        ground_truth = ground_truth.cpu().numpy().flatten()

        activation = model_no_fuz(spectrogram).squeeze(1).cpu().numpy().flatten()

        detected_beats = dbn_processor(activation)

        hop_size = 0.01
        ground_truth_beats = np.where(ground_truth > 0.5)[0] * hop_size

        scores = mir_eval.beat.evaluate(detected_beats, ground_truth_beats)

        count += 1
        for metric in metrics:
            metrics[metric] += scores[metric]

for metric in metrics:
    metrics[metric] = metrics[metric]/count

In [None]:
metrics