In [54]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import data_loader
import numpy as np
import sample_to_chords as s2c


### Definition VAE

In [55]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
print('using',dev)
device = torch.device(dev)
class VAE(nn.Module):
    N_CHORDS = 16
    N_PITCH = 12
    N_QUALITY = 7 # A changer aussi dans data_loader
    
    SIZE_HIDDEN = 400
    SIZE_LATENT = 40
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(self.N_CHORDS * self.N_PITCH * self.N_QUALITY, self.SIZE_HIDDEN)
        self.fc21 = nn.Linear(self.SIZE_HIDDEN, self.SIZE_LATENT)
        self.fc22 = nn.Linear(self.SIZE_HIDDEN, self.SIZE_LATENT)
        
        self.fc3 = nn.Linear(self.SIZE_LATENT, self.SIZE_HIDDEN)
        self.fc4 = nn.Linear(self.SIZE_HIDDEN, self.N_CHORDS * self.N_PITCH * self.N_QUALITY)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        soft = nn.Sigmoid()
        return soft(self.fc4(h3).view(-1, self.N_CHORDS, self.N_PITCH * self.N_QUALITY))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.N_CHORDS*self.N_PITCH * self.N_QUALITY))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

using cuda:0


In [56]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, beta):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 16*12*7), x.view(-1, 16*12*7), reduction='sum')

    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + beta*KLD

In [57]:
def train(epoch):
    model.train()
    train_loss = 0
    beta = epoch/epochs
    for batch_idx, data in enumerate(realbook_dataset):
        data = data.to(device)
        optimizer.zero_grad()
    
    
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, beta)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), Nchunks,
                100. * batch_idx * len(data)/ Nchunks,
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / Nchunks))

In [58]:
epochs = 10
batch_size = 128
log_interval = 100

In [59]:
realbook_dataset = data_loader.import_dataset()
Nchunks = len(realbook_dataset)
realbook_dataset = torch.split(realbook_dataset, batch_size, 0)
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Dataset loaded !


### Load model

In [49]:
model.load_state_dict(torch.load("./model_realbook.pt"))

<All keys matched successfully>

### Train model

In [60]:
for epoch in range(1, epochs + 1):
    train(epoch)

====> Epoch: 1 Average loss: 40.3773
====> Epoch: 2 Average loss: 22.4668
====> Epoch: 3 Average loss: 23.1513
====> Epoch: 4 Average loss: 25.2772
====> Epoch: 5 Average loss: 27.6773
====> Epoch: 6 Average loss: 29.9735
====> Epoch: 7 Average loss: 32.1881
====> Epoch: 8 Average loss: 34.2807
====> Epoch: 9 Average loss: 36.2142
====> Epoch: 10 Average loss: 38.0100


In [61]:
torch.save(model.state_dict(), "./model_realbook.pt")

### Test

In [65]:
PITCH_LIST = ["A", "A#", "B", "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#"]
QUALITY_LIST = ["maj", "min", "dim", "maj7", "min7", "7", "dim7"]

def sample_to_chords(sample):
    idx_chords = np.argmax(sample[0,:,:],1)
#     print(idx_chords)
    chords = [PITCH_LIST[int(idx/7)] + ":" + QUALITY_LIST[int(idx%7)] for idx in idx_chords]
    return chords

index_test = 16
test_sample = realbook_dataset[0][index_test]

print("Vérité")
true_sample = test_sample.view(1, 16, -1).numpy()
print(sample_to_chords(true_sample))

print()
print("Par VAE")
model.to(torch.device("cpu"))
recons_test, _, _ = model(test_sample)
print(sample_to_chords(recons_test.detach().numpy()))

Vérité
['C:maj', 'C:maj', 'C:maj', 'C:maj', 'C:7', 'C:7', 'C:7', 'C:7', 'F:maj', 'F:maj', 'F:maj', 'F:maj', 'F:min7', 'F:min7', 'F:min7', 'F:min7']

Par VAE
['C:maj', 'C:maj', 'C:maj', 'C:maj', 'C:7', 'C:7', 'C:7', 'C:7', 'F:maj', 'F:maj', 'F:maj', 'F:maj', 'F:maj', 'F:min', 'F:min7', 'F:min7']


### Create

In [66]:
N_LATENT = 40
with torch.no_grad():
    sample = torch.randn(1, N_LATENT)
    sample = model.decode(sample).cpu()
    sample = sample.numpy()
print(sample.shape)
sample_to_chords(sample)

(1, 16, 84)


['G:maj',
 'G:maj',
 'G:maj',
 'A#:maj',
 'G:maj',
 'A#:maj',
 'C:maj',
 'A:min',
 'A:min',
 'A:min',
 'A:min',
 'A:7',
 'A:7',
 'D:min7',
 'D:min7',
 'D:min7']

3085