In [None]:
import librosa
import numpy as np
import os
from tqdm import tqdm
import numpy as np
import librosa
import librosa.display
import soundfile as sf
import datasets

TARGET_SR = 44100

def audio2mel(filepath: str, start: int = 0):
    #TODO load different parts of the audio, not just the beginning
    x, sr = librosa.load(filepath, sr=TARGET_SR, mono=True)
    start, end = 0, 432*512-1

    stft = np.abs(librosa.stft(x[start:end], n_fft=2048, hop_length=512))
    mel = librosa.feature.melspectrogram(sr=sr, S=stft**2, n_mels=128)
    log_mel = librosa.amplitude_to_db(mel)

    return log_mel

def make_dataset(files_dir: str, *, count: int = -1):
    all_mels = {}
    loaded_count = 0
    files = os.listdir(files_dir)
    for file in tqdm(files, desc="Processing files...", total=len(files) if count == -1 else min(count, len(files))):
        if not file.lower().endswith('.mp3') and not file.lower().endswith(".wav"):  # Adds wav support
            continue
        if count > 0 and loaded_count >= count:
            break
        filepath = os.path.join(files_dir, file)
        filename = os.path.basename(filepath)

        # Call the function on the MP3 file
        try:
            mel, sr = audio2mel(filepath)
            if mel.shape == (128, 432):
                all_mels[filename + f'_{sr}'] = mel
                loaded_count += 1
            else:
                print("Skipping shape {}".format(mel.shape))
        except Exception as e:
            print(e)
            pass

    return all_mels

def save_spec(dataset: dict, save_dir: str):
    os.makedirs(save_dir, exist_ok=True)
    np.savez_compressed(os.path.join(save_dir, "specs.npz"), **dataset)


def _check_spec(spec):
    # To keep my sanity
    assert spec.shape == (128, 432), f"Shape is {spec.shape}, expected (128, 432)"
    assert spec.dtype == np.float32, f"Data type is {spec.dtype}, expected np.float32"
    assert np.isfinite(spec).all(), "Data contains non-finite values"
    assert np.abs(spec).max() <= 80, "Data contains values greater than 80 dB"
    assert np.abs(spec).max() > 1, "Empty data, or you probably forget to unnormalize it"

def load_spec(spec_file: str):
    ds = np.load(spec_file)
    dsdict = {}
    for key in ds:
        try:
            _check_spec(ds[key])
        except Exception as e:
            print(f"Error in {key}: {e}")
            continue
        dsdict[key] = ds[key] / 80.0 # Normalize to [-1, 1]
    return dsdict

def mel_to_audio(spec, sr: int, n_iter: int = 32):
    _check_spec(spec)
    mel = librosa.db_to_amplitude(spec * 80.0)

    mel_basis = librosa.filters.mel(sr, n_fft=2048, n_mels=128)
    inv_mel_basis = np.linalg.pinv(mel_basis)
    stft_magnitude = np.dot(inv_mel_basis, mel)

    stft_magnitude_squared = stft_magnitude**2
    audio = librosa.griffinlim(stft_magnitude_squared, hop_length=512, n_iter=n_iter)

    return audio

def convert_ds_to_hf_dataset(ds: dict, test_size: float = 0.1):
    hfds_dict = {"filename": [], "mel": []}
    for key in ds:
        hfds_dict["filename"].append(key)
        hfds_dict["mel"].append(ds[key])
    hf_ds = datasets.Dataset.from_dict(hfds_dict)

    #Train test split
    train_ds, test_ds = hf_ds.train_test_split(test_size=test_size)
    return hf_ds

In [None]:
import huggingface_hub

# huggingface_hub.login("nope")

In [None]:
ds = convert_ds_to_hf_dataset(load_spec("./output/specs.npz"))

# Upload to hf hub
ds.push_to_hub("mel-spectrogram-dataset-test", private=True)

In [None]:
import torch
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from diffusers import DDPMScheduler, UNet2DModel
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configure the model
model = UNet2DModel(
    sample_size=(128, 432),
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D")
).to(device)

# Parameters
batch_size = 1
num_epochs = 10
learning_rate = 1e-4

# Load the dataset
#TODO replace this line or smth
dataset = load_dataset("mel-spectrogram-dataset-test")

# DataLoader expects a torch.Tensor, thus a conversion function is needed
def collate_fn(batch):
    mels = [item['mel'].unsqueeze(0) for item in batch]  # Adding channel dimension
    mels = torch.stack(mels).to(device)  # Shape will be [batch_size, 1, 128, 432]
    return mels

# DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# Scheduler and Optimizer
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Loss function
loss_func = torch.nn.MSELoss()

# Training Loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        optimizer.zero_grad()

        # Add noise
        noise = torch.randn_like(batch)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device, dtype=torch.int64)
        noisy_batch = noise_scheduler.add_noise(batch, noise, timesteps)

        # Model forward pass
        noise_pred = model(noisy_batch, timesteps)[0]

        # Loss calculation
        loss = loss_func(noise_pred, noise)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch + 1} completed, Average Loss: {epoch_loss / len(train_loader)}')

print("Training completed.")

In [None]:
# import torch
# from torchvision import datasets, transforms
# from torchvision.utils import make_grid
# from diffusers import DDPMScheduler, UNet2DModel

# device = torch.device("cuda")

# model = UNet2DModel(
#     sample_size=(128, 432), # Dimensions must be a multiple of 2 ** (len(block_out_channels) - 1) = 4.
#     in_channels=1,
#     out_channels=1,
#     layers_per_block=2,
#     block_out_channels=(32, 64, 64),
#     down_block_types=(
#         "DownBlock2D",
#         "AttnDownBlock2D",
#         "DownBlock2D",
#     ),
#     up_block_types=(
#         "UpBlock2D",
#         "AttnUpBlock2D",
#         "UpBlock2D",
#     )
# ).to(device)

# batch_size = 1

# noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# opt = torch.optim.Adam(model.parameters())

# x = torch.randn(1, 1, 128, 432).to(device)

# noise = torch.randn(x.shape, device=device)
# timesteps = torch.randint(
#     0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device, dtype=torch.int64
# ).to(device)
# noisy_x = noise_scheduler.add_noise(x, noise, timesteps).to(device)
# noise_pred = model(noisy_x, timesteps)[0]