In [2]:
import torch 
from torchvision import transforms 

In [3]:
from VAE import * 
from utils import * 
from dataloader import * 

In [4]:
device = "cpu" # Set training device 

In [5]:
# Check if the Encoder doesn't have any dimension dismatches 
import random 
randn = random.randint(1000, 7000) 
wav_to_spec = WavToSpec() 
mel = wav_to_spec(f"../data_parsed_five_sec/Medieval_Celtic_Music_{randn}.wav") 
print(mel.shape) # Check data shape 

encoder = VAEEncoder(1, 16, 16)
output = encoder(mel.unsqueeze(0))
print(output[0].shape) # Check output shape 

torch.Size([1, 80, 512])
torch.Size([1, 16, 20, 128])


In [10]:
import os 
import glob

# Get dataset folder 
root = "../data_parsed_five_sec" 
files = glob.glob(os.path.join(root, "*.wav")) 

In [7]:
transform = transforms.Compose([
    WavToSpec(),
]) 

# Get dataloader 
dset = Spectro(files, transform) 
loader = DataLoader(dset, batch_size = 10, shuffle=True)  

In [8]:
# Normalize 
def reparameterize(mu, logvar):
    std = (0.5 * logvar).exp() 
    eps = torch.randn_like(std)  
    return mu + eps * std 

In [11]:
# Setup encoder and decoder 
enc = VAEEncoder(in_channels=1, C=16, r=16)
dec = VAEDecoder(out_channels=1, C=16)
enc.train(); dec.train() 

# Set optimizer 
opt = torch.optim.AdamW(
    list(enc.parameters()) + list(dec.parameters()),
    lr=2e-4, betas=(0.9, 0.999), weight_decay=0.0
) 
beta_kl = 1e-6 

for x in loader: 
    x = x.to(device) # Move to training device 
    mu, logvar = enc(x) # Get both representation of the latent space 
    z = reparameterize(mu, logvar) 
    
    x_hat = dec(z) # Get decoder output 

    # Calculate loss 
    rec = F.l1_loss(x_hat, x) 
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) # 
    loss = rec + beta_kl * kld
    
    # Backpropagation 
    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(list(enc.parameters())+list(dec.parameters()), 1.0) 
    opt.step() 

KeyboardInterrupt: 