In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
%cd '/content/drive/MyDrive/Hackathon - Sound of AI/Code/vqvae'

/content/drive/.shortcut-targets-by-id/1xPTjrOvC9-VMw67aBMG_XOD53WhDymqb/Hackathon - Sound of AI/Code/vqvae


In [3]:
%%capture
!pip install natsort

In [5]:
!nvidia-smi

Sun Jul 10 01:11:06 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    34W / 250W |    933MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import utils
from models.vqvae import VQVAE

import os
import librosa
from natsort import natsorted
import matplotlib.pyplot as plt
import librosa.display
from IPython.display import Audio

parser = argparse.ArgumentParser()

"""
Hyperparameters
"""
timestamp = utils.readable_timestamp()

parser.add_argument("-f")

# DATASET
# Reminder to load audios to local disk , and thus path will be local
parser.add_argument("--audio_folder_path",  type=str, default='/content/drive/MyDrive/Hackathon - Sound of AI/Code/Dataset/Augmented-ESC/audios/')
parser.add_argument("--sample_rate", type=int, default=22050)
parser.add_argument("--train_val_split", type=float, default=0.8)
parser.add_argument("--lazy", type=bool, default=True)
parser.add_argument("--n_fft", type=int, default=2048)
parser.add_argument("--win_len", type=int, default=2000)
parser.add_argument("--hop_len", type=int, default=500)

# DATALOADER
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=1)

# MODEL
parser.add_argument("--n_hiddens", type=int, default=128)
parser.add_argument("--n_residual_layers", type=int, default=2)
parser.add_argument("--n_residual_hiddens", type=int, default=32)
parser.add_argument("--embedding_dim", type=int, default=64)
parser.add_argument("--beta", type=float, default=.25)
parser.add_argument("--n_embeddings", type=int, default=2048)

# TRAINING
# whether or not to save model

parser.add_argument("-save", action="store_true", default=True)
parser.add_argument("--filename",  type=str, default=timestamp)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--log_interval", type=int, default=250)
parser.add_argument("--n_updates", type=int, default=500000)


args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.save:
    print('Results will be saved in ./results/vqvae_' + args.filename + '.pth')

# Load data and define batch data loaders TODO: xtraining var
training_data, validation_data, training_loader, validation_loader = utils.load_data_and_data_loaders(
    args.audio_folder_path, args.sample_rate, args.train_val_split, args.lazy, args.n_fft, args.win_len,
    args.hop_len, args.num_workers, args.batch_size)
x_train_var = 0.22

#Set up VQ-VAE model with components defined in ./models/ folder
model = VQVAE(args.n_hiddens, args.n_residual_hiddens,
              args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device)

#Set up optimizer and training loop
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)

model.train()

results = {
    'n_updates': 0,
    'recon_errors': [],
    'loss_vals': [],
    'perplexities': [],
}

savepath = os.path.join('/content/drive/MyDrive/Hackathon - Sound of AI/Code/vqvae', 'results')

Results will be saved in ./results/vqvae_sun_jul_10_01_10_20_2022.pth


Train

In [6]:
for i in range(args.n_updates):
    x = next(iter(training_loader)).unsqueeze(1)
    x = x.to(device)
    optimizer.zero_grad()

    embedding_loss, x_hat, perplexity = model(x)
    recon_loss = torch.mean((x_hat - x)**2) / x_train_var
    loss = recon_loss + embedding_loss

    loss.backward()
    optimizer.step()

    results["recon_errors"].append(recon_loss.cpu().detach().numpy())
    results["perplexities"].append(perplexity.cpu().detach().numpy())
    results["loss_vals"].append(loss.cpu().detach().numpy())
    results["n_updates"] = i

    if i % args.log_interval == 0:
        """
        save model and print values
        """
        if args.save:
            hyperparameters = args.__dict__
            utils.save_model_and_results(
                model, results, hyperparameters, args.filename, savepath, i)

        print('Update #', i, 'Recon Error:',
                np.mean(results["recon_errors"][-args.log_interval:]),
                'Loss', np.mean(results["loss_vals"][-args.log_interval:]),
                'Perplexity:', np.mean(results["perplexities"][-args.log_interval:]))


Update # 0 Recon Error: 1.095593 Loss 1.09899 Perplexity: 1.8725001
Update # 250 Recon Error: 0.94663787 Loss 73.351906 Perplexity: 1.2571812
Update # 500 Recon Error: 0.8261293 Loss 14.903319 Perplexity: 1.7110134
Update # 750 Recon Error: 0.64459056 Loss 18.588219 Perplexity: 1.5227239
Update # 1000 Recon Error: 0.38096532 Loss 6.5311294 Perplexity: 2.2091994
Update # 1250 Recon Error: 0.33929217 Loss 3.0227876 Perplexity: 2.5803344
Update # 1500 Recon Error: 0.3155814 Loss 1.6457694 Perplexity: 2.439433
Update # 1750 Recon Error: 0.31481805 Loss 1.1644542 Perplexity: 2.4015005
Update # 2000 Recon Error: 0.439081 Loss 1.8528091 Perplexity: 2.3736944
Update # 2250 Recon Error: 0.30823046 Loss 0.6505829 Perplexity: 2.7416327
Update # 2500 Recon Error: 0.29908055 Loss 0.5708589 Perplexity: 2.8844912
Update # 2750 Recon Error: 0.29632983 Loss 0.62103 Perplexity: 3.0563867
Update # 3000 Recon Error: 0.2867398 Loss 0.61815786 Perplexity: 3.145484
Update # 3250 Recon Error: 0.28067365 Loss 

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feaf8801680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1322, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


Update # 8750 Recon Error: 0.21114136 Loss 0.2231825 Perplexity: 16.510643


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7feaf8801680>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1322, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/usr/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


KeyboardInterrupt: ignored