<a href="https://colab.research.google.com/github/jmhuer/shift_invariant_dictionary_learning/blob/main/dictionarylearning_mididrums.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up enviroment

In [1]:
#@title Magenta & MIDI & Audio

print('Installing dependencies...')

!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth
!pip install -U -q magenta

import tensorflow_datasets as tfds
import tensorflow as tf

# Allow python to pick up the newly-installed fluidsynth lib.
# This is only needed for the hosted Colab environment.
import ctypes.util
orig_ctypes_util_find_library = ctypes.util.find_library
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return orig_ctypes_util_find_library(lib)
ctypes.util.find_library = proxy_find_library
  
print('Importing software libraries...')

import copy, warnings, librosa, numpy as np
warnings.filterwarnings("ignore", category=DeprecationWarning)


# Colab/Notebook specific stuff
import IPython.display
from IPython.display import Audio
from google.colab import files

# Magenta specific stuff
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
from magenta.models.music_vae import data
import note_seq
from note_seq import midi_synth
from note_seq.sequences_lib import concatenate_sequences
from note_seq.protobuf import music_pb2

# Define some functions

# If a sequence has notes at time before 0.0, scootch them up to 0
def start_notes_at_0(s):
  for n in s.notes:
    if n.start_time < 0:
      n.end_time -= n.start_time
      n.start_time = 0
  return s

def play(note_sequence, sf2_path='Standard_Drum_Kit.sf2'):  
  if sf2_path:
    audio_seq = midi_synth.fluidsynth(start_notes_at_0(note_sequence), sample_rate=44100, sf2_path=sf2_path)
    IPython.display.display(IPython.display.Audio(audio_seq, rate=44100))
  else:
    note_seq.play_sequence(start_notes_at_0(note_sequence), synth=note_seq.fluidsynth)

# Some midi files come by default from different instrument channels
# Quick and dirty way to set midi files to be recognized as drums
def set_to_drums(ns):
  for n in ns.notes:
    n.instrument=9
    n.is_drum = True
    
def unset_to_drums(ns):
  for note in ns.notes:
    note.is_drum=False
    note.instrument=0
  return ns

# quickly change the tempo of a midi sequence and adjust all notes
def change_tempo(note_sequence, new_tempo):
  new_sequence = copy.deepcopy(note_sequence)
  ratio = note_sequence.tempos[0].qpm / new_tempo
  for note in new_sequence.notes:
    note.start_time = note.start_time * ratio
    note.end_time = note.end_time * ratio
  new_sequence.tempos[0].qpm = new_tempo
  return new_sequence

def download(note_sequence, filename):
  note_seq.sequence_proto_to_midi_file(note_sequence, filename)
  files.download(filename)
  
def download_audio(audio_sequence, filename, sr):
  librosa.output.write_wav(filename, audio_sequence, sr=sr, norm=True)
  files.download(filename)
 
# Load some configs to be used later
dc_quantize = configs.CONFIG_MAP['groovae_2bar_humanize'].data_converter
dc_tap = configs.CONFIG_MAP['groovae_2bar_tap_fixed_velocity'].data_converter
dc_hihat = configs.CONFIG_MAP['groovae_2bar_add_closed_hh'].data_converter
dc_4bar = configs.CONFIG_MAP['groovae_4bar'].data_converter

# quick method for removing microtiming and velocity from a sequence
def get_quantized_2bar(s, velocity=0):
  new_s = dc_quantize.from_tensors(dc_quantize.to_tensors(s).inputs)[0]
  new_s = change_tempo(new_s, s.tempos[0].qpm)
  if velocity != 0:
    for n in new_s.notes:
      n.velocity = velocity
  return new_s

# quick method for turning a drumbeat into a tapped rhythm
def get_tapped_2bar(s, velocity=85, ride=False):
  new_s = dc_tap.from_tensors(dc_tap.to_tensors(s).inputs)[0]
  new_s = change_tempo(new_s, s.tempos[0].qpm)
  if velocity != 0:
    for n in new_s.notes:
      n.velocity = velocity
  if ride:
    for n in new_s.notes:
      n.pitch = 42
  return new_s

# quick method for removing hi-hats from a sequence
def get_hh_2bar(s):
  new_s = dc_hihat.from_tensors(dc_hihat.to_tensors(s).inputs)[0]
  new_s = change_tempo(new_s, s.tempos[0].qpm)
  return new_s


# Calculate quantization steps but do not remove microtiming
def quantize(s, steps_per_quarter=4):
  return note_seq.sequences_lib.quantize_note_sequence(s,steps_per_quarter)

# Destructively quantize a midi sequence
def flatten_quantization(s):
  beat_length = 60. / s.tempos[0].qpm
  step_length = beat_length / 4#s.quantization_info.steps_per_quarter
  new_s = copy.deepcopy(s)
  for note in new_s.notes:
    note.start_time = step_length * note.quantized_start_step
    note.end_time = step_length * note.quantized_end_step
  return new_s

# Calculate how far off the beat a note is
def get_offset(s, note_index):
  q_s = flatten_quantization(quantize(s))
  true_onset = s.notes[note_index].start_time
  quantized_onset = q_s.notes[note_index].start_time
  diff = quantized_onset - true_onset
  beat_length = 60. / s.tempos[0].qpm
  step_length = beat_length / 4#q_s.quantization_info.steps_per_quarter
  offset = diff/step_length
  return offset

def is_4_4(s):
  ts = s.time_signatures[0]
  return (ts.numerator == 4 and ts.denominator ==4)

def preprocess_4bar(s):
  return dc_4bar.from_tensors(dc_4bar.to_tensors(s).outputs)[0]

def preprocess_2bar(s):
  return dc_quantize.from_tensors(dc_quantize.to_tensors(s).outputs)[0]

def _slerp(p0, p1, t):
  """Spherical linear interpolation."""
  omega = np.arccos(np.dot(np.squeeze(p0/np.linalg.norm(p0)),
    np.squeeze(p1/np.linalg.norm(p1))))
  so = np.sin(omega)
  return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1

def tensor_4bar(s):
  try:
    return dc_4bar.to_tensors(s).outputs[0]
  except: 
    pass
   
def to_ProtoMidi_4bar(s):
  return dc_tap.from_tensors([s])[0]

print('Downloading drum samples...')
# Download a drum kit for playing drum midi
!gsutil -q -m cp gs://magentadata/soundfonts/Standard_Drum_Kit.sf2 .



Installing dependencies...
Selecting previously unselected package fluid-soundfont-gm.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../fluid-soundfont-gm_3.1-5.1_all.deb ...
Unpacking fluid-soundfont-gm (3.1-5.1) ...
Selecting previously unselected package libfluidsynth1:amd64.
Preparing to unpack .../libfluidsynth1_1.1.9-1_amd64.deb ...
Unpacking libfluidsynth1:amd64 (1.1.9-1) ...
Setting up fluid-soundfont-gm (3.1-5.1) ...
Setting up libfluidsynth1:amd64 (1.1.9-1) ...
Processing triggers for libc-bin (2.27-3ubuntu1.2) ...
/sbin/ldconfig.real: /usr/local/lib/python3.7/dist-packages/ideep4py/lib/libmkldnn.so.0 is not a symbolic link

[K     |████████████████████████████████| 1.4MB 9.2MB/s 
[K     |████████████████████████████████| 256kB 40.4MB/s 
[K     |████████████████████████████████| 215kB 52.0MB/s 
[K     |████████████████████████████████| 71kB 10.4MB/s 
[K     |████████████████████████████████| 3.6MB 51.1MB/s 
[K     |███████

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit


Downloading drum samples...


In [2]:
#@title Pytorch for DL

import torch.nn.functional as F
import torch.optim as optim
from torch import nn
import torch
from torch.nn.utils import weight_norm
import numpy as np

def get_model_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

# Fetch data

We load our 4 bars of 4/4 drum bear 

In [3]:
#@title Preprocess groove MIDI dataset

print("Download MIDI data...")

# # Load MIDI files from GMD with MIDI only (no audio) as a tf.data.Dataset
# dataset_2bar = tfds.as_numpy(tfds.load(
#     name="groove/2bar-midionly",
#     split=tfds.Split.VALIDATION,
#     try_gcs=True))

# dev_sequences = [quantize(note_seq.midi_to_note_sequence(features["midi"])) for features in dataset_2bar]
# _ = [set_to_drums(s) for s in dev_sequences]
# dev_sequences = [s for s in dev_sequences if is_4_4(s) and len(s.notes) > 0 and s.notes[-1].quantized_end_step > note_seq.steps_per_bar_in_quantized_sequence(s)]

dataset_4bar = tfds.as_numpy(tfds.load(
    name="groove/4bar-midionly",
    split=tfds.Split.TRAIN,
    try_gcs=True))

dev_sequences_4bar = [quantize(note_seq.midi_to_note_sequence(features["midi"])) for features in dataset_4bar]
_ = [set_to_drums(s) for s in dev_sequences_4bar]
dev_sequences_4bar = [s for s in dev_sequences_4bar if is_4_4(s) and len(s.notes) > 0 and s.notes[-1].quantized_end_step > note_seq.steps_per_bar_in_quantized_sequence(s)]


Download MIDI data...


In [4]:
#@title Preview MIDI data
tempo = 120
data = [change_tempo(start_notes_at_0(s),tempo) for s in dev_sequences_4bar]

for s in data[1:2]:
  play(s)

# Model Definitions


In [5]:
#@title KWTA


class SparsifyBase(nn.Module):
    def __init__(self, sparse_ratio=0.5):
        super(SparsifyBase, self).__init__()
        self.sr = sparse_ratio
        self.preact = None
        self.act = None
    def get_activation(self):
        def hook(model, input, output):
            self.preact = input[0].cpu().detach().clone()
            self.act = output.cpu().detach().clone()
        return hook
    def record_activation(self):
        self.register_forward_hook(self.get_activation())


class Sparsify1D_kactive(SparsifyBase):
    def __init__(self, k=1):
        super(Sparsify1D_kactive, self).__init__()
        self.k = k
    def forward(self, x):
        m = torch.zeros(x.shape).to(device)
        for i in range(self.k):
            indeces = x.topk(self.k, dim=1)[1][:, i]
            m += torch.mul(torch.zeros(x.shape).to(device).scatter(1, indeces.unsqueeze(1), 1), x)
            # print("\n hi", m )
        return m.double()

In [6]:
#@title TCN 

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 1)
        self.conv2.weight.data.normal_(0, 1)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 1)

    def forward(self, x):
        # print("block ", x.size())
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)


        # print("last layer conv", self.network[-1].conv2.weight.data[:,0,:].size())
        # print("last layer conv", self.network[-1].conv2.weight.data[:,0,:])

    def forward(self, x):
        return self.network(x)


In [36]:
#@title TCN - Autoeconder 

class TCNAutoencoder(nn.Module):
    def __init__(self, input_size, num_channels, kernel_size, dropout, wta_k):
        super(TCNAutoencoder, self).__init__()
        self.wta = Sparsify1D_kactive(k = wta_k)
        self.feature = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout).double()
        self.encoder = torch.nn.Conv1d(in_channels=27, out_channels=100, kernel_size=4, padding=0, bias=False, stride=4)
        self.decoder = torch.nn.ConvTranspose1d(in_channels=100, out_channels=27, kernel_size=4, padding=0, bias=False, stride=4)
        # self.encoder.weight.data.normal(1)
        self.relu1 = nn.ReLU()
        self.code = None
        # torch.nn.init.xavier_uniform(self.encoder.weight)
        # torch.nn.init.xavier_uniform(self.decoder.weight)
    def get_kernels(self):
        return self.decoder.weight.data[:,0,:]
    def feature_map(self, x):
        code = self.code
        return code
    def forward(self, x):
        # x needs to have dimension (N, C, L) in order to be passed into CNN
        # output = self.feature(x)
        self.code = self.wta(self.encoder(x))
        output = self.decoder(self.code )
        return output


# Model training 

In [8]:
#@title Prepare batch data

print("len of MIDI data: ", len(data))
train_data = torch.tensor([tensor_4bar(s).T for s in data if tensor_4bar(s) is not None]) #seems some are None
print("Train Dataset size: ", train_data.size())


len of MIDI data:  15637
Train Dataset size:  torch.Size([15591, 27, 64])


In [None]:
#@title GO

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using device: ", device)

model = TCNAutoencoder(input_size=train_data.shape[1], 
                       num_channels=[train_data.shape[1]*2], 
                       kernel_size=4, 
                       dropout=0.2, 
                       wta_k = 4).to(device).double()
print("TCNAutoencoder trainable parameters: ", get_model_parameters(model))

train_data = train_data.to(device).double()


loss_fn = torch.nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=.005,  betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True) ##this has weight decay just like you implemented
epochs = 1000
history = {"loss": []}
for i in range(epochs):
  optimizer.zero_grad()
  output = model(train_data)

  #decaying WTA
  if i % 30 == 0 and i != 0:
      model.wta.k = max(1, model.wta.k - 1)
      print("model.wta.k: ", model.wta.k)

  loss = loss_fn(output, train_data)
  loss.backward()
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
  optimizer.step()
  history["loss"].append(float(loss))
  if i % 1 == 0:
      print("Epoch : {} \t Loss : {} \t Code_Sparsity: {} ".format(i, round(float(loss),7), torch.count_nonzero(model.code)))







Using device:  cuda
TCNAutoencoder trainable parameters:  40824
Epoch : 0 	 Loss : 0.0692632 	 Code_Sparsity: 990912 
Epoch : 1 	 Loss : 0.0671644 	 Code_Sparsity: 990912 
Epoch : 2 	 Loss : 0.0652052 	 Code_Sparsity: 990912 
Epoch : 3 	 Loss : 0.0634611 	 Code_Sparsity: 990912 
Epoch : 4 	 Loss : 0.0618783 	 Code_Sparsity: 990912 
Epoch : 5 	 Loss : 0.0603059 	 Code_Sparsity: 990912 
Epoch : 6 	 Loss : 0.0586681 	 Code_Sparsity: 990912 
Epoch : 7 	 Loss : 0.0569638 	 Code_Sparsity: 990912 
Epoch : 8 	 Loss : 0.0552187 	 Code_Sparsity: 990912 
Epoch : 9 	 Loss : 0.0534638 	 Code_Sparsity: 990912 
Epoch : 10 	 Loss : 0.0517364 	 Code_Sparsity: 990912 
Epoch : 11 	 Loss : 0.0500593 	 Code_Sparsity: 990912 
Epoch : 12 	 Loss : 0.0484638 	 Code_Sparsity: 990912 
Epoch : 13 	 Loss : 0.0469675 	 Code_Sparsity: 990912 
Epoch : 14 	 Loss : 0.0455781 	 Code_Sparsity: 990912 
Epoch : 15 	 Loss : 0.0442919 	 Code_Sparsity: 990912 
Epoch : 16 	 Loss : 0.0430828 	 Code_Sparsity: 990912 
Epoch : 17 

# Model Evaluation 

In [17]:
#@title Test Recustruction 

index_example = 1

input = train_data[index_example:index_example+1,:,:] 
input2 = train_data[index_example+1:index_example+2,:,:] 

inputb = torch.cat([input, input, input2], axis=-1)
input = torch.cat([input2, input, input ], axis=-1)

# input = torch.rand((1,27,100)).to(device).double()
# import scipy.sparse as ss
# input = torch.tensor(ss.random(27, 100, density=.25, format='csr',data_rvs=np.ones,dtype='f').A)[None].to(device).double()

print(inputb.shape)
print("Original: ")
numpy_input = input[0,:,:].detach().cpu().numpy().T
original_midi = to_ProtoMidi_4bar(numpy_input)
play(original_midi)
print("\n")

print("interpolate: ")
#get sparse
output1 = model(input)
code1 = model.code
output2 = model(inputb)
code2 = model.code
#interpolate
new_code = (code1 + code2) / 2.0
output = model.decoder(new_code)


# print("Reconstruction: ")
# output = model(input)

output = output[0,:,:].detach().cpu().numpy().T
midi = to_ProtoMidi_4bar(output)
play(midi)
print("\n")


torch.Size([1, 27, 192])
Original: 




interpolate: 




