# Note Placement Model Test

In [19]:
from sys import platform
from pathlib import Path
import torch
import torch.nn as nn
import torchaudio
from tqdm.auto import tqdm

is_macos = platform == 'darwin'
torch.set_printoptions(sci_mode=False)

In [30]:
class BeatmapConverter:
    def __init__(self,
                audio_dir:Path, 
                osu_dir:Path, 
                data_dir:Path, 
                n_fft_list:list=[1024, 2048, 4096],
                hop_ms:int=10,
                context_window_size = 7):
        self.audio_fns = sorted([p for p in Path(audio_dir).glob('**/*') if p.suffix in {'.mp3', '.wav', '.ogg'}])
        self.osu_fns = sorted(list(Path(osu_dir).rglob('*.osu')))
        self.data_dir = data_dir
        self.hop_ms = hop_ms
        self.hop_length = int(44100 * (hop_ms / 1000))
        self.melspec_converters = [ torchaudio.transforms.MelSpectrogram(sample_rate=44100,
                                                                        n_fft=n_fft, 
                                                                        hop_length=self.hop_length, 
                                                                        f_max=11000, 
                                                                        n_mels=80,
                                                                        power=2) 
                                                                        for n_fft in n_fft_list ]
        self.context_window_size = context_window_size

    def round_base(self, input, base):
        if isinstance(input, float) or isinstance(input, int):
            return base * round(input / base)
        elif isinstance(input, torch.Tensor):
            return (input / base).round() * base
        else:
            raise NotImplementedError

    def get_beat_phase(self, note_time, offset:float, beat_length:float):
        beat_phase = ((note_time - offset) % beat_length) / beat_length
        return self.round_base(beat_phase, 0.02083) # 1/48

    def parse_beatmap(self, fn):
        # Extract beatmap information from an .osu file.
        # RETURNS: FloatTensor: num_notes X 4(time, key_number, note_type, beat_phase).
        #          -1 if error.

        with open(fn, mode='r', encoding='utf-8') as f:
            raw_content = f.read().splitlines()

        # Get difficulty (round to .2)
        difficulty = fn.name.split('-')[1]
        difficulty = self.round_base(float(difficulty[:1] + '.' + difficulty[1:]), 0.2)

        timing_points = []
        # Read everything until next section
        i = raw_content.index('[TimingPoints]') + 1
        while raw_content[i] != '' and raw_content[i][0] != '[':
            timing_points.append(raw_content[i])
            i += 1

        # Check if multiple BPMs exist
        beat_lengths = {float(tp.split(',')[1]) for tp in timing_points if float(tp.split(',')[1]) > 0}
        if len(beat_lengths) > 1:
            print(f'Multiple BPMs in file {fn.name}: skipping conversion.')
            return -1, -1, -1, -1, -1

        offset = float(timing_points[0].split(',')[0])
        beat_length = beat_lengths.pop()

        beatmap_start_index = raw_content.index('[HitObjects]')
        beatmap = raw_content[beatmap_start_index + 1:]

        obj_list = []
        xpos_set = set()

        for obj in beatmap:
            obj_split = obj.split(',')
            time = int(obj_split[2])
            xpos = int(obj_split[0])
            xpos_set.add(xpos)

            if obj_split[3] != '1': # If note is long note...
                end_time = int(obj_split[5].split(':', 1)[0])
                obj_list.append([time, xpos, 2, self.get_beat_phase(time, offset, beat_length)])
                obj_list.append([end_time, xpos, 3, self.get_beat_phase(time, offset, beat_length)])
            else:
                obj_list.append([time, xpos, 1, self.get_beat_phase(time, offset, beat_length)])

        xpos_list = sorted(xpos_set)
        xpos2num = {xpos: num for num, xpos in enumerate(xpos_list)}

        obj_list = [[obj[0], xpos2num[obj[1]], obj[2], obj[3]] for obj in obj_list]
        obj_tensor = torch.tensor(obj_list, dtype=torch.float32)

        # Sort by note time in ascending order
        obj_tensor = obj_tensor[obj_tensor[:, 0].argsort()]

        return len(xpos_list), obj_tensor, offset, beat_length, difficulty

    def convert_audio(self, y, offset, beat_length):
        # Converts audio into 3-channel mel-spectrogram with context windows.
        # INPUT: waveform of sr=44100
        # OUTPUT: Tensor([num_timesteps, len_window * 2 + 1, 80, 3])

        # Multiple-timescale STFT
        specs = []
        for converter in self.melspec_converters:
            melspec = converter(y)
            specs.append(torch.log(melspec.T))
        specs = torch.stack(specs, dim=-1) # len X 80 X 3
        min_value = torch.min(specs)

        # Shift specs and stack
        shifted_specs = []
        for i in range(self.context_window_size, -self.context_window_size-1, -1):
            shifted = torch.roll(specs, i, 0)
            if i > 0:
                shifted[:i, :, :] = min_value
            elif i < 0:
                shifted[i:, :, :] = min_value
            shifted_specs.append(shifted)
        shifted_specs = torch.stack(shifted_specs, dim=1) # len X 15 X 80 X 3

        # Create beat phase tensor
        beat_phase = self.get_beat_phase(torch.arange(len(shifted_specs)) * self.hop_ms, offset, beat_length)

        return shifted_specs, beat_phase

    def find_file_by_stem(self, fn_list, stem):
        for fn in fn_list:
            if fn.stem == stem:
                return fn
        return -1

    def convert(self):
        # TODO: include beat phase

        converted_audio_dir = self.data_dir / 'converted_audio/'
        converted_audio_dir.mkdir(exist_ok=True)
        converted_dirs = []

        for osu_fn in tqdm(self.osu_fns):
            # Parse beatmap notes
            num_keys, beatmap, offset, beat_length, difficulty = self.parse_beatmap(osu_fn)
            if (num_keys == -1):
                continue

            converted_dir = self.data_dir / f'{num_keys}keys/'
            if converted_dir not in converted_dirs:
                converted_dir.mkdir(exist_ok=True)
                converted_dirs.append(converted_dir)

            # Check if corresponding audio has already been converted
            audio_stem = osu_fn.stem.split('-')[0]
            converted_audio_fn = self.find_file_by_stem(list(converted_audio_dir.glob('*.pt')), audio_stem)

            if converted_audio_fn == -1:
                # Load audio with OS-specific backend
                audio_fn = self.find_file_by_stem(self.audio_fns, audio_stem)
                if audio_fn == -1:
                    print(f'Audio file not found: {audio_stem}')
                    continue

                if is_macos:
                    y, sr = torchaudio.load(audio_fn, backend='ffmpeg')
                else:
                    y, sr = torchaudio.load(audio_fn)

                y = y.mean(dim=0)
                if sr != 44100:
                    print(f'Sampling rate of file {audio_fn.name} is {sr}: Resampling to 44100.')
                    y = torchaudio.functional.resample(y, sr, 44100)

                converted_audio, beat_phase = self.convert_audio(y, offset, beat_length)
                torch.save({'audio': converted_audio, 'beat_phase': beat_phase}, (converted_audio_dir / audio_fn.name).with_suffix('.pt'))

            else:
                converted_audio, beat_phase = torch.load(converted_audio_fn).values()
        
            torch.save({'audio': converted_audio, 'beat_phase': beat_phase, 'beatmap': beatmap, 'difficulty': difficulty}, (converted_dir / osu_fn.name).with_suffix('.pt'))

# TODO: define dataset

In [31]:
audio_dir = Path('../osu_dataset/original/')
osu_dir = Path('../osu_dataset/original/')
data_dir = Path('../osu_dataset/')

converter = BeatmapConverter(audio_dir, osu_dir, data_dir)
converter.convert()

  0%|          | 0/261 [00:00<?, ?it/s]

tensor([[[[-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          ...,
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583]],

         [[-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          ...,
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583]],

         [[-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          ...,
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583],
          [-37.8583, -37.8583, -37.8583]],

         ...,

         [[-14.9272, -11.8856,  -9.8394],
          [-14.0139, -12.9589,  -8.8404],
          [-15.2072, -14.5076,  -8.2949],
          ...,
          [-15.0016, -13.5348, -10.95