# Note Placement Model Test

In [10]:
from sys import platform
from pathlib import Path
import torch
import torch.nn as nn
import torchaudio
from time import localtime, strftime

is_macos = platform == 'darwin'
torch.set_printoptions(sci_mode=False)
RANDOM_SEED = 1

In [126]:
class BeatmapConverter:
    def __init__(self,
                audio_path:Path, 
                osu_path:Path, 
                output_path:Path,
                n_fft_list:list=[1024, 2048, 4096],
                hop_ms:int=10,
                context_window_size = 7,
                beat_division = 48):
        
        self.audio_path = audio_path
        self.osu_path = osu_path
        self.output_path = output_path
        self.hop_ms = hop_ms
        self.hop_length = int(44100 * (hop_ms / 1000))
        self.melspec_converters = [ torchaudio.transforms.MelSpectrogram(sample_rate=44100,
                                                                        n_fft=n_fft, 
                                                                        hop_length=self.hop_length, 
                                                                        f_max=11000, 
                                                                        n_mels=80,
                                                                        power=2) 
                                                                        for n_fft in n_fft_list ]
        self.context_window_size = context_window_size
        self.beat_division_length = round(1 / beat_division, 5)

    def round_base(self, input, base):
        if isinstance(input, float) or isinstance(input, int):
            return base * round(input / base)
        elif isinstance(input, torch.Tensor):
            return (input / base).round() * base
        else:
            raise NotImplementedError

    def get_beat_phase(self, note_time, offset:float, beat_length:float):
        beat_phase = ((note_time - offset) % beat_length) / beat_length
        return self.round_base(beat_phase, self.beat_division_length)

    def parse_beatmap(self, fn):
        # Extract beatmap information from an .osu file.
        # RETURNS: FloatTensor: num_notes X 4(time, key_number, note_type, beat_phase).
        #          -1 if error.

        with open(fn, mode='r', encoding='utf-8') as f:
            raw_content = f.read().splitlines()

        # Get difficulty (round to .2)
        difficulty = fn.name.split('-')[1]
        difficulty = self.round_base(float(difficulty[:1] + '.' + difficulty[1:]), 0.2)

        timing_points = []
        # Read everything until next section
        i = raw_content.index('[TimingPoints]') + 1
        while raw_content[i] != '' and raw_content[i][0] != '[':
            timing_points.append(raw_content[i])
            i += 1

        # Check if multiple BPMs exist
        beat_lengths = {float(tp.split(',')[1]) for tp in timing_points if float(tp.split(',')[1]) > 0}
        if len(beat_lengths) > 1:
            return -1, -1, -1, -1, -1

        offset = float(timing_points[0].split(',')[0])
        beat_length = beat_lengths.pop()

        beatmap_start_index = raw_content.index('[HitObjects]')
        beatmap = raw_content[beatmap_start_index + 1:]

        obj_list = []
        xpos_set = set()

        for obj in beatmap:
            obj_split = obj.split(',')
            time = int(obj_split[2])
            xpos = int(obj_split[0])
            xpos_set.add(xpos)

            if obj_split[3] != '1': # If note is long note...
                end_time = int(obj_split[5].split(':', 1)[0])
                obj_list.append([time, xpos, 2, self.get_beat_phase(time, offset, beat_length)])
                obj_list.append([end_time, xpos, 3, self.get_beat_phase(time, offset, beat_length)])
            else:
                obj_list.append([time, xpos, 1, self.get_beat_phase(time, offset, beat_length)])

        xpos_list = sorted(xpos_set)
        xpos2num = {xpos: num for num, xpos in enumerate(xpos_list)}

        obj_list = [[obj[0], xpos2num[obj[1]], obj[2], obj[3]] for obj in obj_list]
        obj_tensor = torch.tensor(obj_list, dtype=torch.float32)

        # Sort by note time in ascending order
        obj_tensor = obj_tensor[obj_tensor[:, 0].argsort()]

        return len(xpos_list), obj_tensor, offset, beat_length, difficulty

    def convert_audio(self, y, offset, beat_length, eps=1e-9):
        # Converts audio into 3-channel mel-spectrogram with context windows.
        # INPUT: waveform of sr=44100
        # OUTPUT: Tensor([num_timesteps, len_window * 2 + 1, 80, 3])

        # Multiple-timescale STFT
        specs = []
        for converter in self.melspec_converters:
            melspec = converter(y + eps)
            specs.append(torch.log(melspec.T))
        specs = torch.stack(specs, dim=-1) # len X 80 X 3
        min_value = torch.min(specs)

        # Shift specs and stack
        shifted_specs = []
        for i in range(self.context_window_size, -self.context_window_size-1, -1):
            shifted = torch.roll(specs, i, 0)
            if i > 0:
                shifted[:i, :, :] = min_value
            elif i < 0:
                shifted[i:, :, :] = min_value
            shifted_specs.append(shifted)
        shifted_specs = torch.stack(shifted_specs, dim=1) # len X 15 X 80 X 3

        # Create beat phase tensor
        beat_phase = self.get_beat_phase(torch.arange(len(shifted_specs)) * self.hop_ms, offset, beat_length)

        return shifted_specs, beat_phase

    def find_file_by_stem(self, fn_list, stem):
        for fn in fn_list:
            if fn.stem == stem:
                return fn
        return -1

    def convert(self):
        audio_fns = sorted([p for p in self.audio_path.glob('**/*') if p.suffix in {'.mp3', '.wav', '.ogg'}])
        osu_fns = sorted(list(self.osu_path.glob('*.osu')))

        converted_audio_path = self.output_path / 'converted_audio/'
        converted_audio_path.mkdir(exist_ok=True)
        num_keys_paths = []

        excluded_audio_path = self.audio_path / 'excluded_audio/'
        excluded_osu_path = self.osu_path / 'excluded_osu'
        excluded_audio_path.mkdir(exist_ok=True)
        excluded_osu_path.mkdir(exist_ok=True)
        log = open(self.output_path / (strftime('conversion-log-%Y-%m-%d-%H-%M-%S', localtime()) + '.txt'), 'w')

        for i, osu_fn in enumerate(osu_fns):
            print(f'Converting: {osu_fn.name} ({i} out of {len(osu_fns)})', end='\r', flush=True)
            # Parse beatmap notes
            num_keys, beatmap, offset, beat_length, difficulty = self.parse_beatmap(osu_fn)
            if (num_keys == -1): # Error: beatmap has multiple BPMs
                log.write(f'Multiple BPMs in file {osu_fn.name}: skipping conversion.\n')
                osu_fn.rename(excluded_osu_path / osu_fn.name)
                osu_fns.remove(osu_fn)
                audio_fn = self.find_file_by_stem(audio_fns, audio_stem)
                if audio_fn != -1:
                    audio_fn.rename(excluded_audio_path / audio_fn.name)
                    audio_fns.remove(audio_fn)
                continue

            # Categorize created data samples into num_keys
            num_keys_path = self.output_path / f'{num_keys}keys/'
            if num_keys_path not in num_keys_paths:
                num_keys_path.mkdir(exist_ok=True)
                num_keys_paths.append(num_keys_path)

            # Check if corresponding audio has already been converted
            audio_stem = osu_fn.stem.split('-')[0]
            converted_audio_fn = self.find_file_by_stem(list(converted_audio_path.glob('*.pt')), audio_stem)

            if converted_audio_fn == -1:
                # Load audio with OS-specific backend
                audio_fn = self.find_file_by_stem(audio_fns, audio_stem)
                if audio_fn == -1:
                    log.write(f'Audio file not found: {audio_stem}\n')
                    osu_fn.rename(excluded_osu_path / osu_fn.name)
                    osu_fns.remove(osu_fn)
                    continue

                if is_macos:
                    y, sr = torchaudio.load(audio_fn, backend='ffmpeg')
                else:
                    y, sr = torchaudio.load(audio_fn)

                y = y.mean(dim=0)
                if sr != 44100:
                    log.write(f'Sampling rate of file {audio_fn.name} is {sr}: Resampling to 44100.\n')
                    y = torchaudio.functional.resample(y, sr, 44100)

                converted_audio, beat_phase = self.convert_audio(y, offset, beat_length)
                torch.save({'audio': converted_audio, 'beat_phase': beat_phase}, (converted_audio_path / audio_fn.name).with_suffix('.pt'))

            else:
                converted_audio, beat_phase = torch.load(converted_audio_fn).values()
        
            torch.save({'audio': converted_audio, 'beat_phase': beat_phase, 'beatmap': beatmap, 'difficulty': difficulty}, (num_keys_path / osu_fn.name).with_suffix('.pt'))
            
        print('Conversion finished.')
        log.close()

In [127]:
audio_path = Path('../osu_dataset/original/')
osu_path = Path('../osu_dataset/original/')
output_path = Path('../osu_dataset/')

converter = BeatmapConverter(audio_path, osu_path, output_path)
converter.convert()

Conversion finished.08-7.osu (191 out of 192)


In [128]:
class OsuDataset:
    def __init__(self, data_path):
        self.fns = sorted(list(data_path.glob('*.pt')))
    
    def __len__(self):
        return len(self.fns)
    
    def __getitem__(self, idx):
        audio, beat_phase, beatmap, difficulty = torch.load(self.fns[idx]).values()
        return audio, beat_phase, beatmap, difficulty

In [129]:
base_set = OsuDataset(Path('../osu_dataset/4keys/'))
generator = torch.Generator().manual_seed(RANDOM_SEED)
train_set, valid_set = torch.utils.data.random_split(base_set, [0.8, 0.2], generator)

(tensor([[[[-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           ...,
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257]],
 
          [[-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           ...,
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257]],
 
          [[-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           ...,
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257],
           [-69.9257, -69.9257, -69.9257]],
 
          ...,
 
          [[-31.6068, -29.1098, -28.4166],
           [-30.6935, -64.4477, -65.1039],
           [-66.0696, -65.5805, -65.5557],
           ...,
       

In [133]:
obj = torch.load('../osu_dataset/4keys/1019836-189-4.pt')
dummy_sample = obj['audio'][1000]
dummy_sample

tensor([[[ 9.3334, 10.8949, 11.5670],
         [10.2467, 11.9233, 13.5350],
         [ 8.8077, 10.8777, 11.6971],
         ...,
         [-0.4921,  0.8950,  2.3700],
         [ 0.6974,  1.9928,  3.1390],
         [-0.5809,  0.8524,  2.4934]],

        [[ 9.4073, 10.8571, 11.5506],
         [10.3205, 11.9594, 13.5993],
         [ 9.6719, 11.0073, 11.7434],
         ...,
         [-0.5775,  0.7027,  2.0301],
         [ 0.2384,  1.4593,  2.8345],
         [-0.7299,  0.5360,  1.9387]],

        [[ 9.4109, 10.8644, 11.4607],
         [10.3242, 11.9373, 13.5545],
         [ 9.3284, 10.9826, 11.7277],
         ...,
         [-1.2293,  0.2476,  1.5936],
         [-1.3848,  0.5875,  2.2645],
         [-1.5987, -0.0726,  1.3843]],

        ...,

        [[ 5.1333,  6.2458,  8.3451],
         [ 6.0466,  7.5554,  9.7704],
         [ 5.2473,  7.0267,  8.2157],
         ...,
         [-4.3466, -3.0635,  3.5511],
         [-4.3352, -2.8544,  2.4479],
         [-4.3879, -2.9491,  2.5068]],

        [[

In [134]:
dummy_sample.shape

torch.Size([15, 80, 3])

In [None]:
class AudioEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            
        )