In [None]:
import os
import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import torch.nn.functional as F

from lib.voices_dataset import VoicesDataset,DatasetSplit
from lib.voices_model import VoicesModel
from lib.opt import optimize
from lib.config import corpora

import lib.media as media
from IPython.display import Audio,clear_output
from scipy.io import wavfile

In [None]:
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='0'
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

checkpoint_dir = '_multipart2'
context = 5

In [None]:
train_set = VoicesDataset(context=context,split=DatasetSplit.train)
test_set = VoicesDataset(context=context,split=DatasetSplit.test)

In [None]:
class Model(VoicesModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def define_graph(self, debug=False):
        self.parts = train_set.max_parts
        self.maxdur = train_set.maxdur
        self.dur_features = len(train_set.dur_map)
        self.knote = 16
        self.kt = 64
        self.kn = 128
                
        # reference meter for pitch class
        self.noteref = np.zeros([1,1,2*self.m-1],dtype=np.float32)
        self.noteref[0,0,self.m] = 1
        self.noteref = Variable(torch.from_numpy(self.noteref).cuda(), requires_grad=False)
        
        fold = np.zeros([self.m,12],dtype=np.float32)
        for n in range(self.m):
            fold[n,n%12] = 1
        self.fold = Variable(torch.from_numpy(fold).cuda(),requires_grad=False)
        
        # rhythm prediction
        self.rhythmpart_wt = Parameter(Tensor(self.dur_features,self.kt))
        self.rhythmpart_wn = Parameter(Tensor(12,self.kt))
        self.rhythmpart_wh = Parameter(Tensor(self.kt,self.kt))
        
        self.rhythm_wpart = Parameter(Tensor(self.kt,self.kt))
        self.rhythm_wh = Parameter(Tensor(self.kt,self.kt))
        
        self.rhythm_wloc = Parameter(Tensor(48,self.kt))
        self.rhythm_wthis = Parameter(Tensor(self.kt,self.kt))
        self.rhythm_wall = Parameter(Tensor(self.kt,self.kt))
        
        self.rhythm_wtop = Parameter(Tensor(self.kt,self.maxdur))
        self.rhythm_bias = Parameter(Tensor(self.maxdur))
        
        # notes prediction
        self.notespart_wnote = Parameter(Tensor(self.knote,1,2*self.m))
        
        self.notespart_wt = Parameter(Tensor(self.dur_features,self.kn))   
        self.notespart_wn = Parameter(Tensor(self.knote,self.kn))
        self.notespart_wyn = Parameter(Tensor(self.kn,1,self.m-1))
        self.notespart_wh = Parameter(Tensor(self.kn,self.kn))
        
        self.notes_wpart = Parameter(Tensor(self.kn,self.kn))
        self.notes_wh = Parameter(Tensor(self.kn,self.kn))
        
        self.notes_wref = Parameter(Tensor(self.kn,1,self.m))
        self.notes_wthis = Parameter(Tensor(self.kn,self.kn))
        self.notes_wall = Parameter(Tensor(self.kn,self.kn))
        
        self.notes_wtop = Parameter(Tensor(self.kn,1))
        self.notes_bias = Parameter(Tensor(self.m))
    
    def predict_rhythm(self, e, t, f, y, yt, yf, loc, corpus):
        t = torch.cat([t,yt[:,None]],dim=1)
        ft = torch.mm(t.view(-1,self.dur_features),self.rhythmpart_wt).view(-1,self.context+1,self.parts,self.kt)
        ft[:,-1,0] = 0*ft[:,-1,0]
        
        e = torch.cat([e,y[:,None]],dim=1)
        ef = torch.mm(e.view(-1,self.m),self.fold)
        fn = torch.mm(ef,self.rhythmpart_wn).view(-1,self.context+1,self.parts,self.kt)
        fn[:,-1,0] = 0*fn[:,-1,0]
        
        f = torch.cat([f,yf[:,None]],dim=1)
        
        floc = torch.mm(loc,self.rhythm_wloc)
        
        hpart = Variable(torch.zeros(e.shape[0],self.parts,self.kt).cuda())
        h = Variable(torch.zeros(1,self.kt).cuda())
        for k in range(self.context+1):
            hpart = torch.bmm(f[:,k].transpose(1,2),hpart)
            
            fparth = torch.mm(hpart.view(-1,self.kt),self.rhythmpart_wh).view(-1,self.parts,self.kt)
            hpart = F.relu(fparth + ft[:,k] + fn[:,k] + floc[:,None])
            
            fh = torch.mm(h.view(-1,self.kt),self.rhythm_wh).view(-1,self.kt)
            fpart = torch.mm(hpart.sum(dim=1),self.rhythm_wpart)
            h = F.relu(fh + fpart)
        
        fall = torch.mm(h,self.rhythm_wall)
        fthis = torch.mm(hpart[:,0],self.rhythm_wthis)
        zx = F.relu(fall + fthis)

        return torch.mm(zx,self.rhythm_wtop) + self.rhythm_bias
    
    def predict_notes(self, e, t, f, y, yt, yf, loc, corpus):
        e = torch.cat([e,y[:,None]],dim=1)
        fembed = F.conv1d(F.pad(e.view(-1,self.m),(self.m,self.m-1))[:,None],self.notespart_wnote)
        fembed = fembed.transpose(1,2).contiguous().view(-1,self.knote)
        
        t = torch.cat([t,yt[:,None]],dim=1)
        ft = torch.mm(t.view(-1,self.dur_features),self.notespart_wt)
        ft = ft.view(-1,self.context+1,self.parts,self.kn)[:,:,:,None]
        
        fn = torch.mm(fembed,self.notespart_wn)
        fn = fn.view(-1,self.context+1,self.parts,self.m,self.kn)
        fn[:,-1,0] = F.conv1d(F.pad(y[:,0],(self.m-1,0))[:,None],self.notespart_wyn)[:,:,:-1].transpose(1,2)
        
        f = torch.cat([f,yf[:,None]],dim=1)
        
        fref = F.conv1d(self.noteref,self.notes_wref).transpose(1,2)
        
        hpart = Variable(torch.zeros(e.shape[0],self.parts,self.m,self.kn).cuda())
        h = Variable(torch.zeros(1,self.m,self.kn).cuda())
        for k in range(self.context+1):
            hpart = torch.bmm(f[:,k].transpose(1,2),hpart.view(hpart.shape[0],self.parts,-1)).view(-1,self.parts,self.m,self.kn)
            
            fparth = torch.mm(hpart.view(-1,self.kn),self.notespart_wh).view(-1,self.parts,self.m,self.kn)
            hpart = F.relu(fparth + ft[:,k] + fn[:,k] + fref)
            
            fh = torch.mm(h.view(-1,self.kn),self.notes_wh).view(-1,self.m,self.kn)
            fpart = torch.mm(hpart.sum(dim=1).view(-1,self.kn),self.notes_wpart).view(-1,self.m,self.kn)
            h = F.relu(fh + fpart)
        
        fall = torch.mm(h.view(-1,self.kn),self.notes_wall).view(-1,self.m,self.kn)
        fthis = torch.mm(hpart[:,0].contiguous().view(-1,self.kn),self.notes_wthis).view(-1,self.m,self.kn)
        zx = F.relu(fall + fthis)
        
        return torch.mm(zx.view(-1,self.kn),self.notes_wtop).view(-1,self.m) + self.notes_bias
    
    def forward(self, x):
        e,t,f,y,yt,yf,loc,corpus = x
        return self.predict_rhythm(e,t,f,y,yt,yf,loc,corpus), self.predict_notes(e,t,f,y,yt,yf,loc,corpus)

In [None]:
model = Model(checkpoint_dir, avg=.999, context_length=context,
              offset=train_set.offset, m=train_set.m, dataset=train_set,
              weight_scale=.003)
model.initialize()

In [None]:
optimize(model,train_set,test_set,learning_rate=.003,batch_size=100,workers=4,update_rate=5000,l2=.0001)

In [None]:
optimize(model,train_set,test_set,learning_rate=.001,batch_size=100,workers=4,update_rate=5000,l2=.0001)

In [None]:
optimize(model,train_set,test_set,learning_rate=.0003,batch_size=100,workers=4,update_rate=5000,l2=.0001)

In [None]:
optimize(model,train_set,test_set,learning_rate=.0001,batch_size=100,workers=4,update_rate=5000,l2=.0001)

In [None]:
model.restore_checkpoint()
plt = media.PlotFormatter(burnin=20000)
plt.plot('log loss',model.stats['ll_tr'][2],color='b')
plt.plot('log loss',model.stats['ll_test'][2],color='g',share=True)
plt.plot('avp notes',model.stats['apn_tr'][2],color='b')
plt.plot('avp notes',model.stats['apn_ts'][2],color='g',share=True)
plt.plot('rhythm weights',model.sum_weights('rhythm'),color='g')
plt.plot('notes weights',model.sum_weights('notes'),color='g')
plt.plot('rhythm log loss',model.stats['llt_tr'][2],color='b')
plt.plot('rhythm log loss',model.stats['llt_ts'][2],color='g',share=True)
plt.plot('notes log loss',model.stats['lln_tr'][2],color='b')
plt.plot('notes log loss',model.stats['lln_ts'][2],color='g',share=True)
plt.show()

In [None]:
from lib.config import piano_corpora
from lib.opt import terminal_error

non_piano_corpora = tuple([corpus for corpus in corpora.keys() if corpus not in piano_corpora])
non_piano_test_set = VoicesDataset(context=context,split=DatasetSplit.test, corpora=non_piano_corpora)
with model.iterate_averaging():
    print(terminal_error(model,non_piano_test_set,batch_size=100))