In [None]:
import os
import numpy as np
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import torch.nn.functional as F

from lib.parts_dataset import PartsDataset,DatasetSplit
from lib.parts_model import PartsModel
from lib.opt import optimize

import lib.media as media
from IPython.display import Audio
from scipy.io import wavfile

In [None]:
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES']='0'
#os.environ['CUDA_LAUNCH_BLOCKING']='1'

checkpoint_dir = '_singlepart19'
context = 10

In [None]:
train_set = PartsDataset(context=context,split=DatasetSplit.train)
test_set = PartsDataset(context=context,split=DatasetSplit.test)

In [None]:
class Model(PartsModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def define_graph(self, debug=False):
        self.knote = 16
        self.kt = self.kn = 300
        self.d = 5
        self.regions = self.context-self.d+1
        self.d2 = 3
        self.regions2 = self.regions-self.d2+1
        self.maxdur = train_set.maxdur
        
        # reference meter for pitch class
        self.noteref = np.zeros([1,1,2*self.m-1],dtype=np.float32)
        self.noteref[0,0,self.m] = 1
        self.noteref = Variable(torch.from_numpy(self.noteref).cuda(), requires_grad=False)
        
        fold = np.zeros([self.m,12],dtype=np.float32)
        for n in range(self.m):
            fold[n,n%12] = 1
        self.fold = Variable(torch.from_numpy(fold).cuda(),requires_grad=False)
        
        # rhythm prediction
        self.rhythm_wt = Parameter(Tensor(self.kt,self.maxdur,self.d))
        self.rhythm_wn = Parameter(Tensor(self.kt,12,self.d))
        self.rhythm_wloc = Parameter(Tensor(48,self.kt))
        
        self.rhythm_w2 = Parameter(Tensor(self.kt,self.kt,self.d2))
        
        self.rhythm_wtop = Parameter(Tensor(self.kt*self.regions2,self.maxdur))
        self.rhythm_bias = Parameter(Tensor(self.maxdur))
        
        # notes prediction
        self.notes_wnote = Parameter(Tensor(self.knote,1,1,2*self.m))
        
        self.notes_wn = Parameter(Tensor(self.knote,1,1,self.m))
        self.notes_wyn = Parameter(Tensor(self.knote,1,self.m-1))
        self.notes_wt = Parameter(Tensor(self.kn,self.maxdur,self.d))        
        self.notes_wc = Parameter(Tensor(self.kn,self.knote,self.d,1))
        self.notes_wref = Parameter(Tensor(self.kn,1,self.m))
        
        self.notes_w2 = Parameter(Tensor(self.kn,self.kn,self.d2,1))
        
        self.notes_wtop = Parameter(Tensor(self.kn*(self.regions2+1),1))
        self.notes_bias = Parameter(Tensor(self.m))
    
    def predict_rhythm(self, e, t, loc):
        frhythm = F.conv1d(t.transpose(1,2),self.rhythm_wt)
        ef = torch.mm(e.view(-1,self.m),self.fold).view(e.shape[0],e.shape[1],-1)
        fnotes = F.conv1d(ef.transpose(1,2),self.rhythm_wn)
        floc = torch.mm(loc,self.rhythm_wloc)[:,:,None]
        
        z1 = F.relu(frhythm + fnotes + floc)
        z2 = F.relu(F.conv1d(z1,self.rhythm_w2))
        return torch.mm(z2.view(-1,self.kt*self.regions2),self.rhythm_wtop) + self.rhythm_bias
    
    def predict_notes(self, e, t, y, yt):
        fembed = F.conv2d(F.pad(e,(self.m,self.m-1,0,0,0,0))[:,None],self.notes_wnote)
        
        t = torch.cat([t,yt[:,None,:]],dim=1)
        frhythm = F.conv1d(t.transpose(1,2),self.notes_wt)[:,:,:,None]
        
        fcurc = F.conv1d(F.pad(y,(self.m-1,0,0,0))[:,None],self.notes_wyn)[:,:,None,:-1]
        fchords = F.relu(torch.cat([fembed,fcurc],dim=2))
        fharmony = F.conv2d(fchords,self.notes_wc)
        fref = F.conv1d(self.noteref,self.notes_wref)[:,:,None,:]
        
        z1 = F.relu(frhythm + fharmony + fref)
        z2 = F.relu(F.conv2d(z1,self.notes_w2)).transpose(1,3).contiguous()
        return torch.mm(z2.view(-1,self.kn*(self.regions2+1)),self.notes_wtop).view(-1,self.m) + self.notes_bias
    
    def forward(self, x):
        e,t,y,yt,loc = x
        return self.predict_rhythm(e,t,loc), self.predict_notes(e,t,y,yt)

In [None]:
model = Model(checkpoint_dir, avg=.999, context_length=context,
              offset=train_set.offset, m=train_set.m, dur_map=train_set.dur_map,
              weight_scale=.01)
model.initialize()

In [None]:
optimize(model,train_set,test_set,learning_rate=.01,batch_size=300,workers=4,update_rate=1000)

In [None]:
optimize(model,train_set,test_set,learning_rate=.003,batch_size=300,workers=4,update_rate=5000)

In [None]:
model.restore_checkpoint()
plt = media.PlotFormatter(burnin=5000)
plt.plot('log loss',model.stats['ll_tr'][2],color='b')
plt.plot('log loss',model.stats['ll_test'][2],color='g',share=True)
plt.plot('avp notes',model.stats['apn_tr'][2],color='b')
plt.plot('avp notes',model.stats['apn_ts'][2],color='g',share=True)
plt.plot('rhythm weights',model.sum_weights('rhythm'),color='g')
plt.plot('notes weights',model.sum_weights('notes'),color='g')
plt.plot('rhythm log loss',model.stats['llt_tr'][2],color='b')
plt.plot('rhythm log loss',model.stats['llt_ts'][2],color='g',share=True)
plt.plot('notes log loss',model.stats['lln_tr'][2],color='b')
plt.plot('notes log loss',model.stats['lln_ts'][2],color='g',share=True)
plt.show()

In [None]:
with model.iterate_averaging():
    x = model.sample()

In [None]:
media.visualize(x)
wav = media.render_notes(*media.sample_to_notes(x),tempo=1.0)
wavfile.write(os.path.join(model.cp,'out.wav'), 44100, wav)
Audio(wav,rate=44100)