In [61]:
import os, time
import torch
import pandas as pd
from skimage import io, transform
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import cv2

In [56]:
dataset_path = f"C:\midi\data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [83]:
class PianorollDataset(Dataset):
    def __init__(self, folder_path):
        self.folder = folder_path
        self.oswalk = list(os.walk(dataset_path))[1:] #throw away the first one - the root directory
        self.dataset = []
        print("Dataset: Building metadata")
        start = time.time()
        for root,_,files in self.oswalk:
            t = len(files) - 2 # there is always a midi-file itself in the directory, so -1, 
            # also since we're pairing the files - there will be N-1 pairs, so -1 again
            self.dataset += [(os.path.join(root, f"{i}.png"),
                         os.path.join(root, f"{i+1}.png")) for i in range(t)]
        print(f"Dataset: Built in {np.round(time.time()-start,3)} sec")
        
    def __getitem__(self, idx):
        first, second = self.dataset[idx]
        return torch.unsqueeze(torch.Tensor(io.imread(first)/255),0), torch.unsqueeze(torch.Tensor(io.imread(second)/255),0)
    
    def __len__(self):
        return len(self.dataset)
    

In [146]:
class CoNNVAEr(nn.Module):
    def __init__(self):
        super(CoNNVAEr,self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(64, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3,stride=2),
            nn.Conv2d(128, 256, kernel_size=3, stride=2),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(256,128)
        self.fc_s = nn.Linear(256,128)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 256, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=(4,1), stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, s):
        std = s.mul(0.5).exp_()
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
    
    def forward(self, x):
        t = self.encoder(x)
        mu, s = self.fc_mu(t), self.fc_s(t)
        z = self.reparameterize(mu,s)
        z = z.view(x.size(0), 128, 1, 1)
        z = self.decoder(z)
        return z, mu, s

In [157]:
dataset = PianorollDataset(dataset_path)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True, drop_last=True)
net = CoNNVAEr()
adam = torch.optim.Adam(net.parameters())
mse = nn.MSELoss()
epochs = 50
print(f"{len(dataloader)} batches prepared")

Dataset: Building metadata
Dataset: Built in 0.198 sec
154 batches prepared


In [158]:
for ep in range(epochs):
    for i, (x,y) in enumerate(dataloader):
        pred, mu, s = net.forward(x)
        KLD = -0.5 * torch.mean(1 + s - mu.pow(2) - s.exp())
        MSE = mse(pred,y)
        loss = KLD+MSE
        adam.zero_grad()
        loss.backward()
        adam.step()
        if i % 50 == 0:
            print(f"Epoch {ep}/{epochs}: loss = {np.round(loss.item(),4)};\
                  kld = {np.round(KLD.item(),4)}; mse = {np.round(MSE.item(),4)}")

Epoch 0/50: loss = 0.2582;                  kld = 0.0023; mse = 0.2559


KeyboardInterrupt: 