In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from models.new_wave import WaveNet
from disk_utils import save_model, load_model
from audio_dataset import build_audio_data_loaders

In [2]:
test_size = 0.0
train_dl, _ = build_audio_data_loaders(test_size=test_size)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")

model = WaveNet(residual_channels=64,
                skip_channels=128,
                num_blocks=2,
                num_layers=10,
                kernel_size=2)
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=2e-4)

In [4]:
num_epochs = 20

In [5]:
model = model.to(device)
model.train()
for epoch in range(num_epochs):
    num_batches = 0
    running_loss = 0.0
    for X, Y, _, _ in train_dl:
        X = X.to(device)
        Y = Y.to(device)

        # Forward pass
        outputs = model(X)
        loss = criterion(outputs, Y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches += 1
        if num_batches % 100 == 0:
            print(f"\t{num_batches}\t B.L: {loss.item():.6f}")

    print(f"E {epoch + 1}\t E.L: {running_loss / num_batches:.6f}")

	100	 B.L: 0.155943
	200	 B.L: 0.168215
	300	 B.L: 0.213482
	400	 B.L: 0.190029
E 1	 E.L: 0.180382
	100	 B.L: 0.148048
	200	 B.L: 0.182386
	300	 B.L: 0.142340
	400	 B.L: 0.174808
E 2	 E.L: 0.175866
	100	 B.L: 0.220137
	200	 B.L: 0.133015
	300	 B.L: 0.158399
	400	 B.L: 0.162753
E 3	 E.L: 0.171406
	100	 B.L: 0.192080
	200	 B.L: 0.146756
	300	 B.L: 0.163001
	400	 B.L: 0.153057
E 4	 E.L: 0.159761
	100	 B.L: 0.129977
	200	 B.L: 0.125644
	300	 B.L: 0.172950
	400	 B.L: 0.110830
E 5	 E.L: 0.144986
	100	 B.L: 0.112927
	200	 B.L: 0.143576
	300	 B.L: 0.202145
	400	 B.L: 0.141173
E 6	 E.L: 0.130875
	100	 B.L: 0.130962
	200	 B.L: 0.095541
	300	 B.L: 0.140166
	400	 B.L: 0.127088
E 7	 E.L: 0.117361
	100	 B.L: 0.085567
	200	 B.L: 0.112944
	300	 B.L: 0.106263
	400	 B.L: 0.054536
E 8	 E.L: 0.106806
	100	 B.L: 0.094692
	200	 B.L: 0.120001
	300	 B.L: 0.106398
	400	 B.L: 0.075586
E 9	 E.L: 0.097919
	100	 B.L: 0.076012
	200	 B.L: 0.113247
	300	 B.L: 0.079806
	400	 B.L: 0.098658
E 10	 E.L: 0.091458
	100	 B.L

In [6]:
model.to(cpu)
save_model(model, "wavenet")