In [1]:
import torch
import torch.nn as nn
from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders
from models.hifi_dil import Generator_HiFi_Dil, Discriminator_HiFi_Dil

In [2]:
test_size = 0.0
train_dl, _ = build_audio_data_loaders(test_size=test_size)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
adversarial_criterion = nn.BCEWithLogitsLoss()

generator = Generator_HiFi_Dil().to(device)
discriminator = Discriminator_HiFi_Dil().to(device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [4]:
num_epochs = 30

In [5]:
history = {
    "gen": [],
    "disc": []
}

for epoch in range(num_epochs):
    epoch_gen_loss = 0.0
    epoch_disc_loss = 0.0
    num_batches = 0

    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        optimizer_d.zero_grad()

        generated_audio = generator(X)
        real_out = discriminator(Y)
        fake_out = discriminator(generated_audio.detach())

        d_loss = adversarial_criterion(
            real_out - fake_out, torch.ones_like(real_out))

        epoch_disc_loss += d_loss.item()

        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()

        gen_out = discriminator(generated_audio)
        g_loss = adversarial_criterion(gen_out, torch.ones_like(gen_out))

        epoch_gen_loss += g_loss.item()

        g_loss.backward()
        optimizer_g.step()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f" - {num_batches}\t D: {d_loss.item():.6f}\t G: {g_loss.item():.6f}")

    epoch_disc_loss /= num_batches
    epoch_gen_loss /= num_batches
    history["disc"].append(epoch_disc_loss)
    history["gen"].append(epoch_gen_loss)
    print(f"E {epoch}\t D: {epoch_disc_loss:.6f}\t G: {epoch_gen_loss:.6f}")

 - 100	 D: 0.668653	 G: 0.510407
 - 200	 D: 0.673445	 G: 0.416101
 - 300	 D: 0.720917	 G: 0.313038
 - 400	 D: 0.676064	 G: 0.298328
E 0	 D: 0.676327	 G: 0.418594
 - 100	 D: 0.680671	 G: 0.226967
 - 200	 D: 0.666313	 G: 0.288115
 - 300	 D: 0.683087	 G: 0.230108
 - 400	 D: 0.665094	 G: 0.259232
E 1	 D: 0.674174	 G: 0.260336
 - 100	 D: 0.647496	 G: 0.266947
 - 200	 D: 0.665049	 G: 0.236946
 - 300	 D: 0.602991	 G: 0.202694
 - 400	 D: 0.718824	 G: 0.147766
E 2	 D: 0.655599	 G: 0.216885
 - 100	 D: 0.666233	 G: 0.166279
 - 200	 D: 0.763044	 G: 0.182509
 - 300	 D: 0.673203	 G: 0.173096
 - 400	 D: 0.659507	 G: 0.153023
E 3	 D: 0.659020	 G: 0.180054
 - 100	 D: 0.647249	 G: 0.155971
 - 200	 D: 0.687514	 G: 0.181727
 - 300	 D: 0.678294	 G: 0.148245
 - 400	 D: 0.649668	 G: 0.141122
E 4	 D: 0.653690	 G: 0.155267
 - 100	 D: 0.635044	 G: 0.143090
 - 200	 D: 0.571643	 G: 0.128092
 - 300	 D: 0.662045	 G: 0.129668
 - 400	 D: 0.630561	 G: 0.102053
E 5	 D: 0.661182	 G: 0.129311
 - 100	 D: 0.668842	 G: 0.11

In [6]:
generator.to(cpu)
discriminator.to(cpu)

save_model(generator, "hifi_gen_30_f")
save_model(discriminator, "hifi_disc_30_f")