In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
from transformers import get_cosine_schedule_with_warmup

# DATA LOADING

In this section, I am loading my data from a custom dataset tailored for my specific use case. This dataset has been preprocessed and saved in a PyTorch tensor format, allowing for efficient loading and manipulation during model training and evaluation.


In [None]:


def load_preprocessed_dataset(pt_path='dataset_V2_extended.pt', shuffle=True, seed=42):

    data = torch.load(pt_path)
    X, y = data['X'], data['y']

    assert len(X) == len(y)
    if shuffle:
        torch.manual_seed(seed)
        indices = torch.randperm(len(X))
        X, y = X[indices], y[indices]

    train_size = int(0.8 * len(X))
    X_train, y_train = X[:train_size], y[:train_size]
    X_val, y_val = X[train_size:], y[train_size:]

    return X_train, y_train, X_val, y_val



def get_dataloader(X, y, batch_size=16, shuffle=True):
    dataset = TensorDataset(X, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


x_train, y_train, x_val, y_val = load_preprocessed_dataset()
train_data_loader = get_dataloader(x_train, y_train, batch_size=16, shuffle=False)
val_data_loader = get_dataloader(x_val, y_val, batch_size=16, shuffle=False)

# BUILD MY NN

In this section, I build my simple but efficient neural network made of 3 layers and dropout to regularize.


In [None]:


class Net(nn.Module):
  def __init__(self, NB_FEATURES):
    super().__init__()
    self.layer1 = nn.Linear(NB_FEATURES, 128)
    self.layer2 = nn.Linear(128, 256)
    self.layer3 = nn.Linear(256, 128)
    self.output = nn.Linear(128, 14)

    self.act = nn.GELU()
    self.dropout = nn.Dropout(0.1)

  def forward(self, x):
    x = self.act(self.layer1(x))
    x = self.dropout(x)
    x = self.act(self.layer2(x))
    x = self.dropout(x)
    x = self.act(self.layer3(x))
    x = self.dropout(x)
    last_hidden_states = x
    out = self.output(x)

    return out, last_hidden_states

In [None]:



lossfn = nn.MSELoss()
net = Net(NB_FEATURES=29)

total_epochs = 150
optim = torch.optim.AdamW(net.parameters(), lr=1e-3)

scheduler = get_cosine_schedule_with_warmup(
    optimizer=optim,
    num_warmup_steps=5,
    num_training_steps=total_epochs,
)

epochs = []
losses = []
lrs = []
best_loss = float('inf')
best_model_path = "best_model.pth"
net.train()

for epoch in range(total_epochs):
    for step, (X, Y) in enumerate(train_data_loader):
        optim.zero_grad(set_to_none=True)

        out, _ = net(X)
        loss = lossfn(out, Y)



        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(net.state_dict(), best_model_path)
            print(f"Saved new best model at epoch {epoch+1} step {step+1} with loss {best_loss:.4f}")

        loss.backward()
        optim.step()
        scheduler.step()
    losses.append(loss.item())
    epochs.append(epoch)
    print(f"Epoch {epoch+1}, loss: {loss.item():.4f}")




plt.xlabel('epochs')
plt.ylabel('loss')
plt.plot(epochs, losses)
plt.grid(True)
plt.show()

# INFERENCE & PLOT RESULT

In [None]:
def inference(audio):
    adl = AudioDataLoader()
    adl.load_stats()
    net.load_state_dict(torch.load("best_model.pth"))

    raw_features = audio.extract_features()
    X = torch.tensor(list(raw_features.values()), dtype=torch.float32)
    y = torch.tensor(list(audio.labels.values()), dtype=torch.float32) * 10

    # Normalize input
    X = (X - adl.global_mean) / (adl.global_std + 1e-8)

    net.eval()
    with torch.no_grad():
        res,_ = net(X)

    y_true = y.detach().numpy()
    y_pred = res.detach().numpy()
    loss = nn.MSELoss()(res, y)
    print(f"True: {y_true}")
    print(f"Pred: {y_pred}")
    print(f"Loss: {loss.item()}")

    feature_names = [
        "loudness", "harshness", "compression", "clarity", "bass", "muddiness",
        "noise_distortion", "stereo_width", "brightness", "warmth",
        "presence", "reverb_amount", "balance", "masking"
    ]

    x = range(len(feature_names))
    bar_width = 0.35

    plt.figure(figsize=(12, 6))
    plt.bar(x, y_true, width=bar_width, label='Target', alpha=0.7)
    plt.bar([i + bar_width for i in x], y_pred, width=bar_width, label='Prediction', alpha=0.7)

    plt.xticks([i + bar_width / 2 for i in x], feature_names, rotation=45)
    plt.ylabel("Value (0–10 scale)")
    plt.title("Track Characteristics: Prediction vs Target")
    plt.legend()
    plt.tight_layout()
    plt.savefig("pred_vs_target_sample.png", dpi=150)
    plt.show()