In [None]:
from model import GuitarToneCloning
import torch.nn as nn
import torch.nn.functional as F
import torch

In [None]:
model = GuitarToneCloning().cuda()

In [None]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import AdamW
import math

class LinearWarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, decay_factor=0.1, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.decay_factor = decay_factor

        self.base_lrs = [group['lr'] for group in optimizer.param_groups]
        self.min_lrs = [base_lr * decay_factor for base_lr in self.base_lrs]

        super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return [
                base_lr * self.last_epoch / self.warmup_steps
                for base_lr in self.base_lrs
            ]
        else:
            progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            progress = min(max(progress, 0.0), 1.0)
            return [
                min_lr + (base_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
                for base_lr, min_lr in zip(self.base_lrs, self.min_lrs)
            ]

no_decay = ['bias', 'norm'] 
params = list(model.named_parameters())
params_generator = list(model.vocoder.named_parameters())
params_discriminator = list(model.vocoder_disc.named_parameters())

lr = 7e-4
loss_weight = {'adv':1,'disc':1,'spec':1,'sn':0.0001}


optimizer_grouped_parameters = [
        {
  
        'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.001,
        'lr': lr
        },
        {
        'params': [p for n, p in params if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,
        'lr': lr
        },
]
optimizer = AdamW(params=optimizer_grouped_parameters,betas=(0.8,0.99))
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer, warmup_steps=2000, total_steps=100000, decay_factor=0.01)

In [None]:
import torchaudio
import os

from itertools import product
from utils import yin_pitch_sequence, amplitude_sequence, wave_delta, mel_spectrogram
sr=24000
y_path="violin.wav"
y=torchaudio.load(y_path)
y=torchaudio.functional.resample(y[0], y[1], sr)[0]

def augment_waveform(waveform, sample_rate, time_stretches, pitch_shifts):

    if waveform.dim() == 2 and waveform.size(0) > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    silence = torch.zeros((int(0.5 * sample_rate)))
    results = []

    for ts, ps in product(time_stretches, pitch_shifts):
        augmented = waveform.clone()
        if ts != 1.0:
            new_sr = int(sample_rate * ts)
            resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sr)
            augmented = resample(augmented)

        if ps != 0:
            pitch_shift = torchaudio.transforms.PitchShift(sample_rate=sample_rate, n_steps=ps)
            augmented = pitch_shift(augmented)
        results.append(augmented.detach())
        results.append(silence.detach())
    return torch.cat(results[:-1])

#"""
import pickle
y = augment_waveform(y,sr,[1.1, 1.0, 0.9], [-3,-2,-1,0,1,2,3])
with open('cached_augmented_vocoder.pkl','wb') as f:
    pickle.dump(y,f)
#"""
frame_size=256

In [None]:
import pickle
with open('cached_augmented_vocoder.pkl','rb') as f:
    y = pickle.load(f)

In [None]:
mel = mel_spectrogram(y.unsqueeze(0)).squeeze(0).T

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import copy

plt.ion() 
fig, ax = plt.subplots()

adv = []
disc = []
spec = []

In [None]:
import random
import numpy as np
N=8
model.train()
i=0
for param in model.vocoder_disc.parameters():
    param.requires_grad_(True)
while True:
    i+=1
    training_window = 128
    print(f"{training_window = }")
    y_window = training_window*frame_size
    max_x = mel.shape[0]-training_window

    
    x_batch = []
    y_batch = []
    for _ in range(N):
        
        max_y = y.shape[0]-y_window
        x_idx = random.randint(0, max_x)
        y_idx = int(x_idx*frame_size)
        
        x_train = mel[x_idx:x_idx + training_window].T.cuda().unsqueeze(0)
        y_train = y[y_idx:y_idx + y_window].cuda().unsqueeze(0)

        x_batch.append(x_train)
        y_batch.append(y_train)
    
    x_batch = torch.cat(x_batch, dim=0)
    y_batch = torch.cat(y_batch, dim=0).unsqueeze(1)
    print(x_batch.shape)
    loss_spectral, loss_adv, loss_disc = model.train_vocoder(x_batch,y_batch,optimizer,loss_weight=loss_weight)
    scheduler.step()
    adv.append(loss_adv)
    disc.append(loss_disc)
    spec.append(loss_spectral)
    if i%10==0:
        clear_output(wait=True) 
        ax.clear()
        ax.plot(np.array(adv), label='adv')
        ax.plot(np.array(disc), label='disc')
        ax.plot(np.array(spec), label='spectral')
        ax.legend()
        display(fig) 
        torch.cuda.empty_cache()
        if i %1000 == 0 and i!= 0:
            torch.save(model,'ckpt'+str(i))
    

In [None]:
disc = [min(10,value) for value in disc]
adv  = [min(10,value) for value in adv]
spec = [min(10,value) for value in spec]

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
import sounddevice as sd
model.eval()
window=364000
start=215000000
mel_ = mel[start//256:(start+window)//256].unsqueeze(0).transpose(-1,-2).clone()
print(mel_.shape)
real = y[start:start+window]
with torch.no_grad():
    fake = model.vocoder(mel_.cuda()).squeeze(0).squeeze(0)
print(fake.shape)

In [None]:
sd.play(real.cpu().numpy(),sr)

In [None]:
sd.play(fake.cpu().detach().numpy(),sr)

In [None]:
plt.imshow(mel_.cpu().detach().numpy()[0],cmap='grey');plt.plot()

In [None]:
vocoder = torch.load('test2',weights_only=False).vocoder.cuda()
vocoder.eval()
vocoder.turn_on_caching()
fake2 = torch.tensor([],device='cuda')
with torch.no_grad():
    for chunk in mel_[0].T:
        #print(chunk.cuda().unsqueeze(0).unsqueeze(2).shape)
        out=vocoder(chunk.cuda().unsqueeze(0).unsqueeze(2)).flatten()
        fake2 = torch.cat((fake,out),dim=-1)

In [None]:
sd.play(fake2.cpu().detach().numpy(),sr)

In [None]:
torch.cuda.empty_cache()