In [None]:
import os
import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import torch.nn.functional as F

from lib.voices_dataset import VoicesDataset,DatasetSplit
from lib.voices_model import VoicesModel
from lib.opt import optimize
from lib.config import corpora,piano_corpora

import lib.media as media
from IPython.display import Audio,clear_output
from scipy.io import wavfile

In [None]:
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='0'
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

checkpoint_dir = '_multipart8'
context = 10

In [None]:
non_piano_corpora = tuple([corpus for corpus in corpora.keys() if corpus not in piano_corpora])
train_set = VoicesDataset(context=context,split=DatasetSplit.train, corpora=non_piano_corpora)
test_set = VoicesDataset(context=context,split=DatasetSplit.test, corpora=non_piano_corpora)

In [None]:
class Model(VoicesModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def define_graph(self, debug=False):
        self.knote = 16
        self.kt = self.kn = 300
        self.parts = train_set.max_parts
        self.maxdur = train_set.maxdur
        self.dur_features = len(train_set.dur_map)
                
        # reference meter for pitch class
        self.noteref = np.zeros([1,1,2*self.m-1],dtype=np.float32)
        self.noteref[0,0,self.m] = 1
        self.noteref = Variable(torch.from_numpy(self.noteref).cuda(), requires_grad=False)
        
        fold = np.zeros([self.m,12],dtype=np.float32)
        for n in range(self.m):
            fold[n,n%12] = 1
        self.fold = Variable(torch.from_numpy(fold).cuda(),requires_grad=False)
        
        # rhythm prediction
        self.rhythm_wt = Parameter(Tensor(self.dur_features,self.kt))
        self.rhythm_wn = Parameter(Tensor(12,self.kt))
        self.rhythm_wh = Parameter(Tensor(self.kt,self.kt))
        self.rhythm_wloc = Parameter(Tensor(48,self.kt))
        
        self.rhythm_wtop = Parameter(Tensor(self.kt,self.maxdur))
        self.rhythm_bias = Parameter(Tensor(self.maxdur))
        
        # notes prediction
        self.notes_wnote = Parameter(Tensor(self.knote,1,1,2*self.m))
        
        self.notes_wn = Parameter(Tensor(self.knote,self.kn))
        self.notes_wyn = Parameter(Tensor(self.kn,1,self.m-1))
        self.notes_wt = Parameter(Tensor(self.dur_features,self.kn))
        self.notes_wref = Parameter(Tensor(self.kn,1,self.m))
        self.notes_wh = Parameter(Tensor(self.kn,self.kn))
        
        self.notes_wtop = Parameter(Tensor(self.kn,1))
        self.notes_bias = Parameter(Tensor(self.m))
        
    def predict_rhythm(self, e, t, loc):
        ht = torch.mm(t.view(-1,self.dur_features),self.rhythm_wt).view(-1,self.context,self.kt)
        ef = torch.mm(e.view(-1,self.m),self.fold).view(-1,self.context,12)
        hn = torch.mm(ef.view(-1,12),self.rhythm_wn).view(-1,self.context,self.kt)
        floc = torch.mm(loc,self.rhythm_wloc)
        
        h = Variable(torch.zeros(1,self.kt).cuda())
        for k in range(self.context):
            hh = torch.mm(h,self.rhythm_wh)
            h = F.relu(hh + ht[:,k,:] + hn[:,k,:] + floc)
        
        return torch.mm(h,self.rhythm_wtop) + self.rhythm_bias
    
    def predict_notes(self, e, t, y, yt, loc):
        fembed = F.conv2d(F.pad(e,(self.m,self.m-1,0,0,0,0))[:,None],self.notes_wnote).transpose(1,3).contiguous()
        
        t = torch.cat([t,yt[:,None,:]],dim=1)
        ht = torch.mm(t.view(-1,self.dur_features),self.notes_wt).view(-1,self.context+1,self.kn)[:,None,:,:]
        
        hn = torch.mm(fembed.view(-1,self.knote),self.notes_wn).view(-1,self.m,self.context,self.kn)
        hnc = F.conv1d(F.pad(y,(self.m-1,0,0,0))[:,None],self.notes_wyn)[:,:,None,:-1].transpose(1,3).contiguous()
        hn = torch.cat([hn,hnc],dim=2)
        href = F.conv1d(self.noteref,self.notes_wref).transpose(1,2)
        
        h = Variable(torch.zeros(1,self.m,self.kn).cuda())
        for k in range(self.context+1):
            hh = torch.mm(h.view(-1,self.kn),self.notes_wh).view(-1,self.m,self.kn)
            h = F.relu(hh + ht[:,:,k,:] + hn[:,:,k,:] + href)
        
        return torch.mm(h.view(-1,self.kn),self.notes_wtop).view(-1,self.m) + self.notes_bias
    
    def forward(self, x):
        e,t,f,y,yt,yf,loc,corpus = x
        
        e = e[:,:,0,:].contiguous(); t = t[:,:,0,:].contiguous()
        y = y[:,0,:].contiguous(); yt = yt[:,0,:].contiguous()
        return self.predict_rhythm(e,t,loc), self.predict_notes(e,t,y,yt,loc)

In [None]:
model = Model(checkpoint_dir, avg=.999, context_length=context,
              offset=train_set.offset, m=train_set.m, dataset=train_set,
              weight_scale=.01)
model.initialize()

In [None]:
optimize(model,train_set,test_set,learning_rate=.001,batch_size=300,workers=4,update_rate=1000)

In [None]:
optimize(model,train_set,test_set,learning_rate=.0003,batch_size=300,workers=4,update_rate=1000)

In [None]:
model.restore_checkpoint()
plt = media.PlotFormatter(burnin=5000)
plt.plot('log loss',model.stats['ll_tr'][2],color='b')
plt.plot('log loss',model.stats['ll_test'][2],color='g',share=True)
plt.plot('avp notes',model.stats['apn_tr'][2],color='b')
plt.plot('avp notes',model.stats['apn_ts'][2],color='g',share=True)
plt.plot('rhythm weights',model.sum_weights('rhythm'),color='g')
plt.plot('notes weights',model.sum_weights('notes'),color='g')
plt.plot('rhythm log loss',model.stats['llt_tr'][2],color='b')
plt.plot('rhythm log loss',model.stats['llt_ts'][2],color='g',share=True)
plt.plot('notes log loss',model.stats['lln_tr'][2],color='b')
plt.plot('notes log loss',model.stats['lln_ts'][2],color='g',share=True)
plt.show()

In [None]:
from lib.config import piano_corpora
from lib.opt import terminal_error

non_piano_corpora = tuple([corpus for corpus in corpora.keys() if corpus not in piano_corpora])
non_piano_test_set = VoicesDataset(context=context,split=DatasetSplit.test, corpora=non_piano_corpora)
with model.iterate_averaging():
    print(terminal_error(model,non_piano_test_set,batch_size=100))