In [1]:
import utils
import drumset
import tokeniser
import model

# silence mido and prettymidi warnings
import warnings
warnings.filterwarnings("ignore")

import os

MIDI_DIR = os.path.join(os.getcwd(), "midi_files")
BUFFER_SIZE = 9

# DRUM_LOOKUP_TABLE = {
#     35: 0,      # bass drum
#     36: 0,      # bass drum
#     37: 1,      # stick
#     38: 2,      # snare
#     40: 2,      # snare
#     39: 3,      # clap
#     41: 4,      # tom 0
#     43: 5,      # tom 1
#     45: 6,      # tom 2
#     47: 7,      # tom 3
#     48: 8,      # tom 4
#     50: 9,      # tom 5
#     42: 10,      # hh closed
#     44: 11,      # hh pedal
#     46: 12,      # hh open
#     49: 13,      # crash
#     57: 13,      # crash
#     51: 14,      # ride 1
#     59: 14,      # ride 2
#     53: 15,      # ride bell
#     55: 16,      # splash
#     -999: 17,    # intensity (ONLY USED FOR TOKEN COUNT CALC)
# }

DRUM_LOOKUP_TABLE = {
    35: 0,      # bass drum
    36: 0,      # bass drum
    38: 1,      # snare
    40: 1,      # snare
    42: 2,      # hh closed
    44: 2,      # hh pedal
    46: 3,      # hh open
    49: 4,      # crash
    57: 4,      # crash
    51: 5,      # ride 1
    59: 5,      # ride 2
    55: 4,      # splash
    -999: 6,    # intensity (ONLY USED FOR TOKEN COUNT CALC)
}

TOKEN_COUNT = len(set(DRUM_LOOKUP_TABLE.values()))

print("There are " + str(TOKEN_COUNT) + " tokens.")

There are 7 tokens.


In [2]:
FILE_PATH = "saved.txt"
USE_SAVED_FILE = True
FILES_TO_PROCESS = -1

if not USE_SAVED_FILE:
    tokeniser.tokenize_drums(FILES_TO_PROCESS, MIDI_DIR)

inputs = tokeniser.load_inputs_from_file(FILE_PATH, BUFFER_SIZE)

print("Total size: " + str(len(inputs)))

Total size: 15577


In [3]:
batch_size = 50

dataset = drumset.DrumSetSequence(inputs, batch_size, BUFFER_SIZE, TOKEN_COUNT, DRUM_LOOKUP_TABLE)

In [4]:
nn = model.Model(dataset, BUFFER_SIZE, TOKEN_COUNT, batch_size, DRUM_LOOKUP_TABLE)

Creating model with 8 x 7 inputs and 1 x 6 outputs.
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm (LSTM)                 (None, 8, 512)            1064960   
                                                                 
 lstm_1 (LSTM)               (None, 8, 512)            2099200   
                                                                 
 lstm_2 (LSTM)               (None, 512)               2099200   
                                                                 
 flatten (Flatten)           (None, 512)               0         
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 batch_normalization (BatchN  (None, 512)              2048      
 ormalization)                                                   
    

In [5]:
nn.load()

In [None]:
nn.train(10)

In [6]:
nn.save()

In [None]:
nn.plot()

In [38]:
INPUT_PATH = os.path.join(MIDI_DIR, "Electric_Light_Orchestra/Telephone_Line.mid")

chain = tokeniser.add_drum_track(
    nn,
    BUFFER_SIZE,
    TOKEN_COUNT,
    DRUM_LOOKUP_TABLE,
    INPUT_PATH,
    # utils.generate_random_input(BUFFER_SIZE, TOKEN_COUNT),
    utils.generate_basic_drums(BUFFER_SIZE, TOKEN_COUNT),
    # utils.generate_count_in(BUFFER_SIZE, TOKEN_COUNT),
    # utils.generate_blank_input(BUFFER_SIZE, TOKEN_COUNT),
    cutoff=0.2
)

tokeniser.remove_drum_track(INPUT_PATH)
tokeniser.append_drum_track("temp.mid", chain)
utils.clean_temp()

Length: 662
Latest input of [[0. 0. 1. 0. 0. 0. 0.]]
Full input: 
[[[1. 0. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [0. 1. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [1. 0. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]
  [0. 1. 1. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0.]]]
Index 0 : ['35', '42']
Latest input of [[ 1.   0.   1.   0.   0.   0.  -1.8]]
Full input: 
[[[ 0.   0.   1.   0.   0.   0.   0. ]
  [ 0.   1.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 1.   0.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 0.   1.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 1.   0.   1.   0.   0.   0.  -1.8]]]
Index 1 : ['42']
Latest input of [[ 0.   0.   1.   0.   0.   0.  -1.8]]
Full input: 
[[[ 0.   1.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 1.   0.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 0.   1.   1.   0.   0.   0.   0. ]
  [ 0.   0.   1.   0.   0.   0.   0. ]
  [ 1.  