In [376]:
''' imports '''

# set auto reload imported modules
%load_ext autoreload
%autoreload 2

# general imports
#import os, shutil

# numpy for array handling
import numpy as np

# import pytorch core libs
import torch

# write audio to file
#from librosa.output import write_wav

import pickle

# audio playback widget
import IPython.display as ipd

%matplotlib widget
import matplotlib.pyplot as plt

# add sample-rnn libs directory to path
import sys
sys.path.append('../melodyrnn/')

# import core sample-rnn model (inc. frame-lvl rnn and sample-lvl mlp)
from model import SampleRNN
from model import Predictor
from model import Generator

# wrapper for optimiser
from optim import gradient_clipping

# training criterion
from nn import sequence_nll_loss_bits

# import audio dataset management
#from dataset import SampleRNNDataset
from dataset import MelodyDataset

#from dataset import DataLoader
from dataset import MelodyDataLoader

from utils import build_audio


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
''' init dataset '''

# directory path
dir_path = '../data/MIDI/'

# init dataset
dataset = MelodyDataset(
    dir_path,
)


In [384]:
''' initialise models components '''

# model parameters
#_frame_sizes = (16, 8, 4)
_frame_sizes = (4, 8, 16)
#_frame_sizes = (4, 4, 4)
_n_rnn = 1
_dim = 1024
_learn_h0 = False
#_learn_h0 = True
#_q_levels = 256 # 8 bit depth
#_q_levels = 2**10
_q_levels = 2**7
_weight_norm = True

# initialise sample-rnn model
model = SampleRNN(
    frame_sizes = _frame_sizes,
    n_rnn = _n_rnn,
    dim = _dim,
    learn_h0 = _learn_h0,
    q_levels = _q_levels,
    weight_norm = _weight_norm
)

# intitialise predictor model
predictor = Predictor(model)

generator = Generator(model)


In [385]:
''' push to device '''

# get computing device
device = ("cuda" if torch.cuda.is_available() else "cpu")

# push models to device
model = model.to(device)
predictor = predictor.to(device)


In [386]:
''' init optimiser '''

# get model parameters
params = predictor.parameters()

# initialise optimiser
optimizer = gradient_clipping( torch.optim.Adam(params, lr = 1e-4) )
#optimizer = torch.optim.Adam(params)


In [382]:
''' init dataloader '''

def collate(samples: list):
    
    ''' pad and collate list tracks for batch '''
    
    #print(samples[0].shape)
    
    # get max length sample in batch
    s = max([ sample.shape[0] for sample in samples ])
    
    # pad each sample to max length on time axis
    samples = [ np.pad(sample, (0, s - sample.shape[0]))
               for sample in samples ]
    
    # stack mini-batch samples, adjust batch to index zero
    samples = np.stack(samples, axis = 1).transpose(1, 0)
    
    # return mini-batch as tensors
    return torch.LongTensor(samples)


_overlap_len = model.lookback
print('overlap_len ', _overlap_len)

#_seq_len = 1024
_seq_len = 512

_batch_size = 32

# intitialise dataloader
train_data_loader = MelodyDataLoader(
    dataset,
    batch_size = _batch_size,
    seq_len = _seq_len,
    overlap_len = _overlap_len,
    
    num_workers = 4,
    shuffle = True,
    
    collate_fn = collate,
    drop_last = True,
)


overlap_len  512


In [359]:
''' training loop '''

# set training epochs
epochs = 5

# perform training model over epochs, iterate over range epoch limit
for _epoch in range(epochs):

    print('epoch: ', _epoch)
    
    ## model training, given dataset compute loss and update model parameters
    
    # set model to training mode (gradients stored)
    predictor.train()
    
    # iterate over dataset
    for (_iteration, data) in enumerate(train_data_loader):

        #print('iteration: ', _iteration)
        
        # zero gradients and step optimiser
        optimizer.zero_grad()

        # unpack dataset
        batch_inputs = data[0].to(device)
        batch_target = data[-1].to(device)
        
        #print(batch_inputs.shape)
        #print(batch_target.shape)
        
        # reevaluate the function multiple times; clear the gradients, compute and return loss
        def closure():

            # pass inputs through model, return output
            batch_output = predictor(batch_inputs, reset = data[1])
            #batch_output = predictor(batch_inputs, reset = False)

            # calculate loss for inputs to outputs
            loss = sequence_nll_loss_bits(batch_output, batch_target)

            #print(loss.item())

            # calculate gradients and return loss
            loss.backward()

            return loss

        # step optimiser with closure
        optimizer.step(closure)


epoch:  0


  return F.log_softmax(x.view(-1, self.q_levels)) \


epoch:  1
epoch:  2


KeyboardInterrupt: 

In [388]:
''' perform sample generation '''

_n_samples = 128
_sample_length = int(100000) # ~30s

# generate samples
samples = generator(_n_samples, _sample_length).cpu().float().numpy()


  prev_samples = torch.autograd.Variable(
  prev_samples = torch.autograd.Variable(
  return F.log_softmax(x.view(-1, self.q_levels)) \


In [389]:
''' save some generated samples '''
with open('../data/output/midi/smpls-4816-05', 'wb') as file:
    pickle.dump(samples, file)
    

In [395]:
''' convert output to note matrix '''

i = 2

m = samples[i]
_pad = 1024
M = np.zeros((128, m.shape[0]+_pad))
for i in range(m.shape[0]):
    if m[i] != 0:
        M[int(m[i]), _pad+i] = 1

''' display melody note matrix '''
fig = plt.figure(figsize=(7,3))
ax = fig.add_subplot(111)

# set downsample time axis [int]
ds = 128

# display track note matrix
plt.imshow(M[::-1,::ds], cmap = 'bone_r', alpha = 1.)

# format figure and display
#ax.set_ylim(35, 95)

ax.set_ylabel('midi note')
ax.set_xlabel('time [smpl]')

plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [396]:
''' build and play audio sample '''

# set sample rate
sr = 16000

# set frame rate of midi track, adjust for playback speed
fr = 2**18 // 2

#_,__ = 140, -1

# generate audio each track
#audio = build_audio(M[:, ds*_:ds*__], sr = sr, fr = fr)
audio = build_audio(M[:, :], sr = sr, fr = fr, soft = 1e-4)

# normalise audio
audio /= audio.max()

# display playback widget
ipd.Audio(audio, rate = sr)


In [11]:
''' convert output to note matrix '''

fig = plt.figure(figsize=(7,3))
ax = fig.add_subplot(111)

for i in range(len(dataset))[:50]:
    #i = 1
    m = dataset[i]
    print(dataset.file_names[i])

    M = np.zeros((128, m.shape[0]+512))

    for i in range(512, M.shape[1]-512):
        if m[i] != 0:
            M[int(m[i]), i] = 1

    # set downsample time axis [int]
    ds = 256

    # display track note matrix
    plt.imshow(M[:,::ds], cmap = 'bone_r', alpha = 0.1)

# format figure and display
#ax.set_ylim(0, 127)
ax.set_ylim(35, 95)

ax.set_ylabel('midi note')
ax.set_xlabel('time [smpl]')

#plt.tight_layout()
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

086_ljubav .midi
117_my star.midi
153_hunter of stars.midi
109_lost and found .midi
145_warrior.midi
016_la chica que yo quiero .midi
019_sound of silence.midi
077_la la love.midi
139_playing with fire.midi
101_qele qele .midi
166_rock bottom.midi
174_net als toen.midi
198_made of stars.midi
080_i can.midi
014_we could be the same.midi
049_shady lady.midi
058_sing little birdie.midi
180_minn hinsti dans.midi
012_fairytale.midi
079_you let me walk alone .midi
159_ein bi.midi
024_occidentali_s karma .midi
189_whats another .midi
009_everyway that I can.midi
092_im alive.midi
154_silencio.midi
104_horehronie.midi
045_a million voices.midi
178_nous les.midi
125_visionary dream.midi
169_lusitana.midi
111_silent storm .midi
093_we are the winners .midi
157_take me to your heaven.midi
039_lejla.midi
144_hold me.midi
143_no no never.midi
147_save your kisses.midi
146_something.midi
022_if love was a crime.midi
056_nije ljubav.midi
099_in your eyes.midi
163_making up your mind.midi
199_ole ole.

In [12]:
''' build and play audio sample '''

# set sample rate
sr = 16000

# set frame rate of midi track, adjust for playback speed
fr = 2**17 // 1

# generate audio each track
#audio = build_audio(M[:, ds*110:ds*200], sr = sr, fr = fr)
audio = build_audio(M[:, :], sr = sr, fr = fr)

# normalise audio
audio /= audio.max()

# display playback widget
ipd.Audio(audio, rate = sr)


In [362]:
''' save checkpoint '''

torch.save(model.state_dict(), '../data/melody-state-444-02')


In [387]:
''' load checkpoint '''

#_state_dict = torch.load('../data/chkpt')
_state_dict = torch.load('../data/melody-state-05')

model.load_state_dict(_state_dict)
    

<All keys matched successfully>