In [None]:
from vae import MELVAE
import numpy as np
import random
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer, AdamW
import torch
import math


In [None]:
model = MELVAE(encoder_channels=[1,16,64,256,32], encoder_downs=[None,'freq','freq','full','full'], 
                 decoder_channels=[1,16,64,256,64,4],
                 discr_channels=[1,16,64,192])

In [None]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer, AdamW
import torch
import math
class LinearWarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0, last_epoch=-1):

        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [
                base_lr * self.last_epoch / self.warmup_steps
                for base_lr in self.base_lrs
            ]
        else:
            progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            return [
                self.min_lr + (base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
                for base_lr in self.base_lrs
            ]


no_decay = ['bias', 'norm'] 
params = list(model.named_parameters())
optimizer_grouped_parameters = [
    {
        'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.003},
    {
        'params': [p for n, p in params if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }
]
optimizer = AdamW(params=optimizer_grouped_parameters, lr=5e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer, warmup_steps=2000, total_steps=200000, min_lr=5e-6)

In [None]:
import json
with open(r"F:\datasets\AVSpeech\processed_hifigan_22khz\dataset.json", 'r') as f:
    dataset = json.load(f)

In [None]:
for r in range(len(dataset)):
    dataset[r]['audio'][0] = np.load(dataset[r]['audio'][0])
std=0
mean=0

for r in range(1000):
    author = random.choice(dataset)
    std += author['audio'][0].std()
    mean += author['audio'][0].mean()
std, mean = std/1000, mean/1000

for r in range(len(dataset)):
    dataset[r]['audio'][0]=(dataset[r]['audio'][0]-mean) /std


In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import copy

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('F:/TENSORBOARD/melvae')
plt.ion() 
fig, ax = plt.subplots()
losses = []
kl_losses = []
disc_losses = []
gen_disc_losses = []
idx=0
best_val_loss=999
i=0    

In [None]:

def kl_weight(step):
    max_=50000
    return 0.015*min(step,max_)/max_
    
model.train().to('cuda')

N=8
def assemble_batch(mel_cutoff=256):
    c = 0
    mels = []

    while c < N:
        artist = random.choice(dataset)
        seq = random.choice(artist['audio'])
        mel = seq[0]

        if mel.shape[-1] < mel_cutoff:
            continue

        start_idx = random.randint(0, mel.shape[-1] - mel_cutoff)
        end_idx = start_idx + mel_cutoff

        mels.append(mel[:, start_idx:end_idx])
        c += 1

    return torch.tensor(mels).to('cuda').unsqueeze(1)

while True:
    optimizer.zero_grad()
    mel = assemble_batch(random.randint(224,400))
    

    disc_loss, rec_loss, kl_loss, gen_adv_loss, (mu, logvar) = model.get_loss(mel)
    kl_loss*=kl_weight(i)
    gen_adv_loss*=0.5
    
    gen_loss = rec_loss + kl_loss + gen_adv_loss
    
    optimizer.zero_grad()
    
    disc_loss.backward(retain_graph=True) 
    gen_loss.backward()
    
    optimizer.step()
    scheduler.step()
    
    losses.append(rec_loss.item())
    kl_losses.append(kl_loss.item())
    gen_disc_losses.append(gen_adv_loss.item())
    disc_losses.append(disc_loss.item())
    
    if i%25==0:
        torch.cuda.empty_cache()
        writer.add_histogram("mu", mu, global_step=i, bins=50)
        writer.add_histogram("logvar", logvar, global_step=i, bins=50)
        
        ax.clear()
        clear_output(wait=True)  
        
        ax.plot(np.array(kl_losses[::8]), label='kl')
        ax.plot(np.array(disc_losses[::8]), label='adv_disc')
        ax.plot(np.array(gen_disc_losses[::8]), label='adv_gen')
        ax.plot(np.array(losses[::8]), label='Loss')
        ax.set_ylim(0,2)
        ax.legend()
        display(fig)  
    i+=1

In [None]:
from scipy.ndimage import sobel

spec=None
mel = assemble_batch(random.randint(224,400))
mel = mel[:1]

In [None]:
model.eval()
plt.imshow(mel.squeeze(0).squeeze(0).cpu().detach().numpy()[::-1],cmap='turbo')
plt.show()
mu, logvar, shapes = model.encode(mel.to('cuda'))
out = model.decode(mu)
out_=out.squeeze(0).squeeze(0).cpu().detach().numpy()
plt.imshow(out_[::-1],cmap='turbo')
plt.show()
model.train()
np.save('out.npy',out.squeeze(0).cpu().detach().numpy()*std + mean)