In [5]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [6]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

In [7]:
#This file is heavily based on Daniel Johnson's midi manipulation code in https://github.com/hexahedria/biaxial-rnn-music-composition
import msgpack
import glob
from tqdm import tqdm

In [8]:
###################################################
# In order for this code to work, you need to place this file in the same 
# directory as the midi_manipulation.py file and the Pop_Music_Midi directory

import midi_manipulation

In [12]:
def get_songs(path):
    files = glob.glob('{}/*.mid*'.format(path))
    songs = []
    for f in tqdm(files):
        try:
            song = np.array(midi_manipulation.midiToNoteStateMatrix(f))
            if np.array(song).shape[0] > 50:
                songs.append(song)
        except Exception as e:
            raise e           
    return songs

songs = get_songs('data/Pop_Music_Midi') #These songs have already been converted from midi to msgpack
print("{len(songs)} songs processed")
###################################################

100%|██████████| 126/126 [00:04<00:00, 31.02it/s]

{len(songs)} songs processed





In [50]:
vocab_size = songs[0].shape[-1]

class MusicDataset(torch.utils.data.Dataset):
    """Face Landmarks dataset."""

    def __init__(self, songs, sequence_length):
        self.songs = songs
        self.dataset = np.vstack(songs)
        self.sequence_length = sequence_length
        self.vocab_size = self.dataset.shape[-1]

    def __len__(self):
#         print((self.data_length - self.sequence_length)//10)
        return (self.dataset.shape[0] // self.sequence_length) - 1

    def __getitem__(self, idx):
#         print(np.arange(10)[0:8]) # example
#         print(np.arange(10)[8])
        start = idx * self.sequence_length
        x = self.dataset[start:start+self.sequence_length]
        y = self.dataset[start+self.sequence_length]
#         x_hot = one_hot(x, self.vocab_size)
        return x, y


In [51]:
import matplotlib.pyplot as plt

def show_adn_save(file_name,img):
    npimg = np.transpose(img.numpy(),(1,2,0))
    f = "./%s.png" % file_name
    plt.imshow(npimg)
    plt.imsave(f,npimg)

In [52]:
class RBM(nn.Module):
    def __init__(self,
                 n_vis=784,
                 n_hin=500,
                 k=5):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
    
    def sample_from_p(self,p):
        return F.relu(torch.sign(p - Variable(torch.rand(p.size()))))
    
    def v_to_h(self,v):
        p_h = F.sigmoid(F.linear(v,self.W,self.h_bias))
        sample_h = self.sample_from_p(p_h)
        return p_h,sample_h
    
    def h_to_v(self,h):
        p_v = F.sigmoid(F.linear(h,self.W.t(),self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v,sample_v
        
    def forward(self,v):
        pre_h1,h1 = self.v_to_h(v)
        
        h_ = h1
        for _ in range(self.k):
            pre_v_,v_ = self.h_to_v(h_)
            pre_h_,h_ = self.v_to_h(v_)
        
        return v,v_
    
    def free_energy(self,v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v,self.W,self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()

In [53]:
batch_size = 64
md = MusicDataset(songs, sequence_length=10)
train_loader = torch.utils.data.DataLoader(md,
    batch_size=batch_size)



In [54]:
test = enumerate(train_loader)

In [55]:
a, (x, y) = next(test)

In [56]:
x.shape

torch.Size([64, 10, 156])

In [57]:
rbm = RBM(n_vis=md.vocab_size*md.sequence_length, k=3)

In [58]:
train_op = optim.SGD(rbm.parameters(),0.1)

In [72]:
for epoch in range(100):
    loss_ = []
    for i, (data,target) in enumerate(train_loader):
#         data = Variable(data.view(-1,784))
        if data.shape[0] != train_loader.batch_size:
            print('Mismatched shapes:', i, data.shape)
            break
        data = Variable(data.float().view(data.shape[0], -1))
        sample_data = data.bernoulli()
        
        v,v1 = rbm(sample_data)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        loss_.append(loss.data[0])
        train_op.zero_grad()
        loss.backward()
        train_op.step()
        
        if (i % 100 == 0):
            print(f'i: {i}, loss: {np.mean(loss_)}')
    
    print(np.mean(loss_))

i: 0, loss: 3.89801025390625
Mismatched shapes: 31 torch.Size([41, 10, 156])
-4.530200589087702
i: 0, loss: 10.162200927734375
Mismatched shapes: 31 torch.Size([41, 10, 156])
-3.62249263640373
i: 0, loss: 2.460418701171875
Mismatched shapes: 31 torch.Size([41, 10, 156])
-5.487013293850806
i: 0, loss: -0.408843994140625
Mismatched shapes: 31 torch.Size([41, 10, 156])
-5.891303277784778
i: 0, loss: 2.919464111328125
Mismatched shapes: 31 torch.Size([41, 10, 156])
-5.870359359248992
i: 0, loss: 0.5150146484375
Mismatched shapes: 31 torch.Size([41, 10, 156])
-6.507988714402722
i: 0, loss: -5.123382568359375
Mismatched shapes: 31 torch.Size([41, 10, 156])
-6.829882221837198
i: 0, loss: -2.365081787109375
Mismatched shapes: 31 torch.Size([41, 10, 156])
-6.443546418220766
i: 0, loss: -8.63446044921875
Mismatched shapes: 31 torch.Size([41, 10, 156])
-6.702360091670867
i: 0, loss: -4.990264892578125
Mismatched shapes: 31 torch.Size([41, 10, 156])
-6.646216607862903
i: 0, loss: -5.58480834960937

KeyboardInterrupt: 

In [73]:
v1.view(64, md.sequence_length, -1).shape

torch.Size([64, 10, 156])

In [74]:
v_out = v1.view(64, md.sequence_length, -1).data.numpy()

In [75]:

# S = np.reshape(sample[i,:], (num_timesteps, 2*note_range))
for i,s in enumerate(v_out):
#     S = v_out.data.numpy()
    midi_manipulation.noteStateMatrixToMidi(s, "data/rbm_mnist_output/generated_chord_{}".format(i))

In [None]:
show_adn_save("real",make_grid(v.view(32,1,28,28).data))

In [None]:
show_adn_save("generate",make_grid(v1.view(32,1,28,28).data))

In [70]:
from IPython.display import FileLink

In [71]:
FileLink('data/rbm_mnist_output/generated_chord_2.mid')