In [12]:
import torch
import torch.nn as nn

from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders
from models.hifi_dil import Generator_HiFi_Dil, Discriminator_HiFi_Dil

In [8]:
test_size = 0.05
train_dl, test_dl = build_audio_data_loaders(test_size=test_size)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
adversarial_criterion = nn.BCEWithLogitsLoss()

generator = Generator_HiFi_Dil().to(device)
discriminator = Discriminator_HiFi_Dil().to(device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [10]:
num_epochs = 1

In [None]:
history = {
    "gen": [],
    "disc": []
}

for epoch in range(num_epochs):
    epoch_gen_loss = 0.0
    epoch_disc_loss = 0.0
    num_batches = 0

    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        optimizer_d.zero_grad()

        generated_audio = generator(X)
        real_out = discriminator(Y)
        fake_out = discriminator(generated_audio.detach())

        d_loss = adversarial_criterion(
            real_out - fake_out, torch.ones_like(real_out))

        epoch_disc_loss += d_loss.item()

        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()

        gen_out = discriminator(generated_audio)
        g_loss = adversarial_criterion(gen_out, torch.ones_like(gen_out))

        epoch_gen_loss += g_loss.item()

        g_loss.backward()
        optimizer_g.step()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f" - {num_batches}\t D: {d_loss.item():.6f}\t G: {g_loss.item():.6f}")

    epoch_disc_loss /= num_batches
    epoch_gen_loss /= num_batches
    history["disc"].append(epoch_disc_loss)
    history["gen"].append(epoch_gen_loss)
    print(f"E [{epoch}]\t D: {epoch_disc_loss:.6f}\t G: {epoch_gen_loss:.6f}")

In [13]:
generator.to(cpu)
discriminator.to(cpu)

save_model(generator, "higi_gen")
save_model(discriminator, "higi_disc")