In [11]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from encodec import EncodecModel
import encodec.customAudioDataset as data
from encodec.losses import total_loss
from encodec.utils import convert_audio, save_audio
from encodec.customAudioDataset import collate_fn

In [12]:
num_epochs = 100
batch_size = 4
sample_rate = 24000
learning_rate = 1e-2
dataset_path = './data'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [13]:
dataset = data.CustomAudioDataset(dataset_folder=dataset_path, n_samples=100, sample_rate=sample_rate)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [14]:
model = EncodecModel._get_model(target_bandwidths=[3], 
                                sample_rate=sample_rate, 
                                channels=1, causal=False).to(device)

In [15]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [16]:
losses_g = []
losses_w = []

def train(n_epocs, model, data_loader, optimizer):
    model.train()
    for epoch in range(n_epocs):
        accum_loss_w = 0.0
        accum_loss_g = 0.0
        for i, input_wav in enumerate(data_loader):
            input_wav = input_wav.to(device)
            output_wav, loss_w = model(input_wav)
            loss_g = total_loss(input_wav, output_wav, sample_rate, device)
            loss = loss_g['l_t'] / 10 + loss_g['l_f']
            
            optimizer.zero_grad()
            loss.backward(retai n_graph=True)    
            loss_w.backward()
            optimizer.step()
            
            accum_loss_g += loss.item()
            accum_loss_w += loss_w.item()
        
        epoch_loss_g = accum_loss_g / len(data_loader)
        epoch_loss_w = accum_loss_w / len(data_loader)
        losses_g.append(epoch_loss_g)
        losses_w.append(epoch_loss_w)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch} | Loss_g: {epoch_loss_g} | Loss_w: {epoch_loss_w}')