In [3]:
# Generator V2 Notes

import torch
import torch.nn as nn
from tqdm import tqdm
from tqdm.auto import trange
import pandas as pd
from time import sleep

from model import AudioClassifier
from util import SoundDS, PlotSpectrogram

# Based on https://towardsdatascience.com/build-a-super-simple-gan-in-pytorch-54ba349920e4

# We need a generator that takes in noise
# and generates torch.Size([2, 64, 344])
# 2 = num_channels (this is `inputs` which the model is run on)
# 64 = Mel freq_bands
# 344 = time_steps

class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 4096)
        self.fc3 = nn.Linear(4096, 44032)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 2, 64, 344)
        return nn.Tanh()(x)
    
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""

num_epochs = 5
batch_size = 64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
D_lr, G_lr = 0.002, 0.001

D = AudioClassifier()
D_optimizer = torch.optim.Adam(D.parameters(), lr=D_lr)
# cp = torch.load('D:/Development/generative-beatpack/models/zesty-salad-125_e512_f3_b128.pt')
# D.load_state_dict(cp['model_state_dict'])
# D_optimizer.load_state_dict(cp['optimizer_state_dict'])
D.to(device)

G = generator()
G_optimizer = torch.optim.Adam(G.parameters(), lr=G_lr)
# gcp = torch.load('./Generator_epoch_19.pth') # e 410
# G.load_state_dict(cp, strict=False)
G.to(device)

loss = nn.BCEWithLogitsLoss()
gloss = nn.BCELoss()

# Load Data

df = pd.read_csv('data/edm_no_loops.csv')
df = df[['path', 'class']]
myds = SoundDS(df)
train_dl = torch.utils.data.DataLoader(myds, batch_size=batch_size, shuffle=True)

with trange(num_epochs, unit="epoch") as tepoch:
    for epoch in tepoch:
        for idx, data in enumerate(train_dl):
            idx += 1
            
            # zero the gradients on each iteration
            G_optimizer.zero_grad()
            
            # Generate examples of even real data
            true_data = data[0].to(device)
            inputs_m, inputs_s = true_data.mean(), true_data.std()
            true_data = (true_data - inputs_m) / inputs_s
            true_labels = torch.ones(true_data.shape[0], 10).to(device)
            
            # Create noise
            noise = (torch.rand(true_data.shape[0], 128) - 0.5) / 0.5
            noise = noise.to(device)
            generated_data = G(noise)
            fake_label = torch.zeros(generated_data.shape[0], 10).to(device)

            
            # Train the Generator...
            #
            # We flip the labels here
            # and don't train discriminator.
            # because we want the generator
            # to make things the discriminator 
            # classifies as true.
            
            generator_discriminator_out = D(generated_data)
            generator_loss = loss(generator_discriminator_out, true_labels)
            generator_loss.backward()
            G_optimizer.step()
            
            # Train the discriminator on true/generated data
            D_optimizer.zero_grad()
            true_discriminator_out = D(true_data)
            true_discriminator_loss = loss(true_discriminator_out, true_labels)
            
            # Add .detach() here thing about this
            generator_discriminator_out = D(generated_data.detach())
            generator_discriminator_loss = loss(generator_discriminator_out, fake_label)
            discriminator_loss = (true_discriminator_loss + generator_discriminator_loss) / 2
            discriminator_loss.backward()
            D_optimizer.step()
            
            if idx == len(train_dl):
                print("Random Sample from Batch")
                real_sg = true_data.view(-1, 2, 64, 344)[0].cpu().detach().numpy()
                PlotSpectrogram(real_sg[0], title="Real Sample: "+str(discriminator_loss.item())+" loss")
            
            if idx == len(train_dl):
                gen_sg = generated_data.view(-1, 2, 64, 344)[0].cpu().detach().numpy()
                PlotSpectrogram(gen_sg[0], title=f"Generated Sample: "+str(generator_loss.item())+"")

            tepoch.set_postfix(b=f"{idx}/{len(train_dl)}", e=f"{epoch}/{num_epochs}", D_loss=discriminator_loss.item(), G_loss=generator_loss.item())
            sleep(0.1)
            
    torch.save(G, 'firstTest_e{}.pth'.format(epoch))
    print('Model saved.')

  0%|          | 0/5 [00:00<?, ?epoch/s]

KeyboardInterrupt: 

In [None]:
# Generator V1 Notes

# Based on https://towardsdatascience.com/building-a-gan-with-pytorch-237b4b07ca9a

# We need a generator that takes in noise
# and generates torch.Size([2, 64, 344])
# 2 = num_channels (this is `inputs` which the model is run on)
# 64 = Mel freq_bands
# 344 = time_steps

class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 4096)
        self.fc3 = nn.Linear(4096, 44032)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 2, 64, 344)
        return nn.Tanh()(x)
    
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""

from tqdm import tqdm
from tqdm.auto import trange
from time import sleep

D = AudioClassifier()
D_optimizer = torch.optim.Adam(D.parameters(), lr=0.002)
# cp = torch.load('D:/Development/generative-beatpack/models/zesty-salad-125_e512_f3_b128.pt')
# D.load_state_dict(cp['model_state_dict'])
# D_optimizer.load_state_dict(cp['optimizer_state_dict'])
D.to(device)

G = generator()
G_optimizer = torch.optim.Adam(G.parameters(), lr=0.001)
# gcp = torch.load('./Generator_epoch_19.pth') # e 410
# G.load_state_dict(cp, strict=False)
G.to(device)

loss = nn.BCEWithLogitsLoss()
gloss = nn.BCELoss()

num_epochs = 100

with trange(num_epochs, unit="epoch") as tepoch:
    for epoch in tepoch:
        for idx, data in enumerate(train_dl):
            idx += 1

            # Train discriminator (classifier)
            # Fake inputs are from generator
            # Real inputs are classified as 1 and fake as 0
            # Real inputs are from train_dl
            real_inputs = data[0].to(device)
            inputs_m, inputs_s = real_inputs.mean(), real_inputs.std()
            real_inputs = (real_inputs - inputs_m) / inputs_s
            real_outputs = D(real_inputs)
            real_label = torch.ones(real_inputs.shape[0], 10).to(device)
            
            noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
            noise = noise.to(device)
            fake_inputs = G(noise)
            fake_outputs = D(fake_inputs)
            fake_label = torch.zeros(fake_inputs.shape[0], 10).to(device)
            
            outputs = torch.cat((real_outputs, fake_outputs), 0)
            targets = torch.cat((real_label, fake_label), 0)
            
            D_loss = loss(outputs,targets)
            D_optimizer.zero_grad()
            D_loss.backward()
            D_optimizer.step()
            
            if idx == len(train_dl):
                print("Random Sample from Batch")
                real_sg = real_inputs.view(-1, 2, 64, 344)[0].cpu().detach().numpy()
                PlotSpectrogram(real_sg[0], title="Real Sample: "+str(D_loss.item())+" loss")
                

            
            # Training the generator
            # For generator, we want it to believe everything is 1
            noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
            noise = noise.to(device)
            
            fake_inputs = G(noise)
            fake_outputs = D(fake_inputs)
            fake_targets = torch.ones([fake_inputs.shape[0], 10]).to(device)
            G_loss = gloss(fake_outputs, fake_targets)
            G_optimizer.zero_grad()
            G_loss.backward()
            G_optimizer.step()
            
            if idx == len(train_dl):
                gen_sg = fake_inputs.view(-1, 2, 64, 344)[0].cpu().detach().numpy()
                PlotSpectrogram(gen_sg[0], title=f"Generated Noise Sample: "+str(G_loss.item())+"")

            tepoch.set_postfix(b=f"{idx}/{len(train_dl)}", e=f"{epoch}/{num_epochs}", D_loss=D_loss.item(), G_loss=G_loss.item())
            sleep(0.1)
            
    torch.save(G, 'firstTest_e{}.pth'.format(epoch))
    print('Model saved.')