<a href="https://colab.research.google.com/github/jmhuer/shift_invariant_dictionary_learning/blob/main/dictionarylearning_midipiano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up enviroment


###Setup Environment and Dependencies

In [1]:
#@title Clone/Install all dependencies
!git clone https://github.com/jmhuer/shift_invariant_dictionary_learning
!pip install tqdm
!pip install progress
!pip install pretty-midi
!pip install pypianoroll
!pip install matplotlib
!pip install librosa
!pip install scipy
!pip install pillow
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!pip install mir_eval
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

%cd /content/shift_invariant_dictionary_learning/maestro

Cloning into 'shift_invariant_dictionary_learning'...
remote: Enumerating objects: 293, done.[K
remote: Counting objects: 100% (293/293), done.[K
remote: Compressing objects: 100% (280/280), done.[K
remote: Total 293 (delta 122), reused 27 (delta 9), pack-reused 0[K
Receiving objects: 100% (293/293), 90.65 MiB | 7.36 MiB/s, done.
Resolving deltas: 100% (122/122), done.
Collecting progress
  Downloading progress-1.5.tar.gz (5.8 kB)
Building wheels for collected packages: progress
  Building wheel for progress (setup.py) ... [?25l[?25hdone
  Created wheel for progress: filename=progress-1.5-py3-none-any.whl size=8086 sha256=570cbb6b608cf4053d2eb4dc6005f3872a02989c1b0629b0f41dd5c8f9974f09
  Stored in directory: /root/.cache/pip/wheels/4c/ff/85/0cabf2cb317421028ef98853ae5c8d84c31f3e4e11862ea977
Successfully built progress
Installing collected packages: progress
Successfully installed progress-1.5
Collecting pretty-midi
  Downloading pretty_midi-0.2.9.tar.gz (5.6 MB)
[K     |████████

In [2]:
#@title Import all needed modules
import numpy as np
import pickle
import os
import sys
import math
import random
# For plotting
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
#%matplotlib inline
#matplotlib.get_backend()
import mir_eval.display
import librosa
import librosa.display
# For rendering output audio
import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

#Option 1: MAESTRO DataSet

In [3]:
#@title Download Google Magenta MAESTRO v.2.0.0 Piano MIDI Dataset (~1300 MIDIs)
!wget 'https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip' -P "/content/shift_invariant_dictionary_learning/maestro/dataset/"
!unzip "/content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0-midi.zip" -d "/content/shift_invariant_dictionary_learning/maestro/dataset/"


--2021-07-19 20:46:07--  https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.142.128, 74.125.195.128, 142.250.99.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.142.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 59243107 (56M) [application/zip]
Saving to: ‘/content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0-midi.zip’


2021-07-19 20:46:09 (32.9 MB/s) - ‘/content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0-midi.zip’ saved [59243107/59243107]

Archive:  /content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0-midi.zip
   creating: /content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0/
  inflating: /content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0/maestro-v2.0.0.csv  
   creating: /content/shift_invariant_dictio

In [None]:
#@title Process MAESTRO MIDI DataSet
!python3 midi/preprocess_midi.py '/content/shift_invariant_dictionary_learning/maestro/dataset/maestro-v2.0.0'

Preprocessing midi files and saving to ./dataset/e_piano
Found 1282 pieces
Preprocessing...
50 / 1282
100 / 1282


In [None]:
# from processor import encode_midi, decode_midi

                  
# from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy


# train_dataset, val_dataset, test_dataset = create_epiano_datasets("/content/MusicTransformer-Pytorch/dataset/e_piano", 2048)

# example1 = list(train_dataset)[1].numpy()
# print(example1)
# # print("torch size ", train_dataset.size())

# # print(len(list(train_dataset)[9][0]))
# # tmp = []
# name = "test111"
# # for point in train_dataset:
# #     # isthis = decode_midi(point[0].numpy(), name + ".mid")
# #     isthis = decode_midi(point[0].numpy())
# #     tmp.append(isthis.estimate_tempo())


# # print("tempo:" , tmp )

# decode_midi(example1[0:2048], name + ".mid")
# FluidSynth("/content/font.sf2").midi_to_audio(name + ".mid", name + ".wav")
# Audio(name + ".wav")


In [None]:
from midi.processor import encode_midi, decode_midi
import torch
                  
from dataset.e_piano import create_epiano_datasets, compute_epiano_accuracy


train_dataset, val_dataset, test_dataset = create_epiano_datasets("/content/shift_invariant_dictionary_learning/maestro/dataset/e_piano", 2048)

for i in range(len(train_dataset)):
    print("train_dataset size", train_dataset[i].size())



# Model definitions

In [None]:
#@title Pytorch for DL

import torch.nn.functional as F
import torch.optim as optim
from torch import nn
import torch
from torch.nn.utils import weight_norm
import numpy as np

def get_model_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    return params

In [None]:
#@title KWTA


class SparsifyBase(nn.Module):
    def __init__(self, sparse_ratio=0.5):
        super(SparsifyBase, self).__init__()
        self.sr = sparse_ratio
        self.preact = None
        self.act = None
    def get_activation(self):
        def hook(model, input, output):
            self.preact = input[0].cpu().detach().clone()
            self.act = output.cpu().detach().clone()
        return hook
    def record_activation(self):
        self.register_forward_hook(self.get_activation())


class Sparsify1D_kactive(SparsifyBase):
    def __init__(self, k=1):
        super(Sparsify1D_kactive, self).__init__()
        self.k = k
    def forward(self, x):
        m = torch.zeros(x.shape).to(device)
        for i in range(self.k):
            indeces = x.topk(self.k, dim=1)[1][:, i]
            m += torch.mul(torch.zeros(x.shape).to(device).scatter(1, indeces.unsqueeze(1), 1), x)
            # print("\n hi", m )
        return m.double()

In [None]:
#@title TCN 

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()


class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 1)
        self.conv2.weight.data.normal_(0, 1)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 1)

    def forward(self, x):
        # print("block ", x.size())
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)


        # print("last layer conv", self.network[-1].conv2.weight.data[:,0,:].size())
        # print("last layer conv", self.network[-1].conv2.weight.data[:,0,:])

    def forward(self, x):
        return self.network(x)


In [None]:
#@title TCN - Autoeconder 

class TCNAutoencoder(nn.Module):
    def __init__(self, kernel_size, dropout, wta_k):
        super(TCNAutoencoder, self).__init__()
        self.wta = Sparsify1D_kactive(k = wta_k)
        self.feature = TemporalConvNet(1, [8,16,24], kernel_size, dropout=dropout).double()
        self.encoder = torch.nn.Conv1d(in_channels=24, out_channels=1000, kernel_size=kernel_size, padding=0, bias=True, stride=4)
        self.decoder = torch.nn.ConvTranspose1d(in_channels=1000, out_channels=1, kernel_size=kernel_size, padding=0, bias=True, stride=4)
        # self.encoder.weight.data.normal_(30)
        # self.decoder.weight.data.normal_(300)
        self.relu1 = nn.ReLU()
        self.code = None
        # torch.nn.init.xavier_uniform(self.encoder.weight)
        # torch.nn.init.xavier_uniform(self.decoder.weight)
    def get_kernels(self):
        return self.decoder.weight.data[:,0,:]
    def feature_map(self, x):
        code = self.code
        return code
    def forward(self, x):
        # x needs to have dimension (N, C, L) in order to be passed into CNN
        output = self.feature(x)
        self.code = self.wta(self.encoder(output))
        output = self.decoder(self.code )
        return output


# Model training 

In [None]:
#@title GO

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using device: ", device)

model = TCNAutoencoder(kernel_size=4, 
                       dropout=0.2, 
                       wta_k = 100).to(device).double()
print("TCNAutoencoder trainable parameters: ", get_model_parameters(model))

# model.load_state_dict(torch.load("model.pth"))


loss_fn = torch.nn.MSELoss().to(device)
# optimizer = optim.SGD(model.parameters(), lr=.01, weight_decay = 0.00001, momentum=0.05) ##this has weight decay just like you implemented
optimizer = optim.AdamW(model.parameters(), lr=.005,  betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True) ##this has weight decay just like you implemented
epochs = 60
history = {"loss": []}
print("test " ,max(1, model.wta.k - 1) )
calc = []
for i in range(epochs):
    #decaying WTA
    if i % 10 == 0 and i != 0:
        model.wta.k = max(1, model.wta.k - 5)
        print("model.wta.k: ", model.wta.k)
    for train_data in train_dataset:
        # calc.extend(train_data.flatten().numpy())
        #normalize 
        train_data = (train_data - 224.15541543187527) / 111.14747885919755
        #preprocess
        lenby4 = len(train_data) // 4
        train_data = train_data[None, None, 0:lenby4*4].to(device).double()
        
        #preprocess
        optimizer.zero_grad()
        output = model(train_data)

        loss = loss_fn(output, train_data)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        history["loss"].append(float(loss))
    print("Epoch : {} \t Loss : {} ".format(i, round(float(np.mean(history["loss"], axis=0)),7)))
    history["loss"] = []

# print(len(calc))
# print(np.mean(calc, axis=0))
# print(np.std(calc, axis=0))


In [None]:
#@title Test Recustruction 

def get_code(model,input):
    model(input)
    return model.code


def play_example(input):
    decode_midi(input[0:2048], name + ".mid")
    FluidSynth("/content/font.sf2").midi_to_audio(name + ".mid", name + ".wav")
    Audio(name + ".wav")


#make it a keep top n 
def exchange_max_rows(A,B):
    maxrowA = torch.argmax(A.sum(1))
    maxrowB = torch.argmax(B.sum(1))
    rowA = A[maxrowA:maxrowA+1,].clone()
    rowB = B[maxrowB:maxrowB+1,].clone()
    # print("rowA ",rowA )
    # print("rowB ",rowB )
    A[maxrowB:maxrowB+1,] =  rowB
    B[maxrowA:maxrowA+1,] =  rowA
    return A,B

#make it a keep top n 
def keep_topk(A,k):
    mask = torch.zeros(A.shape).to(device)
    v, i  = torch.topk(A.sum(1), k)
    print("\n index is", i)
    mask[i, ] = True
    return mask * A




In [None]:
    
index_example = 207


#get raw_input
raw_input = train_dataset[index_example]

print(raw_input.shape)
print(raw_input.max())


print("orginal 1")
name = "music"
decode_midi(raw_input.numpy()[0:300], name + ".mid")
FluidSynth("/content/font.sf2").midi_to_audio(name + ".mid", name + ".wav")
Audio(name + ".wav")





In [None]:
print("reconstructed")
name = "music_rec"

raw_input = (raw_input - 224.5851314855734) / 111.61066023994307

train_data = raw_input[None, None, 0:(len(raw_input) // 4)*4].to(device).double()
print("train_data size", train_data.shape)

model_out = model(train_data)[0,0,:]
model_out = (model_out * 111.61066023994307) + 224.5851314855734

print("model_out size", model_out.shape)
print("model_out max", model_out.max())

print(model_out.cpu().detach().numpy().astype(int))
decode_midi(model_out.cpu().detach().numpy().astype(int)[0:300], name + ".mid")
FluidSynth("/content/font.sf2").midi_to_audio(name + ".mid", name + ".wav")
Audio(name + ".wav")

In [None]:
torch.save(model.state_dict(), "model.pth")


In [None]:
from google.colab import files
files.download('model.pth') 