In [1]:
import os
import librosa
from IPython.display import Audio
os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
import torch
import torchaudio

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
from data.moisesdb_contrastive_online import MoisesdbContrastivePreprocessed #TODO
from pathlib import Path

import lightning as L
from torch.utils.data import DataLoader, ConcatDataset
import torch

from contrastive_model import constants
from feature_extraction.feature_extraction import CoColaFeatureExtractor

In [4]:
chunk_duration = 5.0
target_sample_rate = 16000
generate_submixtures=True
preprocess_transform = None

In [5]:
feature_extractor_type = constants.ModelFeatureExtractorType.HPSS #TODO
#feature_extractor_type = constants.ModelFeatureExtractorType.MEL_SPECTROGRAM
feature_extractor = CoColaFeatureExtractor(feature_extractor_type=feature_extractor_type)
runtime_transform = feature_extractor

In [6]:
def _get_moisesdb_splits(stage: str):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        root_dir = "/disk1/demancum/hp_datasets/moisesdb_contrastive" #TODO
        train_dataset, val_dataset, test_dataset = None, None, None

        if stage == "fit":
            train_dataset = MoisesdbContrastivePreprocessed(
                root_dir=root_dir,
                split="train",
                #preprocess=True,
                chunk_duration=chunk_duration,
                target_sample_rate=target_sample_rate,
                generate_submixtures=generate_submixtures,
                device=device,
                preprocess_transform=preprocess_transform,
                runtime_transform=runtime_transform)

            val_dataset = MoisesdbContrastivePreprocessed(
                root_dir=root_dir,
                split="valid",
                #preprocess=True,
                chunk_duration=chunk_duration,
                target_sample_rate=target_sample_rate,
                generate_submixtures=generate_submixtures,
                device=device,
                preprocess_transform=preprocess_transform,
                runtime_transform=runtime_transform)
        elif stage == "test":
            test_dataset = MoisesdbContrastivePreprocessed(
                root_dir=root_dir,
                split="test",
                #preprocess=True,
                chunk_duration=chunk_duration,
                target_sample_rate=target_sample_rate,
                generate_submixtures=generate_submixtures,
                device=device,
                preprocess_transform=preprocess_transform,
                runtime_transform=runtime_transform)

        return train_dataset, val_dataset, test_dataset

In [7]:
train_dataset, val_dataset, test_dataset = _get_moisesdb_splits(stage="fit")

Building track index: 100%|██████████| 216/216 [00:00<00:00, 6622.20it/s]
Building track index: 100%|██████████| 12/12 [00:00<00:00, 5839.62it/s]


In [8]:
#import cProfile
#cProfile.run('train_dataset[0]')

import cProfile, pstats, io
from pstats import SortKey
pr = cProfile.Profile()
pr.enable()
train_dataset[0]
pr.disable()
s = io.StringIO()
sortby = SortKey.TIME
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())

         2280297 function calls (2243953 primitive calls) in 1.752 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
6036/6004    0.215    0.000    0.256    0.000 /home/demancum/miniconda3/envs/cocola_test/lib/python3.11/site-packages/llvmlite/binding/ffi.py:190(__call__)
      156    0.146    0.001    0.443    0.003 /home/demancum/miniconda3/envs/cocola_test/lib/python3.11/inspect.py:969(getmodule)
3402/3277    0.122    0.000    0.124    0.000 {built-in method __new__ of type object at 0x866f00}
       17    0.108    0.006    0.108    0.006 {built-in method torch._ops.torchaudio_sox.load_audio_file}
     3222    0.077    0.000    0.077    0.000 {built-in method posix.stat}
      330    0.069    0.000    0.069    0.000 {built-in method io.open_code}
    18767    0.069    0.000    0.069    0.000 {built-in method posix.lstat}
      437    0.061    0.000    0.061    0.000 {method '__exit__' of '_io._IOBase' objects}
   292036   

In [28]:
import time
import random
import numpy as np

# Assuming train_dataset is already loaded or defined
n_samples = len(train_dataset)
n_trials = 100
times = []

# Measure the time it takes to compute train_dataset[i] for random indices
for _ in range(n_trials):
    i = random.randint(0, n_samples - 1)
    start_time = time.time()
    _ = train_dataset[i]
    end_time = time.time()
    times.append(end_time - start_time)

# Convert times to a numpy array for statistical computation
times = np.array(times)

# Compute statistics
mean_time = np.mean(times)
std_time = np.std(times)
median_time = np.median(times)
min_time = np.min(times)
max_time = np.max(times)

# Print results
print(f"Mean time: {mean_time:.6f} seconds")
print(f"Standard deviation: {std_time:.6f} seconds")
print(f"Median time: {median_time:.6f} seconds")
print(f"Minimum time: {min_time:.6f} seconds")
print(f"Maximum time: {max_time:.6f} seconds")


Mean time: 0.508575 seconds
Standard deviation: 0.066567 seconds
Median time: 0.519793 seconds
Minimum time: 0.297949 seconds
Maximum time: 0.764302 seconds


## NO HPSS

Mean time: 0.030078 seconds 

Standard deviation: 0.008805 seconds

Median time: 0.028174 seconds

Minimum time: 0.014488 seconds

Maximum time: 0.050782 seconds


## HPSS

Mean time: 0.513520 seconds

Standard deviation: 0.062310 seconds

Median time: 0.532635 seconds

Minimum time: 0.339727 seconds

Maximum time: 0.616706 seconds


In [27]:
audio_path = "/speech/dbwork/mul/spielwiese4/students/demancum/musdb18hq/test/Al James - Schoolboy Facination/mixture.wav"

waveform, sr = torchaudio.load(str(audio_path))

audio = waveform.squeeze(0).cpu().numpy()

stft = librosa.stft(audio)#,
                    #n_fft=1024,
                    #win_length=400,
                    #hop_length=160)

harmonic_stft, percussive_stft = librosa.decompose.hpss(stft)

In [28]:
y_harm = librosa.istft(harmonic_stft)#,                     
                    #n_fft=1024,
                    #win_length=400,
                    #hop_length=160)
y_perc = librosa.istft(percussive_stft)#,                     
                    #n_fft=1024,
                    #win_length=400,
                    #hop_length=160)

In [32]:
import numpy as np
stereo = np.vstack((y_harm, y_perc))

    # Save as flac
torchaudio.save(
        "test.flac",
        torch.tensor(stereo),
        sample_rate=sr,
        format="flac"
    )

In [29]:
harmonic = torch.from_numpy(y_harm)
percussive = torch.from_numpy(y_perc)

In [30]:
harmonic_path = "harmonic.wav"
torchaudio.save(harmonic_path, harmonic, sr)

percussive_path = "percussive.wav"
torchaudio.save(percussive_path, percussive, sr)