/
create_drumgizmo_kit.py
executable file
·355 lines (309 loc) · 18.5 KB
/
create_drumgizmo_kit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#!/usr/bin/env python3
#*******************************************************************************
# Copyright (c) 2022-2024
# Author(s): Volker Fischer
#*******************************************************************************
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
#*******************************************************************************
import os
import gc
import wave
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from scipy.io import wavfile
# conversion settings
raspi_optimized_drumkit = False#True#
disable_positional_sensing_support = False#True#
only_master_channels_per_instrument = False#True#
do_shorten_samples = False#True#
################################################################################
# CONFIGURATION AND INITIALIZATIONS ############################################
################################################################################
# example file names of source files: source_samples/snare/snare_rimshot_0_channel1.wav <- position 0, channel 1
# source_samples/snare/snare_rimshot_1_channel1.wav <- position 1, channel 1
# source_samples/ride/ride_bell_channel7.wav <- no positional sensing, channel 7
kit_name = "PearlMMX" # avoid spaces
kit_description = "Pearl MMX drum set with positional sensing support"
channel_names = ["KDrum", "Snare", "Hihat", "Tom1", "Tom2", "Tom3", "OHLeft", "OHRight"]
# instruments: [instrument_name, master_channel(s), MIDI_note(s), group, hi-hat threshold, min_strike_len, threshold]
instruments = [["kick", ["KDrum", "OHLeft", "OHRight"], [36], "", "", 0.1, 15], \
["snare", ["Snare", "OHLeft", "OHRight"], [38], "", "", 0.08, 16], \
["snare_rimshot", ["Snare", "OHLeft", "OHRight"], [40], "", "", 0.3, 15], \
["hihat_closed", ["Hihat", "OHLeft", "OHRight"], [22, 26], "hihat", "80", 0.18, 20], \
["hihat_closedtop", ["Hihat", "OHLeft", "OHRight"], [42, 46], "hihat", "80", 0.2, 20], \
["hihat_open", ["Hihat", "OHLeft", "OHRight"], [26], "hihat", "0", 0.7, 23], \
["hihat_open1", ["Hihat", "OHLeft", "OHRight"], [26], "hihat", "55", 0.7, 23], \
["hihat_open2", ["Hihat", "OHLeft", "OHRight"], [26], "hihat", "27", 0.7, 23], \
["hihat_opentop", ["Hihat", "OHLeft", "OHRight"], [46], "hihat", "0", 0.7, 24], \
["hihat_open1top", ["Hihat", "OHLeft", "OHRight"], [46], "hihat", "55", 0.7, 21], \
["hihat_open2top", ["Hihat", "OHLeft", "OHRight"], [46], "hihat", "27", 0.7, 23], \
["hihat_foot", ["Hihat", "OHLeft", "OHRight"], [44], "hihat", "", 0.1, 23], \
["tom1", ["Tom1", "OHLeft", "OHRight"], [48, 50], "", "", 0.2, 15], \
["tom2", ["Tom2", "OHLeft", "OHRight"], [45, 47], "", "", 0.2, 15], \
["tom3", ["Tom3", "OHLeft", "OHRight"], [43, 58], "", "", 0.4, 15], \
["crash", ["OHLeft", "OHRight"], [55, 52], "", "", 0.5, 15], \
["crash_top", ["OHLeft", "OHRight"], [49, 57], "", "", 0.4, 15], \
["ride", ["OHRight", "OHLeft"], [51], "", "", 1.0, 15], \
["ride_bell", ["OHRight", "OHLeft"], [53], "", "", 1.0, 16], \
["ride_side", ["OHRight", "OHLeft"], [59], "", "", 1.0, 15]]
#channel_names = ["SnareL"] # for calibrating dynamic in Drumgizmo
#instruments = [["rolandsnare", ["SnareL"], [38], "", 0.03, 23]]
source_samples_dir_name = "source_samples" # root directory of recorded source samples
fade_out_percent = 30 # % of sample at the end is faded out
thresh_from_max_for_start = 20 # dB
add_samples_at_start = 20 # additional samples considered at strike start (also defines the fade-in time period)
min_time_next_strike_s = 0.5 # minimum time in seconds between two different strikes
# TEST for optimizing the algorithms, only use one instrument
#instruments = [instruments[9]]
# settings for optimized drum kit for Raspberry Pi (with limited RAM)
if raspi_optimized_drumkit:
disable_positional_sensing_support = True
do_shorten_samples = True
only_master_channels_per_instrument = True
for instrument in instruments: # remove some instruments for lowest possible memory requirement
if ("tom2" in instrument or "ride_side" in instrument or "crash_top" in instrument or "hihat_opentop" in instrument or
"hihat_open1top" in instrument or "hihat_open2top" in instrument or "hihat_closedtop" in instrument):
instruments.remove(instrument)
for instrument in instruments: # assign now missing MIDI notes to remaining instruments
if "ride" in instrument:
instrument[2].append(59)
if "crash" in instrument:
instrument[2].append(49)
if "hihat_open" in instrument:
instrument[2].append(46)
if "hihat_closed" in instrument:
instrument[2].append(42)
instrument[2].append(46)
if "hihat_open1" in instrument:
instrument[2].append(46)
if "hihat_open2" in instrument:
instrument[2].append(46)
for instrument in instruments:
##############################################################################
# FILE NAME HANDLING #########################################################
##############################################################################
samples_dir_name = "samples" # compatible to other Drumgizmo kits
instrument_name = instrument[0]
instrument_path = kit_name + "/" + instrument_name + "/"
instrument_sample_path = instrument_path + samples_dir_name + "/"
base_instrument_name = instrument_name.split("_")[0]
print(instrument_name)
# check if instrument has positional sensing support and extract position indexes
positions = []
for file_name in os.listdir(source_samples_dir_name + "/" + base_instrument_name):
if instrument_name in file_name:
file_name_parts = file_name.split(".")[0].split("_")
# position information always second last item and one character long
if len(file_name_parts) > 2 and len(file_name_parts[-2]) == 1:
positions.append(int(file_name_parts[-2]))
positions = sorted(list(dict.fromkeys(positions))) # remove duplicates and sort
positions = [-1] if not positions else positions # if no positions, use -1 (i.e. no positional support)
# optionally, disable positional sensing support
if disable_positional_sensing_support and len(positions) > 1:
positions = [0]
sample_powers = [[]] * len(positions)
sample_strikes = [[]] * len(positions)
for p in positions:
##############################################################################
# READ WAVE FORMS ############################################################
##############################################################################
num_channels = len(channel_names)
sample = [[]] * num_channels
pos_str = "_" + str(p) if p >= 0 else ""
for i in range(0, num_channels):
with wave.open(source_samples_dir_name + "/" + base_instrument_name + "/" + \
instrument_name + pos_str + "_channel" + str(i + 1) + ".wav", "r") as file:
sample_rate = file.getframerate() # assuming all wave have the same rate
sample[i] = np.frombuffer(file.readframes(-1), np.int16) # assuming 16 bit
##############################################################################
# WAVE FORM ANALYSIS #########################################################
##############################################################################
master_channel = channel_names.index(instrument[1][0]) # first main channel is master
min_strike_len = int(instrument[5] * sample_rate) # calculate minimum strike length in samples
min_time_next_strike = int(min_time_next_strike_s * sample_rate)
# find samples which are above the threshold
x = np.square(sample[master_channel].astype(float))
threshold = np.power(10, instrument[6] / 10)
above_thresh = x > threshold
# remove oscillating by filling short gaps
first_below_idx = -100 * sample_rate
for i in range(1, len(above_thresh)):
if not above_thresh[i] and above_thresh[i - 1]:
first_below_idx = i
if above_thresh[i] and not above_thresh[i - 1]:
if i - first_below_idx + 1 < min_time_next_strike:
above_thresh[first_below_idx:i + 1] = True
# remove very short on periods
first_above_idx = -100 * sample_rate
for i in range(1, len(above_thresh)):
if above_thresh[i] and not above_thresh[i - 1]:
first_above_idx = i
if not above_thresh[i] and above_thresh[i - 1]:
if i - first_above_idx < min_strike_len:
above_thresh[first_above_idx:i + 1] = False
strike_start = np.argwhere(np.diff(above_thresh.astype(float)) > 0)
strike_end = np.argwhere(np.diff(above_thresh.astype(float)) < 0)
# extract individual samples from long sample vector and analyze/process
sample_powers[p] = [[]] * len(strike_start)
sample_strikes[p] = [[]] * len(strike_start)
strike_cut_pos = np.full(len(x), False)
for i in range(0, len(strike_start)):
# fix start of strike: find first sample going left of the maximum peak which
# is below a threshold which is defined 20 dB below the maximum
x_cur_strike_master = x[strike_start[i][0]:strike_end[i][0] + 1]
strike_mean = np.mean(x_cur_strike_master)
strike_max = np.max(x_cur_strike_master)
below_max_thresh = np.power(10, -thresh_from_max_for_start / 10) # -[20] dB from maximum peak
index = 0
while x[strike_start[i] + index] < strike_max * below_max_thresh:
index += 1
strike_start[i] += index - add_samples_at_start # add some offset
# fix end position: compare regions of min strike length if next region power
# is below previous region
index = 0
while strike_end[i] + index + 2 * min_strike_len < len(x) and \
10 * np.log10(np.sum(x[strike_end[i][0] + index:strike_end[i][0] + index + min_strike_len])) - \
10 * np.log10(np.sum(x[strike_end[i][0] + index + min_strike_len:strike_end[i][0] + index + 2 * min_strike_len])) > \
0.0: # dB difference in power of the current two regions
index += min_strike_len
strike_end[i] += index
# estimate power from master channel using the maximum value
sample_powers[p][i] = strike_max / 32768 / 32768 # assuming 16 bit
# optionally, shorten the samples to save some memory, i.e., modify strike_end
if do_shorten_samples:
x_cur_strike_master = x[strike_start[i][0]:strike_end[i][0] + 1]
strike_max = np.max(x_cur_strike_master)
last_index = np.max(np.argwhere(x_cur_strike_master > strike_max / np.power(10, 40 / 10))) # 40 dB below max
mod_strike_end = strike_start[i][0] + last_index;
else:
mod_strike_end = strike_end[i][0]
# extract sample data of current strike
sample_strikes[p][i] = np.zeros((mod_strike_end - strike_start[i][0] + 1, num_channels), np.int16)
for c in range(0, num_channels):
strike_cut_pos[strike_start[i][0]:mod_strike_end + 1].fill(True) # for debugging
sample_strikes[p][i][:, c] = sample[c][strike_start[i][0]:mod_strike_end + 1]
# audio fade-in at the beginning
sample_strikes[p][i][:add_samples_at_start, c] = np.int16(sample_strikes[p][i][:add_samples_at_start, c].astype(float) * np.arange(1, add_samples_at_start + 1, 1) / add_samples_at_start)
# audio fade-out at the end
sample_len = len(sample_strikes[p][i][:, c])
fade_start = int(sample_len * (1 - fade_out_percent / 100))
fade_len = sample_len - fade_start
sample_strikes[p][i][fade_start:, c] = np.int16(sample_strikes[p][i][fade_start:, c].astype(float) * np.arange(fade_len + 1, 1, -1) / fade_len)
#print(sample_powers[p][i])
#plt.plot(sample_strikes[p][i][:, master_channel])
#plt.show()
if len(instruments) == 1: # if only one instrument is selected, we assume we want to debug plot
mpl.rcParams['agg.path.chunksize'] = 10000 # needed for long wave forms to avoid Exceeded cell block limit in Agg
plt.plot(10 * np.log10(np.abs(x)))
plt.plot([0, len(x)], 10 * np.log10([threshold, threshold]))
plt.plot(10 * np.log10(np.max(x)) * strike_cut_pos)
plt.plot(strike_start, [10 * np.log10(np.max(x))] * len(strike_start), 'o', color='tab:brown')
plt.title(instrument_name + pos_str)
plt.show()
plt.close("all") # to prevent a memory leak
plt.close() # to prevent a memory leak
gc.collect() # to prevent a memory leak
##############################################################################
# WRITE WAVE FORMS AND INSTRUMENT XML FILE ###################################
##############################################################################
instrument_xml = ET.Element("instrument")
instrument_xml.set("version", "2.0")
instrument_xml.set("name", instrument_name)
samples_xml = ET.SubElement(instrument_xml, "samples")
# get indexes of main channels of this instrument
instrument_master_channel_indexes = []
for idx, channel_name in enumerate(channel_names):
if channel_name in instrument[1]:
instrument_master_channel_indexes.append(idx)
for p in positions:
power_sort_indexes = np.argsort(sample_powers[p])
for i in range(0, len(sample_strikes[p])):
strike_index = power_sort_indexes[i] # sort waves by power
print(str(i) + ": " + str(10 * np.log10(sample_powers[p][strike_index])))
# write multi-channel wave file
sample_file_name = str(i + 1) + "-" + instrument_name
if len(positions) > 1:
sample_file_name += "-" + str(p)
os.makedirs(instrument_sample_path, exist_ok=True)
if only_master_channels_per_instrument:
wavfile.write(instrument_sample_path + sample_file_name + ".wav", sample_rate, sample_strikes[p][strike_index][:, instrument_master_channel_indexes])
else:
wavfile.write(instrument_sample_path + sample_file_name + ".wav", sample_rate, sample_strikes[p][strike_index])
# write XML content for current sample
sample_xml = ET.SubElement(samples_xml, "sample")
if len(positions) > 1:
sample_xml.set("position", str(p))
sample_xml.set("name", instrument_name + "-" + str(i + 1))
# make sure result is positive by adding 100 dB (max. assumed dynamic)
sample_xml.set("power", "{:.19f}".format(10 * np.log10(sample_powers[p][strike_index]) + 100))
for j, channel_name in enumerate(channel_names):
if only_master_channels_per_instrument:
if channel_name in instrument[1]:
audiofile_xml = ET.SubElement(sample_xml, "audiofile")
audiofile_xml.set("channel", channel_name)
audiofile_xml.set("file", samples_dir_name + "/" + sample_file_name + ".wav")
audiofile_xml.set("filechannel", str(instrument[1].index(channel_name) + 1))
else:
audiofile_xml = ET.SubElement(sample_xml, "audiofile")
audiofile_xml.set("channel", channel_name)
audiofile_xml.set("file", samples_dir_name + "/" + sample_file_name + ".wav")
audiofile_xml.set("filechannel", str(j + 1))
# write instrument XML file
tree_xml = ET.ElementTree(instrument_xml)
ET.indent(instrument_xml, space="\t", level=0)
tree_xml.write(instrument_path + instrument_name + ".xml", encoding="utf-8", xml_declaration="True")
################################################################################
# CREATE DRUM KIT XML FILE #####################################################
################################################################################
drumkit_xml = ET.Element("drumkit")
drumkit_xml.set("name", kit_name)
drumkit_xml.set("description", kit_description)
drumkit_xml.set("samplerate", str(sample_rate))
drumkit_xml.set("islogpower", "true")
channels_xml = ET.SubElement(drumkit_xml, "channels")
for channel_name in channel_names:
channel_xml = ET.SubElement(channels_xml, "channel")
channel_xml.set("name", channel_name)
instruments_xml = ET.SubElement(drumkit_xml, "instruments")
for instrument in instruments:
instrument_xml = ET.SubElement(instruments_xml, "instrument")
instrument_xml.set("name", instrument[0])
if instrument[3]:
instrument_xml.set("group", instrument[3])
instrument_xml.set("file", instrument[0] + "/" + instrument[0] + ".xml")
for channel_name in channel_names:
channelmap_xml = ET.SubElement(instrument_xml, "channelmap")
channelmap_xml.set("in", channel_name)
channelmap_xml.set("out", channel_name)
if channel_name in instrument[1]:
channelmap_xml.set("main", "true")
tree_xml = ET.ElementTree(drumkit_xml)
ET.indent(drumkit_xml, space="\t", level=0)
os.makedirs(kit_name, exist_ok=True)
tree_xml.write(kit_name + "/" + kit_name + ".xml", encoding="utf-8", xml_declaration="True")
################################################################################
# CREATE MIDI MAP XML FILE #####################################################
################################################################################
midimap_xml = ET.Element("midimap")
for instrument in instruments:
for midi_note in instrument[2]:
map_xml = ET.SubElement(midimap_xml, "map")
map_xml.set("note", str(midi_note))
map_xml.set("instr", instrument[0])
if instrument[4]:
map_xml.set("controlthresh", instrument[4])
tree_xml = ET.ElementTree(midimap_xml)
ET.indent(midimap_xml, space="\t", level=0)
tree_xml.write(kit_name + "/Midimap.xml", encoding="utf-8", xml_declaration="True")