In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from encodec import EncodecModel
from encodec.msstftd import MultiScaleSTFTDiscriminator
import encodec.customAudioDataset as data
from encodec.losses import total_loss, disc_loss
from encodec.utils import convert_audio, save_audio
from encodec.customAudioDataset import collate_fn

In [2]:
num_epochs = 50
batch_size = 2
sample_rate = 24000
learning_rate = 0.01
dataset_path = './dataset'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
dataset = data.CustomAudioDataset(dataset_folder=dataset_path, n_samples=20, sample_rate=sample_rate, tensor_cut=48000, extension='flac')
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [4]:
model = EncodecModel._get_model(target_bandwidths=[1.5, 3, 6, 12, 24], 
                                sample_rate=sample_rate, 
                                channels=1, causal=False, audio_normalize=False, segment=1.0).to(device)

disc = MultiScaleSTFTDiscriminator(filters=32).to(device)

In [5]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.8, 0.99))
optimizer_disc = optim.Adam(disc.parameters(), lr=learning_rate, betas=(0.8, 0.99))

In [6]:
def train_one_step(epoch, optimizer, optimizer_disc, model, disc, dataset_loader, losses):
    model.train()
    disc.train()
    data_length = len(dataset_loader)
    epoch_loss_g = 0.0
    epoch_loss_w = 0.0
    epoch_loss_disc = 0.0
    for i, input_wav in enumerate(dataset_loader):
        optimizer.zero_grad()
        optimizer_disc.zero_grad()
        input_wav = input_wav.to(device)
        output_wav, loss_w = model(input_wav)
        logits_real, fmap_real = disc(input_wav)
        logits_fake, fmap_fake = disc(output_wav)
        
        losses_g = total_loss(fmap_real, logits_fake, fmap_fake, input_wav, output_wav, sample_rate, device)
        loss_g = 3 * losses_g['l_g'] + 3 * losses_g['l_feat'] + losses_g['l_t'] / 10 + losses_g['l_f']

        logits_real, _ = disc(input_wav)
        logits_fake, _ = disc(output_wav.detach())
        loss_disc = disc_loss(logits_real, logits_fake)
        
        loss_w.backward(retain_graph=True)        
        loss_g.backward()   
        loss_disc.backward()
        optimizer.step()
        optimizer_disc.step()

        epoch_loss_g += loss_g.item()
        epoch_loss_w += loss_w.item()
        epoch_loss_disc += loss_disc.item()
        
    losses['loss_g'].append(epoch_loss_g / data_length)
    losses['loss_w'].append(epoch_loss_w / data_length)
    losses['loss_disc'].append(epoch_loss_disc / data_length)
    
    print(f'Epoch {epoch} | Loss_g: {losses["loss_g"][-1]:.4f} | Loss_w: {losses["loss_w"][-1]:.4f} | Loss_disc: {losses["loss_disc"][-1]:.4f}')

In [7]:
losses = {'loss_g': [], 'loss_w': [], 'loss_disc': []}
def train(n_epochs, optimizer, optimizer_disc, model, disc, dataset_loader):    
    for epoch in range(n_epochs):
        train_one_step(epoch, optimizer, optimizer_disc, model, disc, dataset_loader, losses)