In [100]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import train_test_split
import h5py
import librosa
import mido

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [101]:
class Tokenizer:
    def __init__(self, time_count: int = 26, note_count: int = 110, vel_count: int = 2):
        self.val_to_velo_id: dict = {i: i + 1 for i in range(vel_count)}
        self.val_to_note_id: dict = {i: i + 1 +
                                     vel_count for i in range(note_count)}
        self.val_to_time_id: dict = {
            i: i + 1 + vel_count + note_count for i in range(time_count)}

        self.velo_id_to_val: dict = {
            v: k for k, v in self.val_to_velo_id.items()}
        self.note_id_to_val: dict = {
            v: k for k, v in self.val_to_note_id.items()}
        self.time_id_to_val: dict = {
            v: k for k, v in self.val_to_time_id.items()}

        self.id_to_token: dict = {
            **{self.val_to_velo_id[i]: f'velo_{i}' for i in self.val_to_velo_id},
            **{self.val_to_note_id[i]: f'note_{i}' for i in self.val_to_note_id},
            **{self.val_to_time_id[i]: f'time_{i}' for i in self.val_to_time_id},
            0: '<pad>',
            vel_count + note_count + time_count + 1: '<bos>',
            vel_count + note_count + time_count + 2: '<eos>'
        }

        self.token_to_id: dict = {v: k for k, v in self.id_to_token.items()}

    def tuple_to_ids(self, tuple: tuple):
        return [self.val_to_time_id[tuple[0]], self.val_to_note_id[tuple[1]], self.val_to_velo_id[tuple[2]]]

    def tuple_list_to_ids(self, tuple_list: list[tuple]):
        l = []
        for t in tuple_list:
            l.extend(self.tuple_to_ids(t))
        return l

    def id_list_to_tuple_list(self, id_list: list[int]):
        l = []
        for i in range(0, len(id_list), 3):
            if i + 3 > len(id_list):
                break
            t = []
            for j, d in enumerate([self.time_id_to_val, self.note_id_to_val, self.velo_id_to_val]):
                if min(d) <= id_list[i+j] <= max(d):
                    t.append(d[id_list[i+j]])
                else:
                    t.append(-1)
            l.append(tuple(t))
        return l

In [102]:
class ProgressBar:
    def __init__(self, total, length=40):
        self.total = total
        self.length = length
        self.current = 0
        self.start_time = time.time()

    def update(self, step=1):
        self.current += step
        progress = self.current / self.total
        filled_length = int(self.length * progress)
        bar = '=' * filled_length + '-' * (self.length - filled_length)
        if self.current == 1 or self.current == self.total or filled_length > int(self.length * ((self.current - step) / self.total)):
            print(f'\r|{bar}| {self.current}/{self.total} ({progress:.2%})  {time.time() - self.start_time:.1f}s', end='')

    # def finish(self):
    #     self.update(self.total - self.current)
    #     print()

In [103]:
class DatasetGenerator():
    def __init__(
        self,
        wav_mid_df: pd.DataFrame,
        output_path: str,
        chunk_size: int = 128,
        step_size: int = 64,
        sample_rate: int = 12_800,
        n_fft: int = 2048,
        hop_length: int = 256,
        n_mels: int = 512,
        override: bool = False,
        # transform: transforms.Compose = None
    ):
        self.wav_mid_df = wav_mid_df
        self.output_path = output_path
        self.chunk_size = chunk_size
        self.step_size = step_size
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        # self.transform = transform
        self.time_per_frame = hop_length / sample_rate

        if override and os.path.exists(self.output_path):
            os.remove(self.output_path)

    def _get_spectogram(self, wave_path: str) -> np.ndarray:
        samples, sr = librosa.load(wave_path, sr=self.sample_rate)

        mel_spectrogram = librosa.feature.melspectrogram(
            y=samples, sr=sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)

        return mel_spectrogram_db

    def _get_message_df(self, midi_path) -> pd.DataFrame:
        mid = mido.MidiFile(midi_path)
        note_msg = map(lambda x: (x, x.time), mid.tracks[-1])
        note_df = pd.DataFrame(note_msg, columns=['other', 'time'])

        note_df['time'] = note_df['time'].cumsum()
        note_df = note_df[note_df['other'].apply(
            lambda x: x.type == 'note_on')]
        note_df['note'] = note_df['other'].apply(lambda x: x.note)
        note_df['velocity'] = note_df['other'].apply(lambda x: x.velocity)
        note_df = note_df.drop(columns=['other']).reset_index(drop=True)
        return note_df
    
    def _get_spec_chunk(self, spec: np.ndarray, i: int) -> np.ndarray:
        spec_chunk = spec[:, i * self.step_size : i * self.step_size + self.chunk_size]
        return spec_chunk

    def _get_midi_chunk(self, midi: pd.DataFrame, i: int) -> np.ndarray:
        start_time = i * self.step_size * self.time_per_frame * 1000
        end_time = (i * self.step_size + self.chunk_size) * self.time_per_frame * 1000

        midi_chunk = midi[(midi['time'] >= start_time) & (midi['time'] < end_time)].copy()
        midi_chunk['time'] = midi_chunk['time'] - int(start_time)
        return midi_chunk.to_numpy()
    
    def _get_chunk_meta(self, midi_chunk: np.ndarray):
        if midi_chunk.shape[0] == 0:
            return np.array([0, 0, 0])
        meta = [midi_chunk[0].min(), midi_chunk[0].max(), len(midi_chunk)]
        return meta
    
    def _get_chunks(self, wave_path: str, midi_path: str) -> pd.DataFrame:
        spec = self._get_spectogram(wave_path)
        midi = self._get_message_df(midi_path)

        length = int((spec.shape[1] - self.chunk_size) / self.step_size)
        chunks = []
        for i in range(length):
            spec_chunk = self._get_spec_chunk(spec, i)
            midi_chunk = self._get_midi_chunk(midi, i)
            
            chunks.append((spec_chunk, midi_chunk))
        return pd.DataFrame(chunks, columns=['spec', 'midi'])

    def _save_chunk(self, spec_chunk: np.ndarray, midi_chunk: np.ndarray, piece_name: str):
        with h5py.File(self.output_path, 'a') as h5:
            if piece_name not in h5:
                piece_group = h5.create_group(piece_name)
            else:
                piece_group = h5[piece_name]

            chunk_idx = len(piece_group)
            chunk_group = piece_group.create_group(f'chunk_{chunk_idx}')

            chunk_group.create_dataset('image', data=spec_chunk, compression='gzip')
            chunk_group.create_dataset('midi', data=midi_chunk, compression='gzip')
            chunk_group.create_dataset('meta', data=self._get_chunk_meta(midi_chunk), compression='gzip')

    def generate(self):
        output_dir = os.path.dirname(self.output_path)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)

        progress_bar = ProgressBar(len(self.wav_mid_df))
        for i, (idx, row) in enumerate(self.wav_mid_df.iterrows()):
            wave_path = row['wav']
            midi_path = row['midi']
            piece_name = os.path.basename(wave_path).split('.')[0]

            if os.path.exists(wave_path) and os.path.exists(midi_path):
                chunks = self._get_chunks(wave_path, midi_path)
                for _, (spec_chunk, midi_chunk) in chunks.iterrows():
                    self._save_chunk(spec_chunk, midi_chunk, piece_name)
            else:
                print(f"File not found: {wave_path} or {midi_path}")

            progress_bar.update()


    # def _getSimpleSpectogram(self, wave_file: str, notes: list[int]):
    #     spec = self.getSpectogram(wave_file)
    #     mel_frequencies = librosa.mel_frequencies(n_mels=spec.shape[0], fmin=0, fmax=self.sample_rate / 2)

    #     note_freqs = pd.read_csv('note_freqs.csv').values
    #     indexes = np.array([mel_frequencies[np.abs(mel_frequencies - val).argmin()] for val in note_freqs])
    #     indexes = np.array([np.where(mel_frequencies == val)[0][0] for val in indexes])
    #     indexes = indexes[notes]

    #     spec = spec[indexes]
    #     return spec, mel_frequencies[indexes]

In [104]:
csv_path = 'wav_midi.csv'
h5_path = f'dataset_transformed.h5'
folder_path = 'waves'

df = pd.read_csv(csv_path)
df['wav'] = df['wav'].apply(lambda x: os.path.join(folder_path, x))
df['midi'] = df['midi'].apply(lambda x: os.path.join(folder_path, x))
df = df.sample(10)

transform = transforms.Compose([
    transforms.Resize((256, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

dataset_gen = DatasetGenerator(df, h5_path, override=True)
dataset_gen.generate()



In [105]:
tokenizer = Tokenizer()
with h5py.File(h5_path, 'r') as h5:
    progress_bar = ProgressBar(len(h5.keys()))
    for piece_name in h5.keys():
        piece_group = h5[piece_name]
        for chunk_name in piece_group.keys():
            chunk_group = piece_group[chunk_name]
            spec = chunk_group['image'][:]
            midi = chunk_group['midi'][:]
            meta = chunk_group['meta'][:]
            spec = Image.fromarray(spec)
            spec = transform(spec)
            midi[:, 0] = midi[:, 0] // 100
            midi[:, 2] = (midi[:, 2] > 0).astype(np.uint8)
            midi = midi.tolist()
            midi = tokenizer.tuple_list_to_ids(midi)
            midi.insert(0, tokenizer.token_to_id['<bos>'])
            midi.append(tokenizer.token_to_id['<eos>'])
            midi.extend([tokenizer.token_to_id['<pad>']] * (1100 - len(midi)))

        progress_bar.update()

