In [1]:
cd ..

/Users/bdboy/Desktop/Projects/Music-Generation


In [2]:
from melGAN import Generator, Discriminator
from dataset import AudioDataset

In [3]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [4]:
import matplotlib.pyplot as plt
import numpy as np

In [5]:
AUDIO_DIR = "/Users/bdboy/Desktop/Projects/Music-Generation/data/drums/train"
SAMPLE_RATE = 16000
NUM_SAMPLES = 8000

if torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"
print(f"Using device {DEVICE}")

TRANSFORM = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=1024,
    hop_length=512,
    n_mels=64
)

drums = AudioDataset(AUDIO_DIR, TRANSFORM, SAMPLE_RATE, NUM_SAMPLES, DEVICE)
dataloader = DataLoader(drums, batch_size=128, shuffle=True)

print(f"There are {len(drums)} samples in the dataset.")
signal = drums[0]
print(f"Shape of signal: {signal.shape}")

Using device cpu
There are 2350 samples in the dataset.
Shape of signal: torch.Size([1, 64, 16])


In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [7]:
netG = Generator().to(DEVICE)
netD = Discriminator().to(DEVICE)

netG.apply(weights_init)
netD.apply(weights_init)

print(netG)
print(netD)

Generator(
  (activ): SELU()
  (conv1): ConvTranspose2d(20, 256, kernel_size=(3, 2), stride=(2, 2))
  (conv2): ConvTranspose2d(256, 128, kernel_size=(3, 2), stride=(2, 2))
  (conv3): ConvTranspose2d(128, 32, kernel_size=(3, 2), stride=(2, 2))
  (conv4): ConvTranspose2d(32, 8, kernel_size=(3, 2), stride=(2, 2), output_padding=(1, 1))
  (conv5): ConvTranspose2d(8, 1, kernel_size=(3, 2), stride=(2, 1), padding=(1, 1), output_padding=(1, 0))
)
Discriminator(
  (activ): SELU()
  (sigmoid): Sigmoid()
  (conv1): Conv2d(1, 4, kernel_size=(3, 2), stride=(2, 1), padding=(1, 1))
  (conv2): Conv2d(4, 16, kernel_size=(3, 2), stride=(2, 2))
  (conv3): Conv2d(16, 32, kernel_size=(3, 2), stride=(2, 2))
  (conv4): Conv2d(32, 64, kernel_size=(3, 2), stride=(2, 2))
  (conv5): Conv2d(64, 128, kernel_size=(3, 2), stride=(2, 2))
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=True)
)


In [8]:
lr = 0.0002
beta1 = 0.5

loss = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [9]:
num_epochs = 5000
real_label = 1
fake_label = 0


img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):


        netD.zero_grad()

        real_cpu = data.to(DEVICE)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=DEVICE)
        output = netD(real_cpu).view(-1)
        errD_real = loss(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(b_size, 20, 1, 1, device=DEVICE)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = loss(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake

        optimizerD.step()


        
        netG.zero_grad()

        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = loss(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()

        optimizerG.step()

        
        if epoch%50==0 and i==18:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        G_losses.append(errG.item())
        D_losses.append(errD.item())

Starting Training Loop...
[0/5000][18/19]	Loss_D: 0.7335	Loss_G: 1.2703	D(x): 0.7237	D(G(z)): 0.3173 / 0.2807
[50/5000][18/19]	Loss_D: 0.1624	Loss_G: 3.2273	D(x): 0.9600	D(G(z)): 0.0492 / 0.0421
[100/5000][18/19]	Loss_D: 0.1596	Loss_G: 3.3888	D(x): 0.9624	D(G(z)): 0.0595 / 0.0404
[150/5000][18/19]	Loss_D: 0.1264	Loss_G: 3.9758	D(x): 0.9781	D(G(z)): 0.0396 / 0.0214
[200/5000][18/19]	Loss_D: 0.0883	Loss_G: 3.3538	D(x): 0.9691	D(G(z)): 0.0395 / 0.0368
[250/5000][18/19]	Loss_D: 0.1278	Loss_G: 3.8742	D(x): 0.9787	D(G(z)): 0.0381 / 0.0258
[300/5000][18/19]	Loss_D: 0.1433	Loss_G: 3.8906	D(x): 0.9582	D(G(z)): 0.0209 / 0.0246
[350/5000][18/19]	Loss_D: 0.0956	Loss_G: 4.2079	D(x): 0.9714	D(G(z)): 0.0219 / 0.0164
[400/5000][18/19]	Loss_D: 0.1605	Loss_G: 2.8155	D(x): 0.9457	D(G(z)): 0.0337 / 0.0792
[450/5000][18/19]	Loss_D: 0.2979	Loss_G: 3.8888	D(x): 0.9348	D(G(z)): 0.0187 / 0.0243
[500/5000][18/19]	Loss_D: 0.1835	Loss_G: 3.3623	D(x): 0.9579	D(G(z)): 0.0371 / 0.0393
[550/5000][18/19]	Loss_D: 0.053

KeyboardInterrupt: 