In [1]:
''' imports '''

# set auto reload imported modules
%load_ext autoreload
%autoreload 2

# general imports
import os, shutil

# numpy for array handling
import numpy as np

# import pytorch core libs
import torch

# write audio to file
from librosa.output import write_wav


''' sample-rnn components '''
# add sample-rnn libs directory to path
import sys
sys.path.append('../libs/samplernn/')

# import core sample-rnn model (inc. frame-lvl rnn and sample-lvl mlp)
from model import SampleRNN
from model import Predictor
from model import Generator

# wrapper for optimiser
from optim import gradient_clipping

# training criterion
from nn import sequence_nll_loss_bits

# import audio dataset management
from dataset import FolderDataset
from dataset import DataLoader


In [2]:
''' initialise models components '''

# model parameters
_frame_sizes = (16, 4)
_n_rnn = 1
_dim = 1024
_learn_h0 = True
_q_levels = 256 # 8 bit depth
_weight_norm = True

# initialise sample-rnn model
model = SampleRNN(
    frame_sizes = _frame_sizes,
    n_rnn = _n_rnn,
    dim = _dim,
    learn_h0 = _learn_h0,
    q_levels = _q_levels,
    weight_norm = _weight_norm
)

# intitialise predictor model
predictor = Predictor(model)

generator = Generator(model)


  init(chunk)


In [3]:
''' push to device '''

# get computing device
device = ("cuda" if torch.cuda.is_available() else "cpu")

# push models to device
model = model.to(device)
predictor = predictor.to(device)


In [4]:
''' init optimiser '''

# get model parameters
params = predictor.parameters()

# initialise optimiser
optimizer = gradient_clipping( torch.optim.Adam(params) )
#optimizer = torch.optim.Adam(params)


In [5]:
''' initialise dataset and dataloader '''

# define dataset
_datasets_path = '../data/'
_dataset = 'piano-small'
_path = os.path.join(_datasets_path, _dataset)


# get number frame samples of final frame-level rnn in model
_overlap_len = model.lookback

_seq_len = 1024
_batch_size = 64

_train_frac = 0.9

# initialise dataset
train_dataset = FolderDataset(
    _path,
    _overlap_len,
    _q_levels,
    0,
    _train_frac,
)

# intitialise dataloader
train_data_loader = DataLoader(
    train_dataset,
    batch_size = _batch_size,
    seq_len = _seq_len,
    overlap_len = _overlap_len,
    
    #shuffle = True,
    #drop_last = True,
)


In [6]:
''' training loop '''

# set training epochs
epochs = 1

# perform training model over epochs, iterate over range epoch limit
for _epoch in range(epochs):

    #print('epoch: ', _epoch)
    
    ## model training, given dataset compute loss and update model parameters
    
    # set model to training mode (gradients stored)
    predictor.train()
    
    # iterate over dataset
    for (_iteration, data) in enumerate(train_data_loader):

        #print('iteration: ', _iteration)
        
        # zero gradients and step optimiser
        optimizer.zero_grad()

        # unpack dataset
        batch_inputs = data[0].to(device)
        batch_target = data[-1].to(device)
        
        # reevaluate the function multiple times; clear the gradients, compute and return loss
        def closure():

            # pass inputs through model, return output
            #batch_output = predictor(batch_inputs, reset = data[1])
            batch_output = predictor(batch_inputs, reset = False)

            # calculate loss for inputs to outputs
            loss = sequence_nll_loss_bits(batch_output, batch_target)

            #print(loss)

            # calculate gradients and return loss
            loss.backward()

            return loss

        # step optimiser with closure
        optimizer.step(closure)


  return F.log_softmax(x.view(-1, self.q_levels)) \


In [7]:
''' perform sample generation '''

# define dataset
_output_path = '../data/'

_sample_rate = 16000
_n_samples = 1
_sample_length = int(_sample_rate * 3)

# intiialise generator

samples = generator(_n_samples, _sample_length).cpu().float().numpy()

for i in range(_n_samples):
    write_wav(
        os.path.join(_output_path, 'test-out.wav'),
        samples[i, :], sr = _sample_rate, norm = True)
    

  prev_samples = torch.autograd.Variable(
  prev_samples = torch.autograd.Variable(


In [27]:
''' save checkpoint '''

torch.save(predictor.state_dict(), '../data/chkpt-sml')


In [21]:
''' load checkpoint '''

#_state_dict = torch.load('../data/chkpt')
_state_dict = torch.load('../data/chkpt-sml')

predictor.load_state_dict(_state_dict)
    

<All keys matched successfully>