In [None]:
import matplotlib.pyplot as plt
import numpy as np
import librosa
from librosa import display
from pathlib import Path
from multiprocessing import Pool
import torch
import lightning as L
from IPython.display import Audio

%matplotlib inline

In [None]:
dir=Path("/home/kureta/Music/Chorale Samples/")

In [None]:
def prepare(fname):
    sound = librosa.load(fname, sr=44100, mono=True)[0]
    stft = librosa.stft(sound, n_fft=1024, hop_length=512, window='hann')
    spec = np.abs(stft)[1:]
    
    return librosa.amplitude_to_db(spec)

In [None]:
with Pool(12) as p:
    spectra = p.map(prepare, dir.glob("*.mp3"))

In [None]:
display.specshow(spectra[0][:, :1024])

In [None]:
data = torch.from_numpy(np.concatenate(spectra, axis=1).T)
data.shape

In [None]:
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 8),
            torch.nn.Tanh(),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


class AutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.ae = AE()

    def training_step(self, batch, _):
        pred = self.ae(batch)
        loss = torch.nn.functional.mse_loss(pred, batch)

        self.log("train_loss", loss, prog_bar=True)

        return loss

    def on_train_epoch_end(self) -> None:
        dur = 2048
        with torch.inference_mode():
            pred = self.ae(data[:dur])

        result = np.zeros((513, dur))
        result[1:, :] = pred.T.cpu().numpy()
        # y = librosa.griffinlim(librosa.db_to_amplitude(result), n_iter=100, hop_length=512, n_fft=1024, window='hann')
        y = librosa.griffinlim(result, n_iter=100, hop_length=512, n_fft=1024, window='hann')
        write_audio(f'test_{self.global_step}.wav', y, 44100)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimize

In [None]:
model = AutoEncoder.load_from_checkpoint('./lightning_logs/version_0/checkpoints/epoch=249-step=278500.ckpt')

In [None]:
with torch.inference_mode():
    pred = model.ae(data[:1024].cuda())

In [None]:
result = np.zeros((513, 1024))
result[1:,] = pred.T.cpu().numpy()
result = librosa.db_to_amplitude(result)
y = librosa.griffinlim(result, n_iter=1024, hop_length=512, n_fft=1024, window='hann')

In [None]:
display.specshow(result)
Audio(y, rate=44100)

In [None]:
class Spec(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
dataset = Spec(data)

loader = torch.utils.data.DataLoader(
    dataset=dataset, batch_size=32, shuffle=True, num_workers=8, prefetch_factor=4, persistent_workers=True
)

In [None]:
model = AutoEncoder()

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(max_steps=10000, limit_val_batches=1, accelerator="gpu", devices=1)
trainer.fit(model=model, train_dataloaders=loader)

In [None]:
dt = 0.06
length = 512
t = np.arange(start=0, stop=dt*length, step=dt)
seed = np.zeros((length, 8))
for idx in range(8):
    phase = np.random.rand(1) * 2 * np.pi
    freq = np.random.rand(1)
    amp = np.random.rand() * 50.
    seed[:, idx] = amp * np.sin(t * freq + phase)

In [None]:
with torch.inference_mode():
    pred = model.ae(data[:1024] * dmax)

In [None]:
display.specshow(pred.T.numpy())
plt.show()
display.specshow(data[:1024].T.numpy())

In [None]:
result = np.zeros((513, length))
result[1:,] = pred.T
y = librosa.griffinlim(result, n_iter=100, hop_length=512, n_fft=1024, window='hann')

In [None]:
from soundfile import write as write_audio

In [None]:
write_audio('test.wav', y, 44100)