In [1]:
import torch
import torch.nn as nn
from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders
from models.hifi_dil import Generator_HiFi_Dil, Discriminator_HiFi_Dil

In [2]:
test_size = 0.05
train_dl, test_dl = build_audio_data_loaders(test_size=test_size)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
adversarial_criterion = nn.BCEWithLogitsLoss()

generator = Generator_HiFi_Dil().to(device)
discriminator = Discriminator_HiFi_Dil().to(device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [4]:
num_epochs = 15

In [5]:
history = {
    "gen": [],
    "disc": []
}

for epoch in range(num_epochs):
    epoch_gen_loss = 0.0
    epoch_disc_loss = 0.0
    num_batches = 0

    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        optimizer_d.zero_grad()

        generated_audio = generator(X)
        real_out = discriminator(Y)
        fake_out = discriminator(generated_audio.detach())

        d_loss = adversarial_criterion(
            real_out - fake_out, torch.ones_like(real_out))

        epoch_disc_loss += d_loss.item()

        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()

        gen_out = discriminator(generated_audio)
        g_loss = adversarial_criterion(gen_out, torch.ones_like(gen_out))

        epoch_gen_loss += g_loss.item()

        g_loss.backward()
        optimizer_g.step()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f" - {num_batches}\t D: {d_loss.item():.6f}\t G: {g_loss.item():.6f}")

    epoch_disc_loss /= num_batches
    epoch_gen_loss /= num_batches
    history["disc"].append(epoch_disc_loss)
    history["gen"].append(epoch_gen_loss)
    print(f"E {epoch}\t D: {epoch_disc_loss:.6f}\t G: {epoch_gen_loss:.6f}")

 - 100	 D: 0.675163	 G: 0.477180
 - 200	 D: 0.704602	 G: 0.398389
 - 300	 D: 0.697937	 G: 0.380485
 - 400	 D: 0.624354	 G: 0.373141
E [0]	 D: 0.675302	 G: 0.454749
 - 100	 D: 0.668123	 G: 0.263440
 - 200	 D: 0.689036	 G: 0.254612
 - 300	 D: 0.650361	 G: 0.261404
 - 400	 D: 0.606445	 G: 0.241900
E [1]	 D: 0.656839	 G: 0.269001
 - 100	 D: 0.740719	 G: 0.207589
 - 200	 D: 0.676526	 G: 0.204214
 - 300	 D: 0.687507	 G: 0.155489
 - 400	 D: 0.533653	 G: 0.158992
E [2]	 D: 0.645246	 G: 0.191223
 - 100	 D: 0.626964	 G: 0.128523
 - 200	 D: 0.686953	 G: 0.129936
 - 300	 D: 0.658291	 G: 0.135620
 - 400	 D: 0.693316	 G: 0.136303
E [3]	 D: 0.658794	 G: 0.140987
 - 100	 D: 0.712068	 G: 0.116351
 - 200	 D: 0.658760	 G: 0.121089
 - 300	 D: 0.737299	 G: 0.132768
 - 400	 D: 0.641139	 G: 0.120771
E [4]	 D: 0.667965	 G: 0.129427
 - 100	 D: 0.669484	 G: 0.117586
 - 200	 D: 0.590304	 G: 0.152789
 - 300	 D: 0.638506	 G: 0.136068
 - 400	 D: 0.669199	 G: 0.141506
E [5]	 D: 0.656061	 G: 0.129759
 - 100	 D: 0.614

In [6]:
generator.to(cpu)
discriminator.to(cpu)

save_model(generator, "hifi_gen")
save_model(discriminator, "hifi_disc")