In [None]:
import os
import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import torch.nn.functional as F

from lib.parts_dataset import PartsDataset,DatasetSplit
from lib.parts_model import PartsModel
from lib.opt import optimize

import lib.media as media
from IPython.display import Audio
from scipy.io import wavfile

In [None]:
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='0'
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

checkpoint_dir = '_singlepart16'
context = 10

In [None]:
train_set = PartsDataset(context=context,split=DatasetSplit.train)
test_set = PartsDataset(context=context,split=DatasetSplit.test)

In [None]:
class Model(PartsModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def define_graph(self, debug=False):
        self.maxdur = train_set.maxdur
        
        # reference meter for pitch class
        self.noteref = np.zeros([1,1,2*self.m-1],dtype=np.float32)
        self.noteref[0,0,self.m] = 1
        self.noteref = Variable(torch.from_numpy(self.noteref).cuda(), requires_grad=False)
        
        # rhythm prediction
        self.rhythm_wt = Parameter(Tensor(self.context*self.maxdur,self.maxdur))
        self.rhythm_wn = Parameter(Tensor(self.context*self.m,self.maxdur))
        self.rhythm_wloc = Parameter(Tensor(48,self.maxdur))
        self.rhythm_bias = Parameter(Tensor(self.maxdur))
        
        # notes prediction
        self.notes_wt = Parameter(Tensor(self.context*self.maxdur,self.m))
        self.notes_wn = Parameter(Tensor(1,1,self.context,2*self.m))
        self.notes_wyt = Parameter(Tensor(self.maxdur,self.m))
        self.notes_wyn = Parameter(Tensor(1,1,self.m-1))
        self.notes_wref = Parameter(Tensor(1,1,self.m))
        self.notes_bias = Parameter(Tensor(self.m))
    
    def predict_rhythm(self, e, t, loc):
        frhythm = torch.mm(t.view(-1,self.context*self.maxdur),self.rhythm_wt)
        fnotes = torch.mm(e.view(-1,self.context*self.m),self.rhythm_wn)
        floc = torch.mm(loc,self.rhythm_wloc)
        return frhythm + fnotes + floc + self.rhythm_bias
    
    def predict_notes(self, e, t, y, yt):
        frhythm = torch.mm(t.view(-1,self.context*self.maxdur),self.notes_wt)
        fnotes = F.conv2d(F.pad(e,(self.m,self.m-1,0,0,0,0))[:,None],self.notes_wn)[:,0,0,:]
        fcurt = torch.mm(yt,self.notes_wyt)
        fcurn = F.conv1d(F.pad(y,(self.m-1,0,0,0))[:,None],self.notes_wyn)[:,0,:-1]
        fref = F.conv1d(self.noteref,self.notes_wref)[:,0]
        return frhythm + fnotes + fcurt + fcurn + fref + self.notes_bias
    
    def forward(self, x):
        e,t,y,yt,loc = x
        return self.predict_rhythm(e,t,loc), self.predict_notes(e,t,y,yt)

In [None]:
model = Model(checkpoint_dir, avg=.999, context_length=context,
              offset=train_set.offset, m=train_set.m, dur_map=train_set.dur_map,
              weight_scale=0.)
model.initialize()

In [None]:
optimize(model,train_set,test_set,learning_rate=.1,batch_size=300,workers=4,update_rate=1000)

In [None]:
optimize(model,train_set,test_set,learning_rate=.01,batch_size=300,workers=4,update_rate=1000)

In [None]:
model.restore_checkpoint()
plt = media.PlotFormatter()
plt.plot('log loss',model.stats['ll_tr'][2],color='b')
plt.plot('log loss',model.stats['ll_test'][2],color='g',share=True)
plt.plot('avp notes',model.stats['apn_tr'][2],color='b')
plt.plot('avp notes',model.stats['apn_ts'][2],color='g',share=True)
plt.plot('rhythm weights',model.sum_weights('rhythm'),color='g')
plt.plot('notes weights',model.sum_weights('notes'),color='g')
plt.plot('rhythm log loss',model.stats['llt_tr'][2],color='b')
plt.plot('rhythm log loss',model.stats['llt_ts'][2],color='g',share=True)
plt.plot('notes log loss',model.stats['lln_tr'][2],color='b')
plt.plot('notes log loss',model.stats['lln_ts'][2],color='g',share=True)
plt.show()