In [1]:
import torch
import torch.nn as nn
from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders
from models.hifi_dil import Generator_HiFi_Dil, Discriminator_HiFi_Dil

In [2]:
test_size = 0.0
train_dl, _ = build_audio_data_loaders(test_size=test_size)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")
adversarial_criterion = nn.BCEWithLogitsLoss()

generator = Generator_HiFi_Dil().to(device)
discriminator = Discriminator_HiFi_Dil().to(device)

optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [4]:
num_epochs = 30

In [5]:
history = {
    "gen": [],
    "disc": []
}

for epoch in range(num_epochs):
    epoch_gen_loss = 0.0
    epoch_disc_loss = 0.0
    num_batches = 0

    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        optimizer_d.zero_grad()

        generated_audio = generator(X)
        real_out = discriminator(Y)
        fake_out = discriminator(generated_audio.detach())

        d_loss = adversarial_criterion(
            real_out - fake_out, torch.ones_like(real_out))

        epoch_disc_loss += d_loss.item()

        d_loss.backward()
        optimizer_d.step()

        optimizer_g.zero_grad()

        gen_out = discriminator(generated_audio)
        g_loss = adversarial_criterion(gen_out, torch.ones_like(gen_out))

        epoch_gen_loss += g_loss.item()

        g_loss.backward()
        optimizer_g.step()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f" - {num_batches}\t D: {d_loss.item():.6f}\t G: {g_loss.item():.6f}")

    epoch_disc_loss /= num_batches
    epoch_gen_loss /= num_batches
    history["disc"].append(epoch_disc_loss)
    history["gen"].append(epoch_gen_loss)
    print(f"E {epoch}\t D: {epoch_disc_loss:.6f}\t G: {epoch_gen_loss:.6f}")

 - 100	 D: 0.684643	 G: 0.450203
 - 200	 D: 0.691863	 G: 0.451808
 - 300	 D: 0.691072	 G: 0.456235
 - 400	 D: 0.677572	 G: 0.465861
E 0	 D: 0.689794	 G: 0.452395
 - 100	 D: 0.678191	 G: 0.450968
 - 200	 D: 0.688373	 G: 0.452413
 - 300	 D: 0.696380	 G: 0.445759
 - 400	 D: 0.689317	 G: 0.469983
E 1	 D: 0.688828	 G: 0.459859
 - 100	 D: 0.691327	 G: 0.465695
 - 200	 D: 0.692886	 G: 0.468562
 - 300	 D: 0.692839	 G: 0.454151
 - 400	 D: 0.686196	 G: 0.466932
E 2	 D: 0.688831	 G: 0.464925
 - 100	 D: 0.692200	 G: 0.487910
 - 200	 D: 0.683933	 G: 0.481027
 - 300	 D: 0.665511	 G: 0.498896
 - 400	 D: 0.698466	 G: 0.499669
E 3	 D: 0.685628	 G: 0.496425
 - 100	 D: 0.691478	 G: 0.531600
 - 200	 D: 0.687303	 G: 0.618509
 - 300	 D: 0.699276	 G: 0.620326
 - 400	 D: 0.692008	 G: 0.628341
E 4	 D: 0.687902	 G: 0.597544
 - 100	 D: 0.700114	 G: 0.644500
 - 200	 D: 0.680774	 G: 0.643776
 - 300	 D: 0.690149	 G: 0.677298
 - 400	 D: 0.699860	 G: 0.667843
E 5	 D: 0.687374	 G: 0.655054
 - 100	 D: 0.669788	 G: 0.66

In [6]:
generator.to(cpu)
discriminator.to(cpu)

save_model(generator, "hifi_gen_30_f")
save_model(discriminator, "hifi_disc_30_f")