In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from models.new_wave import WaveNet
from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders

In [2]:
test_size = 0.0
train_dl, _ = build_audio_data_loaders(test_size=test_size)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")

model = WaveNet(residual_channels=64,
                skip_channels=128,
                num_blocks=2,
                num_layers=10,
                kernel_size=2)
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=2e-4)

In [4]:
num_epochs = 40

In [5]:
model = model.to(device)
model.train()
for epoch in range(num_epochs):
    num_batches = 0
    running_loss = 0.0
    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        # Forward pass
        outputs = model(X)
        loss = criterion(outputs, Y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f"\t{num_batches}\t B.L: {loss.item():.6f}")

    print(f"E {epoch + 1}\t E.L: {running_loss / num_batches:.6f}")

	100	 B.L: 0.202646
	200	 B.L: 0.189578
	300	 B.L: 0.211914
	400	 B.L: 0.194932
E 1	 E.L: 0.180184
	100	 B.L: 0.204060
	200	 B.L: 0.184506
	300	 B.L: 0.213561
	400	 B.L: 0.152655
E 2	 E.L: 0.176907
	100	 B.L: 0.173467
	200	 B.L: 0.186961
	300	 B.L: 0.176622
	400	 B.L: 0.160191
E 3	 E.L: 0.174720
	100	 B.L: 0.176798
	200	 B.L: 0.168423
	300	 B.L: 0.196584
	400	 B.L: 0.162889
E 4	 E.L: 0.170511
	100	 B.L: 0.103505
	200	 B.L: 0.121356
	300	 B.L: 0.168574
	400	 B.L: 0.231297
E 5	 E.L: 0.161578
	100	 B.L: 0.173487
	200	 B.L: 0.109389
	300	 B.L: 0.117765
	400	 B.L: 0.127875
E 6	 E.L: 0.148796
	100	 B.L: 0.169228
	200	 B.L: 0.106238
	300	 B.L: 0.103225
	400	 B.L: 0.161509
E 7	 E.L: 0.136507
	100	 B.L: 0.118525
	200	 B.L: 0.109898
	300	 B.L: 0.126439
	400	 B.L: 0.147214
E 8	 E.L: 0.123738
	100	 B.L: 0.099292
	200	 B.L: 0.080249
	300	 B.L: 0.126972
	400	 B.L: 0.118881
E 9	 E.L: 0.112150
	100	 B.L: 0.117366
	200	 B.L: 0.125143
	300	 B.L: 0.113023
	400	 B.L: 0.118348
E 10	 E.L: 0.101974
	100	 B.L

In [6]:
model.to(cpu)
save_model(model, "wavenet")