# Note Placement Model Test

In [7]:
from sys import platform
from pathlib import Path
import torch
import torch.nn as nn
import torchaudio
is_macos = platform == 'darwin'

In [9]:
class BeatmapDataset:
    def __init__(self, audio_dir, osu_dir, data_dir):
        audio_fns = sorted([p for p in Path(audio_dir).glob('**/*') if p.suffix in {'.mp3', '.wav', '.ogg'}])
        osu_fns = sorted(list(Path(osu_dir).rglob('*.osu')))
        self.data_fns = list(zip(audio_fns, osu_fns))
        self.data_dir = data_dir

    def parse_beatmap(self, fn):

        # TODO: calculate beat phase

        with open(fn, mode='r', encoding='utf-8') as f:
            raw_content = f.read().splitlines()
        start_index = raw_content.index('[HitObjects]')
        beatmap = raw_content[start_index + 1:]

        obj_list = []
        xpos_list = set()
        xpos2num = {}

        for obj in beatmap:
            obj_split = obj.split(',')
            time = int(obj_split[2])
            xpos = int(obj_split[0])
            is_longnote = obj_split[3] != '1'
            end_time = obj_split[5].split(':', 1)[0] if is_longnote else 0

            obj_list.append([time, xpos, int(is_longnote), int(end_time)])
            xpos_list.add(xpos)

        xpos_list = sorted(xpos_list)
        xpos2num = {xpos: num for num, xpos in enumerate(xpos_list)}

        obj_list = [[obj[0], xpos2num[obj[1]], obj[2], obj[3]] for obj in obj_list]

        obj_tensor = torch.LongTensor(obj_list)

        return obj_tensor
    
    def shift_cat_audio(self, y, shift_amount):
        tensor_list = []
        for shift in range(shift_amount, 0, -1):
            shifted = y.

    def convert_spec(self, n_fft_list:list=[1024, 2048, 4096], hop_ms:int=10, n_mels:int=80):

        # TODO: split melspec into 15 sample windows
        # TODO: include beat phase

        melspec_converters = [ torchaudio.transforms.MelSpectrogram(sample_rate=44100, n_fft=n_fft, hop_length=int(44100*(hop_ms/1000)), f_max=11000, n_mels=80) for n_fft in n_fft_list ]
        db_converter = torchaudio.transforms.AmplitudeToDB()

        for audio_fn, osu_fn in self.data_fns:
            if is_macos:
                y, sr = torchaudio.load(audio_fn, backend='ffmpeg')
            else:
                y, sr = torchaudio.load(audio_fn)
            y = y.mean(dim=0)
            
            if sr != 44100:
                print(f'Sampling rate of file {audio_fn.name} is {sr}: Resampling to 44100.')
                y = torchaudio.functional.resample(y, sr, 44100)
            
            # Multiple-timescale STFT
            specs = []
            for converter in melspec_converters:
                melspec = converter(y)
                specs.append(db_converter(melspec))
            specs = torch.stack(specs, dim=-1)

            # Parse beatmap
            print(osu_fn)
            beatmap = self.parse_beatmap(osu_fn)

            torch.save({'specs': specs, 'beatmap': beatmap}, (self.data_dir / osu_fn.name).with_suffix('.pt'))

    # TODO: getitem

In [11]:
dataset = BeatmapDataset(Path('data/'), Path('data/'), Path('data/converted'))
dataset.convert_spec()

data/er.osu


In [20]:
specs, beatmap = torch.load('data/converted/er.pt').values()
specs.shape

torch.Size([80, 8778, 3])