# Note Placement Model Test

In [17]:
from sys import platform
from pathlib import Path
import warnings
import torch
import torch.nn as nn
import torchaudio
is_macos = platform == 'darwin'

torch.set_printoptions(sci_mode=False)

In [26]:
class BeatmapDataset:
    def __init__(self, audio_dir, osu_dir, data_dir):
        audio_fns = sorted([p for p in Path(audio_dir).glob('**/*') if p.suffix in {'.mp3', '.wav', '.ogg'}])
        osu_fns = sorted(list(Path(osu_dir).rglob('*.osu')))
        self.data_fns = list(zip(audio_fns, osu_fns))
        self.data_dir = data_dir

    def get_beat_phase(self, note_time, beat_length:float, offset:int):
        return ((note_time - offset) % beat_length) / beat_length

    def parse_beatmap(self, fn):
        # Extract beatmap information from an .osu file.
        # RETURNS: FloatTensor: num_notes X 4(time, key_number, note_type, beat_phase). -1 if error.

        with open(fn, mode='r', encoding='utf-8') as f:
            raw_content = f.read().splitlines()

        timing_points = []
        # Read everything until next section
        i = raw_content.index('[TimingPoints]') + 1
        while raw_content[i] != '' and raw_content[i][0] != '[':
            timing_points.append(raw_content[i])
            i += 1

        # Check if multiple BPMs exist
        beat_lengths = {float(tp.split(',')[1]) for tp in timing_points if float(tp.split(',')[1]) > 0}
        if len(beat_lengths) > 1:
            warnings.warn(f'Multiple BPMs in file {fn.name}: skipping conversion.')
            return -1

        offset = int(timing_points[0].split(',')[0])
        beat_length = beat_lengths.pop()

        beatmap_start_index = raw_content.index('[HitObjects]')
        beatmap = raw_content[beatmap_start_index + 1:]

        obj_list = []
        xpos_set = set()

        for obj in beatmap:
            obj_split = obj.split(',')
            time = int(obj_split[2])
            xpos = int(obj_split[0])
            xpos_set.add(xpos)

            if obj_split[3] != '1': # If note is long note...
                end_time = int(obj_split[5].split(':', 1)[0])
                obj_list.append([time, xpos, 2, self.get_beat_phase(time, beat_length, offset)])
                obj_list.append([end_time, xpos, 3, self.get_beat_phase(time, beat_length, offset)])
            else:
                obj_list.append([time, xpos, 1, self.get_beat_phase(time, beat_length, offset)])

        xpos_list = sorted(xpos_set)
        xpos2num = {xpos: num for num, xpos in enumerate(xpos_list)}

        obj_list = [[obj[0], xpos2num[obj[1]], obj[2], obj[3]] for obj in obj_list]
        obj_tensor = torch.tensor(obj_list, dtype=torch.float32)

        # Sort by note time in ascending order
        obj_tensor = obj_tensor[obj_tensor[:, 0].argsort()]

        return obj_tensor
    
    # def shift_cat_audio(self, y, shift_amount):
    #     tensor_list = []
    #     for shift in range(shift_amount, 0, -1):
    #         shifted = y.

    def convert_spec(self, n_fft_list:list=[1024, 2048, 4096], hop_ms:int=10, n_mels:int=80):

        # TODO: split melspec into 15 sample windows
        # TODO: include beat phase

        melspec_converters = [ torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_fft=n_fft, hop_length=int(44100*(hop_ms/1000)), f_max=11000, n_mels=80) for n_fft in n_fft_list ]
        db_converter = torchaudio.transforms.AmplitudeToDB()

        for audio_fn, osu_fn in self.data_fns:
            if is_macos:
                y, sr = torchaudio.load(audio_fn, backend='ffmpeg')
            else:
                y, sr = torchaudio.load(audio_fn)
            y = y.mean(dim=0)
            
            if sr != 44100:
                print(f'Sampling rate of file {audio_fn.name} is {sr}: Resampling to 44100.')
                y = torchaudio.functional.resample(y, sr, 44100)
            
            # Multiple-timescale STFT
            specs = []
            for converter in melspec_converters:
                melspec = converter(y)
                specs.append(db_converter(melspec))
            specs = torch.stack(specs, dim=-1)

            # Parse beatmap
            print(osu_fn)
            beatmap = self.parse_beatmap(osu_fn)

            torch.save({'specs': specs, 'beatmap': beatmap}, (self.data_dir / osu_fn.name).with_suffix('.pt'))

    # TODO: getitem

In [27]:
specs, beatmap = torch.load('data/converted/er.pt').values()
specs.shape

FileNotFoundError: [Errno 2] No such file or directory: 'data/converted/er.pt'

In [29]:
dataset = BeatmapDataset('data/', 'data/', 'data/')
dataset.parse_beatmap('data/er.osu').shape

torch.Size([388, 4])