# Audio Modelling VST
Create a VST from only MIDI and corresponding audio of an instrument.
Dataset: [here](https://magenta.tensorflow.org/datasets/nsynth#instrument-sources).


# 1. **Import**

In [None]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torchsummary import summary
from IPython.display import HTML, Audio
from torch.utils.data import Dataset, DataLoader
import math
import itertools

import glob
import json
from typing import TypedDict
import torchaudio
from torchaudio.transforms import Spectrogram, InverseSpectrogram
from torchaudio.functional import amplitude_to_DB, DB_to_amplitude
from tqdm import tqdm
import shutil
import requests
import tarfile
import gc

# Set random seed for reproducibility
manualSeed = 999
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

In [None]:
# !pip install lightning-bolts
!pip install git+https://github.com/PytorchLightning/lightning-bolts.git@master --upgrade
!pip install pytorch-lightning

In [None]:
import pytorch_lightning as pl
from pl_bolts.models.gans import Pix2Pix

# 2. **Dataset Setup**

## **1.** Connect to Drive

In [None]:
from google.colab import drive
import os

GDRIVE_DIR = "/content/drive/"
GDRIVE_HOME_DIR = os.path.join(GDRIVE_DIR, "MyDrive")
GDRIVE_DATA_DIR = os.path.join(GDRIVE_HOME_DIR, "AudioModelling")

# Mount drive
drive.mount(GDRIVE_DIR, force_remount=True)

## **2.** Downloader Utilities

In [None]:
"""
This function downloads a file from a specific URL directly to Google Drive.
"""
def get_data(dataset_url, dest, chunk_size=1024):
  # Check if file already exists
  if os.path.exists(dest):
    print(f"{dest} already exists.")
    return
  # Downloading file
  with requests.get(dataset_url, stream=True) as r:
      r.raise_for_status()
      with open(dest, "wb") as f:
          pbar = tqdm(total=int(r.headers["Content-Length"]), unit="B", unit_scale=True, unit_divisor=1024)
          for chunk in r.iter_content(chunk_size=8192):
              if chunk:  # filter out keep-alive new chunks
                  f.write(chunk)
                  pbar.update(len(chunk))

"""
This function extract a tar file.
"""
def extract_file(tar_path, dest_path, dataset_type):
  status_file = os.path.join(dest_path, f"extract_status_{dataset_type}.txt")
  # Check if already extracted
  if os.path.exists(status_file):
    print(f"Data already extracted.")
    return
  # Extract the tar file
  with tarfile.open(tar_path) as archive:
    archive.extractall(dest_path)
  # Write file to confirm extraction gone well
  with open(status_file, "w") as f:
    f.write("OK")


## **3.** Download Train Dataset

In [None]:
TRAIN_TYPE="valid" # train, valid, test
DATASET_URL = f"http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-{TRAIN_TYPE}.jsonwav.tar.gz"
GDRIVE_DATA_TRAIN_ZIP = os.path.join(GDRIVE_DATA_DIR, DATASET_URL.split("/")[-1])

LOCAL_DATA_DIR = "/content"
LOCAL_DATA_TRAIN_ZIP = os.path.join(LOCAL_DATA_DIR, DATASET_URL.split("/")[-1])
DO_LOCAL_TRAIN = True

# Download the dataset in Google Drive
get_data(DATASET_URL, GDRIVE_DATA_TRAIN_ZIP)

if DO_LOCAL_TRAIN:
  # Clone dataset to local colab and extract it
  if not os.path.exists(LOCAL_DATA_TRAIN_ZIP):
    shutil.copyfile(GDRIVE_DATA_TRAIN_ZIP, LOCAL_DATA_TRAIN_ZIP)
  extract_file(LOCAL_DATA_TRAIN_ZIP, LOCAL_DATA_DIR, TRAIN_TYPE)
else:
  # Extract the dataset in Google Drive
  extract_file(GDRIVE_DATA_TRAIN_ZIP, GDRIVE_DATA_DIR, TRAIN_TYPE)
  os.makedirs(f"nsynth-{TRAIN_TYPE}/audio", exist_ok=True)
  shutil.copy(GDRIVE_DATA_DIR + f"/nsynth-{TRAIN_TYPE}/examples.json", f"nsynth-{TRAIN_TYPE}/examples.json")

## **4.** Download Test Dataset

In [None]:
TEST_TYPE="test" # train, valid, test
DATASET_URL = f"http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-{TEST_TYPE}.jsonwav.tar.gz"
GDRIVE_DATA_TEST_ZIP = os.path.join(GDRIVE_DATA_DIR, DATASET_URL.split("/")[-1])

LOCAL_DATA_DIR = "/content"
LOCAL_DATA_TEST_ZIP = os.path.join(LOCAL_DATA_DIR, DATASET_URL.split("/")[-1])
DO_LOCAL_TEST = True

# Download the dataset in Google Drive
get_data(DATASET_URL, GDRIVE_DATA_TEST_ZIP)

if DO_LOCAL_TEST:
  # Clone dataset to local colab and extract it
  if not os.path.exists(LOCAL_DATA_TEST_ZIP):
    shutil.copyfile(GDRIVE_DATA_TEST_ZIP, LOCAL_DATA_TEST_ZIP)
  extract_file(LOCAL_DATA_TEST_ZIP, LOCAL_DATA_DIR, TEST_TYPE)
else:
  # Extract the dataset in Google Drive
  extract_file(GDRIVE_DATA_TEST_ZIP, GDRIVE_DATA_DIR, TEST_TYPE)

## **5.** Dataset Structure
The dataset has 11 (different instruments) × 3 (types).
<table>
  <tr>
    <td>
      <table>
        <thead>
          <tr>
            <th>Index</th>
            <th>ID</th>
          </tr>
        </thead>
        <tbody>
          <tr>
            <td>0</td>
            <td>bass</td>
          </tr>
          <tr>
            <td>1</td>
            <td>brass</td>
          </tr>
          <tr>
            <td>2</td>
            <td>flute</td>
          </tr>
          <tr>
            <td>3</td>
            <td>guitar</td>
          </tr>
          <tr>
            <td>4</td>
            <td>keyboard</td>
          </tr>
          <tr>
            <td>5</td>
            <td>mallet</td>
          </tr>
          <tr>
            <td>6</td>
            <td>organ</td>
          </tr>
          <tr>
            <td>7</td>
            <td>reed</td>
          </tr>
          <tr>
            <td>8</td>
            <td>string</td>
          </tr>
          <tr>
            <td>9</td>
            <td>synth_lead</td>
          </tr>
          <tr>
            <td>10</td>
            <td>vocal</td>
          </tr>
        </tbody>
      </table>
    </td>
    <td width="40%">
    </td>
    <td>
      <table>
        <thead>
          <tr>
            <th>Index</th>
            <th>ID</th>
          </tr>
        </thead>
        <tbody>
          <tr>
            <td>0</td>
            <td>acoustic</td>
          </tr>
          <tr>
            <td>1</td>
            <td>electronic</td>
          </tr>
          <tr>
            <td>2</td>
            <td>synthetic</td>
          </tr>
        </tbody>
      </table>
    </td>
  </tr>
</table>

In [None]:
instruments = [ "bass", "brass", "flute", "guitar", "keyboard", "mallet", "organ", "reed", "string", "synth_lead", "vocal" ]
types = [ "acoustic", "electronic", "synthetic" ]

# selected_instrument = "flute_synthetic_000"
selected_instrument = "guitar_acoustic"

class DataJson(TypedDict):
  note: int
  sample_rate: int
  qualities: list[bool]
  pitch: int
  note_str: str
  instrument_family_str: str

train_dataroot = f"nsynth-{TRAIN_TYPE}/audio/"
test_dataroot = f"nsynth-{TEST_TYPE}/audio/"

# Load training dataset
data = json.load(open(f"nsynth-{TRAIN_TYPE}/examples.json"))
json_data: list[DataJson] = [x for k, x in data.items() if selected_instrument in x['instrument_str']]
sr = [x["sample_rate"] for x in json_data]
print(f"For {selected_instrument} we have {len(json_data)} audio files in training set")

# Load test dataset
data = json.load(open(f"nsynth-{TEST_TYPE}/examples.json"))
json_test_data: list[DataJson] = [x for k, x in data.items() if selected_instrument in x['instrument_str']]
print(f"For {selected_instrument} we have {len(json_test_data)} audio files in test set")

del data

# 3. **PySpark + Colab Setup**

## **1.** Install PySpark and related dependencies

In [None]:
!pip install pyspark

## **2.** Import useful PySpark packages

In [None]:
import pyspark
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf

## **3.** Create Spark context

In [None]:
# Create the session
conf = SparkConf().\
                set('spark.ui.port', "4050").\
                set('spark.executor.memory', '4G').\
                set('spark.driver.memory', '45G').\
                set('spark.driver.maxResultSize', '10G').\
                setAppName("AudioModelling").\
                setMaster("local[*]")

# Create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

## **4.** Create Web UI Console

In [None]:
# Install ngrok
!pip install pyngrok

In [None]:
# Auth on ngrok
!ngrok authtoken 2MyVytznaXPBeUvYCOpAzAXYMrH_6iBW9P8793yjNfBAoemCH

In [None]:
from pyngrok import ngrok

# Open a ngrok tunnel on the port 4050 where Spark is running
port = '4050'
public_url = ngrok.connect(port).public_url

In [None]:
print("To access the Spark Web UI console, please click on the following link to the ngrok tunnel \n\"{}\" -> \"http://127.0.0.1:{}\"".format(public_url, port))

## **5.** Check everything is ok

In [None]:
spark

In [None]:
sc._conf.getAll()

# 4. **Audio Utilities**

## 1. Load Audio

In [None]:
def load_audio(audio_path, normalize=True):
  wav, _ = torchaudio.load(audio_path, normalize=normalize)
  return wav

## 2. Show Spectrogram and Audio

In [None]:
def show_spec_and_audio(spec, phase, title=None, ylabel="freq_bin", aspect="auto", xmax=None):
  # Sine audio by midi spec
  fig, axs = plt.subplots(1, 1)
  axs.set_title(title or "Spectrogram (db)")
  axs.set_ylabel(ylabel)
  axs.set_xlabel("frame")
  im = axs.imshow(spec, origin="lower", aspect=aspect)
  if xmax:
      axs.set_xlim((0, xmax))
  fig.colorbar(im, ax=axs)
  plt.show(block=False)

  # Reconstructed original audio
  wav_orig_reconstruct = audioT.spec_to_wav(spec, phase)
  display(Audio(wav_orig_reconstruct, rate=midi_item["sample_rate"]))

## 3. Generate Audio from Midi  (sawtooth synth)

In [None]:
from scipy import signal

def midi_to_audio(pitch, velocity, duration=2, sample_rate=16000):
  freq = 440*(2**((pitch-69)/12))
  velocity /= 127
  t = np.linspace(0., duration, int(sample_rate * duration))
  # sample = signal.sawtooth(freq * 2. * np.pi * t ) * velocity
  sample = np.sin(freq * 2. * np.pi * t ) * velocity
  return torch.tensor([sample])

Audio(midi_to_audio(70, 127), rate=16000, normalize=False)

## 3. Audio-Melspectrogram conversion

In [None]:
class AudioTransform():
  def __init__(self):
    super().__init__()
    self.n_fft = 1024
    self.hop_length = self.n_fft//4
    self.spectrogram = Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length)
    self.inverse_spectrogram = InverseSpectrogram(n_fft=self.n_fft, hop_length=self.hop_length)

  def wav_to_spec(self, wav, sample_rate=16000):
    # Transform audio to spec
    S = self.spectrogram(wav).squeeze()
    S, phase = torch.abs(S), torch.angle(S)
    # Create a (h=2x, w=x) size array
    blank_column = torch.zeros((S.shape[0], 5))
    S = torch.cat((S, blank_column), axis=1)[:self.n_fft//2]
    phase = torch.cat((phase, blank_column), axis=1)[:self.n_fft//2]
    # Get amplitude
    S = amplitude_to_DB(S, multiplier=10, amin=0, db_multiplier=torch.log10(torch.max(S))) + 80
    return S, phase

  def spec_to_wav(self, spec, phase, sample_rate=16000):
    # Transform spec to audio
    A = DB_to_amplitude(spec - 80, ref=torch.max(spec), power=1)
    A = spec * np.exp(1j*phase)
    blank_column = torch.zeros((1, A.shape[1]))
    A = torch.cat((A, blank_column), axis=0)
    A = self.inverse_spectrogram(A)
    return A

# Load audio from file
midi_item = json_data[-6]
print(midi_item["note_str"])
filename = midi_item["note_str"] + ".wav"
wav_orig = load_audio(train_dataroot + filename)

# Load midi info from file
duration = wav_orig.shape[1]/midi_item["sample_rate"]
wav_midi = midi_to_audio(midi_item["pitch"], midi_item["velocity"], duration, midi_item["sample_rate"])

print("Original audio and midi audio synthesized")
# Original audio
display(Audio(wav_orig, rate=midi_item["sample_rate"]))
# Sine audio by midi
display(Audio(wav_midi, rate=midi_item["sample_rate"]))

audioT = AudioTransform()

print("Spectrogram and inverse spectrogram to get original audio and midi audio synthesized back")
# Original audio spec
spec_orig, phase = audioT.wav_to_spec(wav_orig, midi_item["sample_rate"])
show_spec_and_audio(spec_orig, phase)
# Sine audio by midi spec
spec_midi, phase = audioT.wav_to_spec(wav_midi, midi_item["sample_rate"])
show_spec_and_audio(spec_midi, phase)

In [None]:
# wav_midi = midi_to_audio(127, midi_item["velocity"], duration, midi_item["sample_rate"])
# spec_midi, phase = audioT.wav_to_spec(wav_midi, midi_item["sample_rate"])
# split = split_low_high_freq(torch.tensor(spec_midi)).numpy()
# split_phase = split_low_high_freq(torch.tensor(phase)).numpy()
# show_spec_and_audio(split[0], split_phase[0])
# show_spec_and_audio(split[1], split_phase[1])

# 5. **Model Utilities**

## 1. Patch Splitting

In [None]:
def split_low_high_freq(spectrum):
    """
    Split in half the 2d spectrum and add as new dimension.

    Args:
      spectrum: shape (1, h, w) where h=2*w

    Return:
      spectrum_3d: shape (2, w, w)
    """
    x, h, w = spectrum.shape
    # Check if h is two times w
    if (h!=2*w):
      raise Exception(f"Invalid array of shape {spectrum.shape}")
    spectrum_3d = torch.reshape(spectrum, (2,w,w))
    return spectrum_3d

def merge_low_high_freq(spectrum_3d):
    """
    Merge the splitted spectrum and remove third dimension.

    Args:
      spectrum_3d: shape (2, w, w)

    Return:
      spectrum: shape (h, w) where h=2*w
    """
    x, h, w = spectrum_3d.shape
    spectrum = torch.reshape(spectrum_3d, (1, w*2,w))
    return spectrum

def split_remove_high_freq(spectrum):
    """
    Split in half the 2d spectrum, add as new dimension and remove the high frequency part.

    Args:
      spectrum: shape (1, h, w) where h=2*w

    Return:
      spectrum_3d: shape (1, w, w)
    """
    spectrum_3d = split_low_high_freq(spectrum)[0][None]
    return spectrum_3d

def recreate_high_freq(spectrum_3d):
    """
    Recreate the high frequency part of the spectrum and remove third dimension.

    Args:
      spectrum_3d: shape (1, w, w)

    Return:
      spectrum: shape (h, w) where h=2*w and the top of the matrix is full of zeros
    """
    x, h, w = spectrum_3d.shape
    zeros_tensor = torch.zeros(1, h, w)
    spectrum = torch.cat((spectrum_3d, zeros_tensor), dim=0)
    return spectrum

n = 4
values = [i for i in range(-16,16)]
arr = [values[n*i:n*(i+1)] for i in range(len(values)//n)]
x = torch.tensor([arr])
print("Original:\n", x, x.shape)
x = split_low_high_freq(x)
print("After Splitting:\n", x, x.shape)
x = merge_low_high_freq(x)
print("After Merging:\n", x, x.shape)
x = split_remove_high_freq(x)
print("After Removing High Frequencies:\n", x, x.shape)
x = recreate_high_freq(x)
print("After Recreating (Zeros) High Frequencies:\n", x, x.shape)

## 2. Dataloaders

In [None]:
class CustomDataset(Dataset):
    def __init__(self, dataroot, json):
        self.audio_path = dataroot
        self.data = [] # list of [spec_midi, spec_orig]
        self.phases = [] # list of [phase_midi, phase_orig]
        for data in json:
            # Get data
            filename, pitch, velocity, sr = f"{data['note_str']}.wav", data["pitch"], data["velocity"], data["sample_rate"]
            # Load audio
            if (not os.path.exists(self.audio_path + filename)):
              continue
            wav_orig = load_audio(self.audio_path + filename)
            # Get duration
            duration = wav_orig.shape[1]/sr
            # Synth midi audio
            wav_midi = midi_to_audio(pitch, velocity, duration, sr)
            # Convert to spectrum
            spec_orig, phase_orig = audioT.wav_to_spec(wav_orig)
            spec_midi, phase_midi = audioT.wav_to_spec(wav_midi)
            # Zero padding to make it a square
            spec_orig = split_remove_high_freq(spec_orig[None]) # Take only the half with meaningful signal
            spec_midi = split_remove_high_freq(spec_midi[None]) # Take only the half with meaningful signal
            # spec_midi = split_low_high_freq(spec_midi)
            self.data.append((spec_orig, spec_midi))
            # Phases
            self.phases.append((phase_orig, phase_midi))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        spec_orig, spec_midi = self.data[idx]
        return spec_orig, spec_midi

    def get_phase(self, idx):
        phase_orig, phase_midi = self.phases[idx]
        return phase_orig, phase_midi

train_dataset = CustomDataset(train_dataroot, json_data)
print(len(train_dataset))
test_dataset = CustomDataset(test_dataroot, json_test_data)
print(len(test_dataset))

# Batch size during training
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

train_sample = next(iter(train_dataloader))
# test_sample = test_dataset[0]
# test_phase = test_dataset.get_phase(0)

min_in = float('inf')
max_in = -float('inf')
min_out = float('inf')
max_out = -float('inf')
for input, output in train_dataloader:
    # For input
    max_data = input.max()
    min_data = input.min()
    min_in = min_data if min_data < min_in else min_in
    max_in = max_data if max_data > max_in else max_in
    # For output
    max_data = output.max()
    min_data = output.min()
    min_out = min_data if min_data < min_out else min_out
    max_out = max_data if max_data > max_out else max_out
print(f"Input shape ({input.shape}) | min {min_in} | max {max_in}")
print(f"Output shape ({output.shape}) | min {min_out} | max {max_out}")

os.makedirs("checkpoints", exist_ok=True)

## 3. Custom Activation Function

In [None]:
class CustomAct(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Calculate the current minimum and maximum values of the matrix
        x_min = torch.min(x)
        x_max = torch.max(x)

        # Normalize the matrix to the range [0, 1]
        x = (x - x_min) / (x_max - x_min)

        x = x * 80

        return x

customAct = CustomAct()
for i in range(1):
  if torch.cuda.is_available():
    m = torch.rand((8,2,256,256)).to("cuda")
  else:
    m = torch.rand((1,2,256,256))
  m = customAct.forward(m)
  print(m.max(), m.min())

# 6. Pix2Pix

## 1. LR Finder adjustment

In [None]:
# from pytorch_lightning.tuner.lr_finder import _LRFinder, _LinearLR, _ExponentialLR, LRScheduler
# from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT
# from typing import cast


# def new_exchange_scheduler(self, trainer: "pl.Trainer") -> None:
#         # TODO: update docs here
#         """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
#         optimizer together with a new scheduler that takes care of the learning rate search."""

#         optimizers = trainer.strategy.optimizers

#         optimizer = optimizers[0]

#         new_lrs = [self.lr_min] * len(optimizer.param_groups)
#         for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
#             param_group["lr"] = new_lr
#             param_group["initial_lr"] = new_lr

#         args = (optimizer, self.lr_max, self.num_training)
#         scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
#         scheduler = cast(LRScheduler, scheduler)

#         trainer.strategy.optimizers = [optimizer]
#         trainer.strategy.lr_scheduler_configs = [LRSchedulerConfig(scheduler, interval="step")]

# _LRFinder._exchange_scheduler = new_exchange_scheduler

# def new_training_step(self, batch, batch_idx, optimizer_idx=0):
#     real, condition = batch

#     loss = None
#     if optimizer_idx == 0:
#         loss = self._disc_step(real, condition)
#         self.log("PatchGAN Loss", loss)
#     elif optimizer_idx == 1:
#         loss = self._gen_step(real, condition)
#         self.log("Generator Loss", loss)

#     return loss

# Pix2Pix.training_step = new_training_step

In [None]:
# from pytorch_lightning.tuner.lr_finder import _LRFinder, _LinearLR, _ExponentialLR, LRScheduler
# from pytorch_lightning.utilities.types import LRScheduler, LRSchedulerConfig, STEP_OUTPUT
# from typing import cast


# def new_exchange_scheduler(self, trainer: "pl.Trainer") -> None:
#         """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified
#         optimizer together with a new scheduler that takes care of the learning rate search."""

#         optimizers = trainer.strategy.optimizers
#         trainer.strategy.lr_scheduler_configs = []

#         for optimizer in optimizers:
#             new_lrs = [self.lr_min] * len(optimizer.param_groups)
#             for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
#                 param_group["lr"] = new_lr
#                 param_group["initial_lr"] = new_lr

#             args = (optimizer, self.lr_max, self.num_training)
#             scheduler = _LinearLR(*args) if self.mode == "linear" else _ExponentialLR(*args)
#             scheduler = cast(LRScheduler, scheduler)
#             trainer.strategy.lr_scheduler_configs.append(LRSchedulerConfig(scheduler, interval="step"))

# _LRFinder._exchange_scheduler = new_exchange_scheduler

In [None]:
# # Run learning rate finder
# lr_finder = trainer.tuner.lr_find(pix2pix, train_dataloader, min_lr=1e-100, max_lr=1e-5, early_stop_threshold=None)

# # Plot with
# fig = lr_finder.plot(suggest=True)
# fig.show()

# # Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()

# # update hparams of the model
# pix2pix.hparams.lr = new_lr
# pix2pix.hparams.learning_rate = new_lr
# pix2pix.lr = new_lr
# pix2pix.learning_rate = new_lr
# print(new_lr)

## 2. Model creation

In [None]:
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
  pix2pix = Pix2Pix(in_channels=1, out_channels=1, learning_rate=1e-3).to("cuda")
else:
  pix2pix = Pix2Pix(in_channels=1, out_channels=1, learning_rate=1e-3)

# pix2pix.gen.tanh = CustomAct() #TODO TEST WITH THIS
pix2pix.gen.tanh = nn.ReLU()

# print(pix2pix.gen)

# summary(pix2pix.gen, (1,256,256))
# summary(pix2pix.patch_gan, [(2,256,256), (2,256,256)])

!rm -R checkpoints/pix2pix
!rm -R checkpoints/pix2pix/lightning_logs/version_1
# !cp -R "checkpoints/pix2pix/lightning_logs/version_1/checkpoints/epoch=249-step=2000.ckpt" "drive/MyDrive/AudioModelling/dst.ckpt"

if torch.cuda.is_available():
  trainer = pl.Trainer(max_epochs=200, default_root_dir="checkpoints/pix2pix", accelerator="gpu", auto_lr_find=True)
else:
  trainer = pl.Trainer(max_epochs=200, default_root_dir="checkpoints/pix2pix", auto_lr_find=True)


## 3. Training

In [None]:
# Load checkpoints
# checkpoint_folder = "checkpoints/pix2pix/lightning_logs"
# if os.path.exists(checkpoint_folder):
#   last_version = f"/version_{len(os.listdir(checkpoint_folder))-1}/checkpoints/"
#   last_checkpoint = os.listdir(checkpoint_folder+last_version)[-1]
#   print(f"Loaded checkpoint {last_version + last_checkpoint}")
#   trainer.fit(pix2pix, train_dataloader, ckpt_path=checkpoint_folder+last_version+last_checkpoint)
# else:
#   trainer.fit(pix2pix, train_dataloader)

trainer.fit(pix2pix, train_dataloader)

# !cp -R checkpoints/pix2pix dst

## 4. Testing

In [None]:
idx = np.random.randint(0,len(train_dataset))
print(idx)
spec_orig, spec_midi = train_dataset[idx] #TO REMOVE
phase_orig, phase_midi = train_dataset.get_phase(idx) #TO REMOVE
norm = CustomAct()
if torch.cuda.is_available():
  # Generate from data
  data = spec_midi.unsqueeze(dim=0).to("cuda") # Create batch with size one
  generated_spec = pix2pix.to("cuda").gen(data).cpu().detach()[0] # Take first element of the batch
  # Some stats to remove
  print(0, generated_spec[0].max(), generated_spec[0].min())
  max_data = generated_spec.max()
  min_data = generated_spec.min()
  print(max_data, min_data)
  # Normalize it
  generated_spec = norm(generated_spec)#-spec_midi*generated_spec*0.01)
else:
  # Generate from data
  data = spec_midi.unsqueeze(dim=0) # Create batch with size one
  generated_spec = pix2pix.gen(data).cpu().detach()[0] # Take first element of the batch
  generated_spec = norm(merge_low_high_freq(generated_spec)).numpy() # Normalize it

generated_spec = merge_low_high_freq(recreate_high_freq(generated_spec)).numpy()
show_spec_and_audio(generated_spec, phase_orig)
# # Original source sample
spec_midi = merge_low_high_freq(recreate_high_freq(spec_midi)).numpy()
show_spec_and_audio(spec_midi, phase_midi)
# Original dest sample
spec_orig =  merge_low_high_freq(recreate_high_freq(spec_orig)).numpy()
show_spec_and_audio(spec_orig, phase_orig)

# Evaluation

In [None]:
i=0
for i in range(5):
  idx = np.random.randint(0,len(train_dataset))
  print(idx)
  spec_orig, spec_midi = train_dataset[idx] #TO REMOVE
  phase_orig, phase_midi = train_dataset.get_phase(idx) #TO REMOVE
  sample = merge_low_high_freq(recreate_high_freq(spec_orig.cpu().detach())).numpy()
  # dest_sample = merge_low_high_freq(recreate_high_freq(spec_midi)).numpy()
  show_spec_and_audio(sample, phase_orig)