In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

from pushover import notify
from utils import makegif
from random import randint
import numpy as np

import IPython
from IPython.display import Image
from IPython.core.display import Image, display
import PIL

import piano_roll_utils
import librosa
import librosa.display

%load_ext autoreload
%autoreload 2

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
bs = 32 # batchsize

In [None]:
# Load Data

def load_img(img_path):
    img = PIL.Image.open(img_path)
    img = img.convert(mode="L")
#     npimg = np.array(img)/256.0
    return img
IMRANGE = 256 # uint8

dataset = datasets.ImageFolder(root='trainings/rolls_gray', transform=transforms.Compose([
    transforms.ToTensor(), 
    lambda x: x > 0,
    lambda x: x.type(torch.FloatTensor),
]), loader=load_img)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)

In [None]:
# Fixed input for debugging
fixed_x, _ = next(iter(dataloader))
save_image(fixed_x, 'tmp/real_image.png')

Image('tmp/real_image.png')

In [None]:
print(dataset[1][0].shape)
HSIZE = 2048 #9216 # 1024
ZDIM =  16 # 32

In [None]:
# def get_player(piano_roll):
#     midi= piano_roll_utils.piano_roll_to_pretty_midi(piano_roll, fs=50, program=11)
#     IPython.display.Audio(midi.fluidsynth(fs=44100),rate=44100)

In [None]:
rand_i = 0 #randint(1, 100)
print(rand_i)

In [None]:
sample = dataset[rand_i][0].unsqueeze(0)[0][0]
for_play = np.array(sample).astype(int) * 64
midi= piano_roll_utils.piano_roll_to_pretty_midi(for_play,fs=50,program=11)
IPython.display.Audio(midi.fluidsynth(fs=44100),rate=44100)

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
class UnFlatten(nn.Module):
    def forward(self, input, size=HSIZE):
        return input.view(input.size(0), size, 1, 1)

In [None]:
class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=HSIZE, z_dim=ZDIM):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2), # -> [32, 32, 31, 31] 63
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # -> [32, 64, 14, 14] 31
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), # -> [32, 128, 6, 6] 14
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2), # -> [32, 256, 2, 2] 6
            nn.ReLU(), 
            nn.Conv2d(256, 512, kernel_size=4, stride=2), # -> Null -> [32, 512, 2, 2] 
            nn.ReLU(), 
            Flatten() # -> [32, 1024]  -> [32, 2048]
            # [32, a, b, c] -> [32, abc]
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(), 
            nn.ConvTranspose2d(h_dim, 256 , kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
#         print("bottle: ",mu.shape, logvar.shape)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        print("======== Encode ========", x.shape)
        h = self.encoder(x)
        print("enc(x): ", h.shape)
        z, mu, logvar = self.bottleneck(h)
        print("z.shape: ", z.shape)
        return z, mu, logvar

    def decode(self, z):
#         print("======== Decode ========", z.shape)
        z = self.fc3(z)
#         print("fc3(z).shape: ", z.shape)
        z = self.decoder(z)
#         print("decode(fc3(z)).shape: ", z.shape)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
#         print(z.shape)
        z = self.decode(z)
#         print(z.shape, mu.shape, logvar.shape)
        return z, mu, logvar

In [None]:
image_channels = fixed_x.size(1)

In [None]:
model = VAE(image_channels=image_channels).to(device)
model_version = "graybin_bce_batch32-imgs_2297-epch_75-100" #"graybin_bce_d16-imgs_2297-epch_100"  # "graybin_bce_d16-imgs_2297-epch_100" 
model.load_state_dict(torch.load('models/cvae.' + model_version, map_location='cpu'))

In [None]:
def print_gt_zero_elem(matrix):
    print(get_elems(matrix))
    
def get_elems(matrix, thres=0.1):
    return matrix[matrix >= thres]

In [None]:
rand_i = randint(1, 100)
print(rand_i)

In [None]:
fixed_x = dataset[rand_i][0].unsqueeze(0)
fixed_x2 = fixed_x.clone()

In [None]:
def compare(x):
    recon_x, _, _ = model(x)
    recon_x2 = recon_x.clone()
    
#     recon_pass = recon_x2[0][0]
    recon_pass = (recon_x2>0.4).type(torch.FloatTensor) #* recon_x2
    print("recon:", get_elems(recon_x, thres=0.01))
    print("original:", get_elems(x, thres=0.01))
    return torch.cat([x * 64.0, recon_pass * 64.0]), recon_pass

In [None]:
for_play = np.array(fixed_x2[0][0]).astype(int) * 64
midi= piano_roll_utils.piano_roll_to_pretty_midi(for_play,fs=50,program=11)
IPython.display.Audio(midi.fluidsynth(fs=44100),rate=44100)

In [None]:
compare_x, recon_x = compare(fixed_x)

In [None]:
for_play = np.array(recon_x[0][0].data.cpu()) * 64
midi= piano_roll_utils.piano_roll_to_pretty_midi(for_play.astype(int) ,fs=50,program=11)
IPython.display.Audio(midi.fluidsynth(fs=44100),rate=44100)

In [None]:
sample_filename = 'outputs/sample_image-{}.png'.format(model_version)
save_image(compare_x.data.cpu(), sample_filename)
display(Image(sample_filename, width=300, unconfined=True))

In [None]:
save_image(fixed_x2 * 256, 'tmp/test.png', padding=0)
display(Image('tmp/test.png'))
z, mu, log_var = model.encode(fixed_x2)
print(z)
x = model.decode(z)
print(x)
save_image(x[0][0] * 256, 'tmp/test.png', padding=0)
display(Image('tmp/test.png'))

In [None]:
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(1, affine=False)
input = torch.randn(20, 100, 35, 45)
test = torch.Tensor([[[[-5.7798,  2.1175,  1.8994, -1.1060,  2.9570,  1.2670,  1.6456,  1.1319,
          0.3539, -3.7390, -7.2136, -0.0291, -1.3198,  0.5766,  1.1158,  2.6609]]]])
output = m(test)

In [None]:
test

In [None]:
output

In [4]:
xs = torch.randn((128,8,32,32))

In [6]:
xs

tensor([[[[ 1.6497e+00, -1.2354e+00,  1.0321e+00,  ..., -2.3203e-01,
            1.1211e+00, -3.1046e-01],
          [-4.1823e-01,  1.2201e-01,  1.6051e-02,  ..., -2.2568e-01,
           -1.6504e+00,  1.0754e+00],
          [-3.5300e-01, -6.7170e-01,  4.2309e-01,  ..., -1.6493e+00,
            1.5065e+00, -6.7356e-01],
          ...,
          [-6.2478e-01, -8.3208e-01, -1.6183e+00,  ..., -1.9204e+00,
           -6.5826e-01, -1.0545e+00],
          [-2.0849e+00,  1.1849e+00,  6.8645e-01,  ...,  3.8623e-01,
            7.9311e-01, -8.9476e-02],
          [-3.2642e-01,  1.4181e+00, -2.5301e+00,  ...,  1.0128e+00,
           -1.6913e+00, -6.9789e-01]],

         [[-8.9811e-01,  3.5447e-01, -5.1330e-01,  ..., -2.0926e-01,
            4.6303e-01,  7.2819e-01],
          [-7.7812e-01, -5.1563e-01,  1.9099e-01,  ...,  7.9822e-02,
            1.0617e+00,  1.3884e+00],
          [ 1.2734e-01, -6.2976e-01,  1.0528e+00,  ..., -1.5813e-01,
           -9.8564e-01, -5.0896e-01],
          ...,
     