# DataSet Creation

## encoder Test

In [1]:
import zipfile
import os
import json
import librosa
import numpy as np

In [2]:
from audiodiffusion.audio_encoder import AudioEncoder

audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")

In [3]:
!ls

1701171760.ogg                       [1m[36mbeatmaps[m[m
1701776382.mc                        [1m[36mbeatmaps_unzip[m[m
1701849827.mc                        c5234ac9383bb69ebe8bd619a5fed4b9.mcz
1701864148.mc                        tmp.wav
1_dataset_creation.ipynb


In [4]:
import soundfile

x , sr = librosa.load("1701171760.ogg", sr=20000)
soundfile.write("tmp.wav", x, sr)

In [5]:
def get_audio_features(audio_file, bpm, position, offset):
    x , sr = librosa.load(audio_file, sr=20000)
    one_beat = 60 / bpm
    beat = position * one_beat / 4 + offset/1000
    
    start = beat - one_beat / 8
    end = beat + one_beat / 8
    
    if start < 0:
        start = 0
    
    start_index = int(sr * start)
    end_index = int(sr * end)
    
    
    soundfile.write("tmp.wav", x[start_index:end_index], sr)
    
    
    return audio_encoder.encode(["tmp.wav"]).numpy()[0].tolist()

In [6]:
get_audio_features("1701171760.ogg", 200, 300, 150)

[22.706905364990234,
 115.16878509521484,
 -96.89672088623047,
 96.65142059326172,
 -107.2148666381836,
 10.837530136108398,
 -1.0614581108093262,
 46.774322509765625,
 -108.610595703125,
 10.865479469299316,
 26.016265869140625,
 81.38760375976562,
 -18.64441680908203,
 159.84898376464844,
 53.63759994506836,
 128.12562561035156,
 50.83043670654297,
 24.8842716217041,
 -71.442138671875,
 -78.92030334472656,
 -4.975412845611572,
 12.416231155395508,
 -6.519041538238525,
 168.9194793701172,
 -144.8435516357422,
 289.6753845214844,
 -137.9021759033203,
 156.244384765625,
 172.641845703125,
 74.20262145996094,
 33.09909439086914,
 9.810715675354004,
 11.081274032592773,
 112.89115142822266,
 3.034332275390625,
 33.815467834472656,
 292.14361572265625,
 -125.14810180664062,
 -23.307891845703125,
 159.2989959716797,
 50.04045867919922,
 -81.59968566894531,
 64.35515594482422,
 -52.457332611083984,
 6.11663818359375,
 20.719179153442383,
 110.8714370727539,
 -144.26754760742188,
 196.5712738

## beatmaps filter

In [7]:
mcz_files = os.listdir("beatmaps")

In [8]:
count = 0
for mcz_file in mcz_files:
    if ".mcz" not in mcz_file:
        continue
    print(count, mcz_file)
    zFile = zipfile.ZipFile("beatmaps/" + mcz_file, "r")
    audio_file = ""
    mc_file = ""
    mc_data = {}
    for fileM in zFile.namelist():
        zFile.extract(fileM, './beatmaps_unzip')
    count += 1

0 miles away.mcz
1 fb98dd5bc1967db702ece05a84168d10.mcz
2 c5234ac9383bb69ebe8bd619a5fed4b9.mcz
3 初无改.mcz
4 Lose Control.mcz
5 4c9ddc941b482828c06751c6130fb394.mcz
6 NO.4 究极158秒.mcz
7 Lyrith -迷宮リリス-.mcz
8 Liv1ng 1n The F4st Lan3.mcz
9 efb3f1c6f7ab736db4f01e136b3f074a.mcz
10 d86c199af5abdaa72d2b0ae7a2043395.mcz
11 Cosmos Part.1.mcz
12 1de08fd5a808e09089530d115bb80ee9.mcz
13 Oshama Scramble! .mcz
14 ad1a2cf5b7c7da7039403faf80193bb0.mcz
15 7cd0467ce242192a6dd46b231f717799.mcz
16 望影の方舟Six.mcz
17 d7674e4e71b81539b2febfdd00dc5ba1.mcz
18 巴别塔.mcz
19 Ops：Limone.mcz
20 Loli Bomb (Speed up&Cut ver.).mcz
21 49cb1143d185f28e11e26ed546b03152.mcz
22 Cross the Edge.mcz
23 79eed254a0a09e48e436597d557910b6.mcz
24 The Multiverse(ft.oa).mcz


In [18]:
def check_dance3_beatmap(json_data):
    # 解析JSON数据
    data = json.loads(json_data)
    
    # 获取note列表
    notes = data['note']
    # 初始化最大column值
    max_column = -1
    
    # 遍历note，找到最大的column值
    for note in notes:
        if 'column' in note:
            if note['column'] > max_column:
                max_column = note['column']
    # 判断最大column值是否为6
    if max_column == 5:
        return True
    else:
        return False

In [19]:
mc_list = []
for fname in os.listdir("beatmaps_unzip"):
    if fname.endswith(".mc"):
        print(fname)
        with open(f"beatmaps_unzip/{fname}") as f:
            if check_dance3_beatmap(f.read()):
                mc_list.append(f"beatmaps_unzip/{fname}")

1701864148.mc
1652750857.mc
1696635166.mc
1695401434.mc
1663213760.mc
1664454762.mc
1652754796.mc
1688283329.mc
Σ╕¡τ║º.mc
1697989502.mc
1692538137.mc
1701776382.mc
1659541622.mc
1701849827.mc
1696637398.mc
1697869040.mc
1695401180.mc
Θ½ÿτ║º.mc
σê¥τ║º.mc
1703703071.mc
1688309494.mc
1700696948.mc
1669440136.mc
1688308799.mc
1669621668.mc
1694879415.mc
1669944775.mc
1696483497.mc
1648042355.mc
1698898550.mc
1669439231.mc
1694955584.mc
1692883511.mc
1703913735.mc


In [21]:
for fname in os.listdir("beatmaps_unzip/0"):
    if fname.endswith(".mc"):
        print(fname)
        with open(f"beatmaps_unzip/0/{fname}") as f:
            if check_dance3_beatmap(f.read()):
                mc_list.append(f"beatmaps_unzip/0/{fname}")

1644412671.mc
1587962568.mc
1689340595.mc
1659720056.mc
1588339363.mc
1689746687.mc
1658244850.mc
Mujinku-Vacuum Track#ADD8E6- (Dance Cube Hard Lv.18).mc
1587884900.mc
1667803736.mc


In [22]:
mc_list

['beatmaps_unzip/1701864148.mc',
 'beatmaps_unzip/1652750857.mc',
 'beatmaps_unzip/1696635166.mc',
 'beatmaps_unzip/1695401434.mc',
 'beatmaps_unzip/1663213760.mc',
 'beatmaps_unzip/1664454762.mc',
 'beatmaps_unzip/1652754796.mc',
 'beatmaps_unzip/1688283329.mc',
 'beatmaps_unzip/Σ╕¡τ║º.mc',
 'beatmaps_unzip/1697989502.mc',
 'beatmaps_unzip/1692538137.mc',
 'beatmaps_unzip/1701776382.mc',
 'beatmaps_unzip/1659541622.mc',
 'beatmaps_unzip/1701849827.mc',
 'beatmaps_unzip/1696637398.mc',
 'beatmaps_unzip/1697869040.mc',
 'beatmaps_unzip/1695401180.mc',
 'beatmaps_unzip/Θ½ÿτ║º.mc',
 'beatmaps_unzip/σê¥τ║º.mc',
 'beatmaps_unzip/1703703071.mc',
 'beatmaps_unzip/1688309494.mc',
 'beatmaps_unzip/1700696948.mc',
 'beatmaps_unzip/1688308799.mc',
 'beatmaps_unzip/1669621668.mc',
 'beatmaps_unzip/1694879415.mc',
 'beatmaps_unzip/1669944775.mc',
 'beatmaps_unzip/1696483497.mc',
 'beatmaps_unzip/1648042355.mc',
 'beatmaps_unzip/1698898550.mc',
 'beatmaps_unzip/1669439231.mc',
 'beatmaps_unzip/16949

In [23]:
len(mc_list)

43

In [None]:
# 将
def get_one_song_

In [4]:
def hcf(x, y):
   """该函数返回两个数的最大公约数"""
 
   # 获取最小值
   if x > y:
       smaller = y
   else:
       smaller = x
   _hcf = 1
   for i in range(1,smaller + 1):
       if((x % i == 0) and (y % i == 0)):
           _hcf = i
 
   return _hcf

In [6]:
def get_columns_list(notes):
    columns_list = []
    columns = {
        0: {},
        1: {},
        2: {},
        3: {},
        4: {},
        5: {},
    }

    for note in notes:
        if 'column' in note:
            beat = note['beat'][0]
            sub_beat = note['beat'][1]
            split_count = note['beat'][2]
            if split_count == 8:
                if (len(columns[0]) != 0) and (len(columns[1]) != 0) \
                    and (len(columns[2]) != 0)and (len(columns[3]) != 0):
                    columns_list.append(columns)
                    columns = {0: {}, 1: {}, 2: {}, 3: {}}
                continue
            if split_count != 4:
                if sub_beat == 0:
                    split_count = 4
                else:
                    _hcf = hcf(sub_beat, split_count)
                    sub_beat = int(sub_beat / _hcf)
                    split_count = int(split_count / _hcf)
                if split_count == 2:
                    sub_beat *= 2
                    split_count *= 2
                if split_count == 1:
                    sub_beat *= 4
                    split_count *= 4
                elif split_count != 4:
                    if (len(columns[0]) != 0) and (len(columns[1]) != 0) \
                        and (len(columns[2]) != 0)and (len(columns[3]) != 0):
                        columns_list.append(columns)
                        columns = {0: {}, 1: {}, 2: {}, 3: {}}
                    continue

            position = beat * 4 + sub_beat
            which_col = note["column"]
            if "endbeat" in note:
                end_position = note["endbeat"][0] * 4 + int(note["endbeat"][1] / note["endbeat"][2] * 4)
                if end_position == position:
                    columns[which_col][position] = 1
                else:
                    for i in range(position, end_position+1):
                        columns[which_col][i] = 2
            else:
                columns[which_col][position] = 1
    return columns_list

In [7]:
def get_columns_min_max(columns):
    _min = 10000000000
    _max = 0
    for col in columns.keys():
        column = columns[col]
        if max(column.keys()) > _max:
            _max = max(column.keys())

        if min(column.keys()) < _min:
            _min = min(column.keys())
    return _min, _max
    

In [8]:
def get_one_data(start, end, columns, bpm, x_, sr, offset):
    # 判断是否有beat
    x0 = []
    y0 = []
    
    # 判断note的键型
    x1 = []
    y1 = []
    
    
     # 判断是否有long_note
    x2 = []
    y2 = []
    
    # 判断long_note的键型
    x3 = []
    y3 = []
    for i in range(start, end):
        audio_features = get_audio_features(x_, sr, bpm, i, offset)
        x0.append(audio_features)
        x2.append(audio_features)
        beat_count = 0
        has_beat = False
        has_ln = False
        long_note_count = 0
        # column 0
        if i in columns[0]:
            if columns[0][i] == 1:
                has_beat = True
                beat_count += 1
            else:
                has_ln = True
                long_note_count += 1
            
        
        # column 1
        if i in columns[1]:
            if columns[1][i] == 1:
                has_beat = True
                beat_count += 2
            else:
                has_ln = True
                long_note_count += 2
            
        # column 2
        if i in columns[2]:
            if columns[2][i] == 1:
                has_beat = True
                beat_count += 2*2 
            else:
                has_ln = True
                long_note_count += 2*2
        
        # column 3
        if i in columns[3]:
            if columns[3][i] == 1:
                has_beat = True
                beat_count += 2*2*2
            else:
                has_ln = True
                long_note_count += 2*2*2
        
        y0.append(int(has_beat))
        
        if has_beat:
            x1.append(audio_features)
            y1.append(beat_count)
            
        y2.append(int(has_ln))
        
        if has_ln:
            x3.append(audio_features)
            y3.append(long_note_count)
        
    return x0, y0, x1, y1, x2, y2, x3, y3

In [14]:
count = 0
X0 = []
Y0 = []
X1 = []
Y1 = []
X2 = []
Y2 = []
X3 = []
Y3 = []

x_ = []
sr = 0
for mcz_file in mcz_files:
    if ".mcz" not in mcz_file:
        continue
    print(count, mcz_file)
    zFile = zipfile.ZipFile("original_beatmaps/" + mcz_file, "r")
    audio_file = ""
    mc_file = ""
    mc_data = {}
    for fileM in zFile.namelist():
        zFile.extract(fileM, './')
        if ".mc" in fileM:
            mc_file = fileM
            data = zFile.read(fileM).decode("utf-8")
            mc_data = json.loads(data)
            print("\t", mc_data["meta"]["version"], mc_data["time"], mc_data["note"][-1], "\n")
        elif ".ogg" in fileM:
            audio_file = fileM
        elif ".mp3" in fileM:
            audio_file = fileM
    # 解析的代码并不适应于变速歌曲
    if len(mc_data["time"]) != 1:
        print("此谱面有变速，暂不支持。\n\n")
        continue
    notes = mc_data["note"]
    notes = notes[:len(notes)-1]
    columns_list = get_columns_list(notes)
    bpm = mc_data["time"][0]['bpm']
    if "offset" not in mc_data["note"][-1]:
        offset = 0 
    else:
        offset = mc_data["note"][-1]["offset"]
    
    print(audio_file, bpm, offset)
    
    x_, sr = load_audio(audio_file)
    
    print(len(x_), sr, "\n")
    no_ln_count = 0
    for columns in columns_list:
        _min, _max = get_columns_min_max(columns)
        if (_max - _min) > 40:
            _now = _min
#             print(_now)
            while (_now + 40) < _max:
                x0, y0, x1, y1, x2, y2, x3, y3,  = get_one_data(_now, _now+40, columns, bpm, x_, sr, offset)
                X0.append(x0)
                Y0.append(y0)
                if len(y1) >= 1:    
                    X1.append(x1)
                    Y1.append(y1)
                
                
                if len(y3) > 0:
                    X2.append(x2)
                    Y2.append(y2)
                    X3.append(x3)
                    Y3.append(y3)
                elif (len(y1) >= 1) and (no_ln_count < 15):
                    X2.append(x2)
                    Y2.append(y2)
                    no_ln_count += 1
                
                _now += 38
            
#     break
    count += 1

0 怒槌.mcz
	 4K    Lv24 [{'beat': [0, 0, 1], 'bpm': 200.0}] {'beat': [0, 0, 1], 'sound': '怒槌_.ogg', 'vol': 100, 'offset': 280, 'type': 1} 

0/µÇÆµºî_.ogg 200.0 280
2999903 20000 

1 Halloween Party.mcz
	 4K Jack-o'-lantern Lv.24 [{'beat': [0, 0, 1], 'bpm': 155.0}] {'beat': [0, 0, 1], 'sound': '1507024911.ogg', 'vol': 100, 'offset': 360, 'type': 1} 

0/1507024911.ogg 155.0 360
2303333 20000 

2 Prayer.mcz
	 4K EXHAUST Lv.26 [{'beat': [0, 0, 1], 'bpm': 144.0}] {'beat': [0, 0, 1], 'sound': '1499579976.ogg', 'vol': 100, 'offset': 220, 'type': 1} 

0/1499579976.ogg 144.0 220
2290736 20000 

3 PUPA(1).mcz
	 4K Advanced Lv.22 [{'beat': [0, 0, 1], 'bpm': 202.0}] {'beat': [0, 0, 1], 'sound': '1488631483.ogg', 'vol': 100, 'offset': 1753, 'type': 1} 

0/1488631483.ogg 202.0 1753
2530772 20000 

4 Justified.mcz
	 4K Another Lv.25 [{'beat': [0, 0, 1], 'bpm': 185.0}] {'beat': [0, 0, 1], 'sound': '1528992388.ogg', 'vol': 100, 'offset': 257, 'type': 1} 

0/1528992388.ogg 185.0 257
2650239 20000 

5 snow

In [15]:
print(len(X0[0][0]))

64


In [16]:
print(len(X0))

419


In [17]:
print(len(X1))

414


In [18]:
print(len(X2))

363


In [19]:
print(len(X3))

198


In [20]:
with open("dataset.json", "w") as f:
    json.dump({
        "X0": X0,
        "Y0": Y0,
        "X1": X1,
        "Y1": Y1,
        "X2": X2,
        "Y2": Y2,
        "X3": X3,
        "Y3": Y3,
    }, f)

In [21]:
with open("glove/malody.txt", "w") as f:
    for y1 in Y1:
        strs = [str(i) for i in y1]
        line = " ".join(strs)
        print(line)
        f.write(line + "\n")

3 12 3 8 4 2 1 9 2 4 2 12 3 12 8 4 1 2 4 2 3 12 3
3 1 2 8 15 15 15 15 15 15 15 2 4 8 4 2 4 2 4 3
3 4 2 4 15 6 1 8 1 8 1 5 10 4 2 6 11 4 2 4 2 4 3 8 9 2 4 8
4 8 1 2 4 8 1 4 8 1 4 2 8 4 6 6 9 6 9 4 10 8 2
2 2 8 14 3 12 6 8 1 8 2 3 4 2 1 15 3 12 8 5 8 1 2 6
6 2 1 2 2 1 1 2 1 1 4 4 15 3 12 3
3 9 2 8 4 9 3 4 2 6 2 4 2 1 10 4 1 8 2 1 8 4 2 8 2 4 3 4 2 8
2 8 2 4 8 4 1 2 4 8 4 2 1 2 4 8 4 10 1 2 4 8 2 4 2 4 9 2 4 2 6 6 13
13 11 2 4 2 8 1 8 2 4 8 4 3 8 2 4 8 1 2 4 2 1 8 4 2 2 5 4
4 2 1 8 4 2 4 2 1 2 8 4 1 6 9 4 9 6 15 15 2 8 1 2 4 2 4 2 4
2 4 2 4 2 5 5 2 9
1 2 4 8
4 4 9
2 8 6 6 6 3 12 3 9 2 4 1 2 8 5 2 4 2
2 4 2 8 4 8 7 3 10 5 6 4 2 4 2 4 2 8 5 2 5 2 1
8 2 4 4 4 2 2 2 12 4 9 6 4 4 1
4 1 4 9 2 9 4 9 2 8 4 2 1 8 4 2 1 8 4 2 1 8 4 2 1 4 15 1 1
1 2 4 12 15 4 4 4 2 4 2 2 2 2 4 2 2 5 2 15
15 4 9 4 8 1 1 2 4 8 3 8 3 8 4 6 4 2
2 4 4 2 4 2 2 2 4 2 4 2 4 4 3 12 3
1 8 8 9 6 9 1 4 2 8 2 4 1 4 1 4 2 1 2 4 8 4 2 1 2 4 8 4 2 1
2 1 8 2 1 4 8 4 2 1 2 4 8 4 2 1 9 2 4 2 6 2 4 2 12 3 1 3 2 4
2 4 2 1 2 4 8 3 12 12 

In [22]:
with open("glove/malody2.txt", "w") as f:
    for y3 in Y3:
        strs = [str(i) for i in y3]
        if len(strs) > 0:
            line = " ".join(strs)
            print(line)
            f.write(line + "\n")

1 1 1 1 1
9 9 9 9 9 1 1 1 1 8 8
8 8 8 8
8 8 8 8 8 8 8 8 4 4 4 2 2 2 2 2 2
6 6 6 6
10 10 10 10 10 10 15 5 5 5 5 5 5 5 15 10 10 10 10 10 10 10 10 10 10 10 10 9 9
9 9 9 9 9 9 15 6 6 6 6 6 6 6 14 8 8 8 8 8 8 8 8 8 8 8 8 10 10 10 10 10 10 15 5
15 5 5 5 5 5 5 5 15 10 10 10 10 10 10 10 15 5 5 5 1 5 5 5 5 5 5 15 10 10 10 10 10 10 10
10 10 15 5 5 5 5 5 5 5 5
8 12 6 3
6 3 1 9 9 9 9 9 9 9 9 9 9 9 9 9 2 2 2 2 6 4 4 4 4 9 9 9 9 9
9
9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9
8 12 6 3 1 1 3 6 12 8 8 8 8 8
8 8 8 8 8 8 8 8 8 8 8 8 8 1 1 1 1 1 1 1 1 1 1 1 1 1 1 3 6 12
6 12 8 8 12 6 3 1
9 9
9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 15 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 7 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 9 8 8 8 8 8 8 8 8 8 8 8 8 8 8 8 10 2 2 1 1 1 2 2
2 2 2 8 8 8 4 4 4 4 4 4 2 2 2 2 2 6 4 4 4 4 4 8 8 8 8
8 8 8 8 8 8 12 12 12 12 12 12 12 12 14 14 14 14 14 14 14 14 15 15 15 15 15 15 15 15 15 1 1
1 1 1 1 1 1 1 1 3 3 3 3 3 3 3 3 7 6 6 6 6 6 6 6 14 8 8 8 8 8 8 8 12 12 12 12 12 4 4 4
4 4 6 6 6 6 6 6 6 6 15 9 9 9 9 9 9 9 