In [1]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchaudio

from tqdm import tqdm
from typing import Optional
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from mir_eval.separation import bss_eval_sources
from sklearn.model_selection import train_test_split



from enrollment_model import MyModel
# from load_data import BEATS_path, ORIG_mixture, ORIG_target #, stems
# from dataset import MusicDataset
from loss import L1SNR_Recons_Loss, L1SNRDecibelMatchLoss
from utils import _load_config
from metrics import (
    AverageMeter, cal_metrics, safe_signal_noise_ratio, MetricHandler
)

from models.types import InputType, OperationMode, SimpleishNamespace
from data.moisesdb.datamodule import (
    MoisesTestDataModule,
    MoisesValidationDataModule,
    MoisesDataModule,
    MoisesBalancedTrainDataModule,
    MoisesVDBODataModule,
)


from dismix.dismix_model import DisMixModel

  from .autonotebook import tqdm as notebook_tqdm


In [39]:
# # Cropping
# def crop_sustain_phase(mel_spectrogram, crop_frames=10, start_frame=None):
#     """
#     Crop a 320ms segment (10 frames) from the sustain phase of the mel spectrogram.
    
#     Parameters:
#     - mel_spectrogram: Mel spectrogram to crop.
#     - crop_frames: Number of frames to crop (10 frames corresponds to 320ms).
#     - start_frame: Starting frame for cropping (if None, find from sustain phase).
    
#     Returns:
#     - Cropped mel spectrogram segment, start_frame used for alignment.
#     """
#     # Calculate energy for each frame
#     frame_energy = torch.sum(mel_spectrogram, dim=0)
    
#     # Find the maximum energy frame index (attack phase) if start_frame is not provided
#     if start_frame is None:
#         max_energy_frame = torch.argmax(frame_energy)
#         # Define the starting frame of the sustain phase, a few frames after the peak energy
#         start_frame = max_energy_frame + 5  # Shift 5 frames after peak to avoid attack phase
    
#     # Ensure the crop window does not exceed the spectrogram length
#     if start_frame + crop_frames > mel_spectrogram.size(1):
#         start_frame = max(0, mel_spectrogram.size(1) - crop_frames)
    
#     # Crop the mel spectrogram segment
#     cropped_segment = mel_spectrogram[:, start_frame:start_frame + crop_frames]
    
#     return cropped_segment, start_frame



# def processing(mel_spectrogram, start_frame=None):
    
#     # Convert complex-valued spectrogram to magnitude (real values)
#     mel_spectrogram_magnitude = torch.abs(mel_spectrogram)
    
#     # Convert amplitude to decibel scale
#     mel_spectrogram_db = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel_spectrogram_magnitude)
    
#     # Crop a 320ms segment (10 frames) from the sustain phase
#     cropped_mel_spectrogram, start_frame = crop_sustain_phase(mel_spectrogram_db.squeeze(0), crop_frames=10, start_frame=start_frame)
    
#     return cropped_mel_spectrogram, start_frame

    

In [2]:
# Init settings
wandb_use = False # False
lr = 1e-3 # 1e-4
num_epochs = 1 # 500
batch_size = 4 # 8
n_srcs = 1
emb_dim = 768 # For BEATs
query_size = 512 # 512
mix_query_mode = "Hyper_FiLM" # "Transformer"
q_enc = "Passt"
config_path = "config/train.yml"
mask_type = "L1"
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print("Training on device:", device)


def to_device(batch, device=device):
    batch.mixture.audio = batch.mixture.audio.to(device) # torch.Size([BS, 2, 294400])
    batch.sources.target.audio = batch.sources.target.audio.to(device) # torch.Size([BS, 2, 294400])
    batch.query.audio = batch.query.audio.to(device) # torch.Size([BS, 2, 441000])
    return batch


if wandb_use:
    wandb.init(
        project="Query_ss",
        config={
            "learning_rate": lr,
            "architecture": "Transformer_UNet Using 9 stems",
            "dataset": "MoisesDB",
            "epochs": num_epochs,
        },
        notes=f"{mix_query_mode} + {mask_type} Loss + 512 query size",
    )


config = _load_config(config_path)
stems = config.data.train_kwargs.allowed_stems
print("Training with stems: ", stems)

datamodule = MoisesDataModule(
    data_root=config.data.data_root,
    batch_size=batch_size, #config.data.batch_size,
    num_workers=config.data.num_workers,
    train_kwargs=config.data.get("train_kwargs", None),
    val_kwargs=config.data.get("val_kwargs", None),
    test_kwargs=config.data.get("test_kwargs", None), # Cannot use now
    datamodule_kwargs=config.data.get("datamodule_kwargs", None),
)



# Instantiate the enrollment model
model = DisMixModel(
    input_dim=128, 
    latent_dim=64, 
    hidden_dim=256, 
    gru_hidden_dim=256,
    num_frames=10,
    pitch_classes=52,
    output_dim=128,    
).to(device)


def window_fn(win_length):
    return torch.hann_window(win_length).to(device)

# Initialize the Spectrogram transform with the correct window function
stft = torchaudio.transforms.Spectrogram(
    n_fft=2048,
    win_length=2048,
    hop_length=512,
    pad_mode="constant",
    pad=0,
    window_fn=window_fn,  # Pass the callable window function
    wkwargs=None,
    power=None,
    normalized=True,
    center=True,
    onesided=True,
)
        

Training on device: cuda:2
Training with stems:  ['bass', 'drums', 'lead_male_singer', 'lead_female_singer', 'distorted_electric_guitar', 'clean_electric_guitar', 'acoustic_guitar', 'grand_piano', 'electric_piano']
Loading query tuples from /home/buffett/NAS_NTU/moisesdb/queries/4/chunk294400-hop264600/query441000-n1/queries-9stems:b:d:l:l:d:c:a:g:e.csv
Loading query tuples from /home/buffett/NAS_NTU/moisesdb/queries/5/chunk294400-hop264600/query441000-n1/queries-9stems:b:d:l:l:d:c:a:g:e.csv


In [46]:
for batch in tqdm(datamodule.train_dataloader()):
    batch = InputType.from_dict(batch)
    batch = to_device(batch)
    print(batch.mixture.audio.shape, batch.query.audio.shape)
    break

  0%|          | 0/8192 [00:32<?, ?it/s]

torch.Size([1, 2, 294400]) torch.Size([1, 2, 441000])





In [48]:
batch = InputType.from_dict(batch)
batch = to_device(batch)

batch.mixture.spectrogram = stft(batch.mixture.audio)
batch.query.spectrogram = stft(batch.query.audio)

mixture_orig = batch.mixture.spectrogram.to(device)
query_orig = batch.query.spectrogram.to(device)
print( query_orig.shape, mixture_orig.shape)


# Processing
query, start_frame = processing(query_orig)
mixture, _ = processing(mixture_orig, start_frame)
print(query.shape, mixture.shape)

# rec_mixture, pitch_latent, pitch_logits, timbre_latent, timbre_mean, timbre_logvar, eq = model(mixture, query)

# print(pitch_latent.shape, timbre_latent.shape)

# pitch_data.append(pitch_latent)
# timbre_data.append(timbre_latent)
# stem_data.append(batch.metadata.stem)


torch.Size([1, 2, 1025, 862]) torch.Size([1, 2, 1025, 576])
torch.Size([2, 10, 862]) torch.Size([2, 10, 576])


In [3]:
spec_store = []
stem_names = []
song_id = []

# Training loop
for batch in tqdm(datamodule.train_dataloader()):
    batch = InputType.from_dict(batch)
    batch = to_device(batch)
    
    # batch.mixture.spectrogram = stft(batch.mixture.audio)
    batch.query.spectrogram = stft(batch.query.audio)
    
    for i in range(batch_size):
        spec_store.append(batch.query.spectrogram[i])
        stem_names.append(batch.metadata.stem[i])
        song_id.append(batch.metadata.query.song_id[i])

    # break
    # mixture_orig = batch.mixture.spectrogram.to(device)
    # query_orig = batch.query.spectrogram.to(device)
    
    
    # # Processing
    # query, start_frame = processing(query_orig)
    # mixture, _ = processing(mixture_orig, start_frame)
    # print(query.shape, mixture.shape)
    
    # rec_mixture, pitch_latent, pitch_logits, timbre_latent, timbre_mean, timbre_logvar, eq = model(mixture, query)
    
    # print(pitch_latent.shape, timbre_latent.shape)
    
    # pitch_data.append(pitch_latent)
    # timbre_data.append(timbre_latent)
    # stem_data.append(batch.metadata.stem)
    # break
    

  2%|▏         | 48/2048 [12:26<5:30:47,  9.92s/it] 

In [4]:
batch.query.spectrogram.shape

torch.Size([4, 2, 1025, 862])