# DENOISING AUTOENCODER FOR MIDI

### **1. Load Libraries**

The imported modules include:

* `torchvision`: contains many popular computer vision datasets, deep neural network architectures, and image processing modules. We will use this to download the CIFAR10 dataset.
* `torch.nn`: contains the deep learning neural network layers such as `Linear()`, and `Conv2d()`.
* `transforms`: will help in defining the image transforms and normalizations.
* `optim`: contains the deep learning optimizer classes such as `MSELoss()` and many others as well.
* `functional`: we will use this for activation functions such as ReLU.
* `DataLoader`: eases the task of making iterable training and testing sets.

In [None]:
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from torch.utils.data import Dataset
from torchinfo import summary

import lightning as L

# other
import os
import time
import random
import numpy as np
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm

#### Dark theme

In [None]:
plt.style.use("dark_background")

%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

### **2. Define Constants**

In [None]:
NUM_EPOCHS = 1
LEARNING_RATE = 0.005
BATCH_SIZE = 32
NOISE_FACTOR = 0.0
NUM_PERMUTATIONS = 1

Move to GPU if available

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### **3. Helper Functions**

In [None]:
def compare_plot(a1, a2, t1="noisy", t2="denoised", set_axis="off", video=False):
    plt.figure(figsize=(10, 5))
    plt.subplot(2, 1, 1)
    plt.title(t1)
    plt.imshow(
        np.squeeze(a1),
        aspect="auto",
        origin="lower",
        cmap="magma",
        interpolation="nearest",
    )
    plt.axis(set_axis)
    plt.subplot(2, 1, 2)
    plt.title(t2)
    plt.imshow(
        np.squeeze(a2),
        aspect="auto",
        origin="lower",
        cmap="magma",
        interpolation="nearest",
    )
    plt.axis(set_axis)

    if video:
        dirname = f"video"
        if not os.path.isdir(dirname):
            os.mkdir(dirname)
        plt.savefig(
            f"{dirname}/plot_{datetime.now().strftime('%y-%m-%d_%H%M%S')}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
    else:
        plt.show()

In [None]:
def format_image(image, remove_time=False):
    if remove_time:
        image = np.delete(image, 0, axis=1)
    image = torch.from_numpy(np.expand_dims(image, 0)).to(torch.float32)
    if torch.any(image > 1.0):
        image = image / image.max()
    image = F.pad(input=image, pad=(0, 12, 1, 1), mode="constant", value=0.0)

    return image

### **4. Load Data**

In [None]:
clean_images = np.load("data/all_loops.npz")

In [None]:
def shift_image_vertically(name, array, num_iterations):
    shifted_images = []

    def find_non_zero_bounds(arr):
        # Find the first and last row index with a non-zero element
        rows_with_non_zero = np.where(arr.any(axis=1))[0]
        if len(rows_with_non_zero) == 0:
            return (0, arr.shape[0] - 1)  # Handle case with no non-zero elements
        return rows_with_non_zero[0], rows_with_non_zero[-1]

    def shift_array(arr, up=0, down=0):
        # Shift array vertically
        if up > 0:
            arr = np.roll(arr, -up, axis=0)
            arr[-up:] = 0
        elif down > 0:
            arr = np.roll(arr, down, axis=0)
            arr[:down] = 0
        return arr

    highest, lowest = find_non_zero_bounds(array)
    maximum_up = highest
    maximum_down = array.shape[0] - lowest - 1

    for _ in range(num_iterations):
        # Shift up and then down, decreasing the shift amount in each iteration
        for i in range(maximum_up, 0, -1):
            new_key = f"{Path(name).stem}_u{i:02d}"
            shifted_images.append((new_key, np.copy(shift_array(array, up=i))))
        for i in range(maximum_down, 0, -1):
            new_key = f"{Path(name).stem}_d{i:02d}"
            shifted_images.append((new_key, np.copy(shift_array(array, down=i))))

    random.shuffle(shifted_images)

    return shifted_images[:num_iterations]

### **5. Prepare the Data**

In [None]:
def augment_data(clean_images, num_permutations=NUM_PERMUTATIONS, vshift=True):
    """Augments a set of passed-in images by a factor of 2*num_permutations"""
    shifted_images = []
    noisy_images = []

    for name, image in tqdm(
        list(clean_images.items()), unit="images", dynamic_ncols=True
    ):
        time_factor = image[:, 0]  # save time factor
        image = np.delete(image, 0, axis=1)  # remove it from the image though
        if vshift:
            # vertical shift images
            shifted_images.append(shift_image_vertically(name, image, num_permutations))
        else:
            # reformat clean image array
            shifted_images.append([(name, image)])

        # add noise to images
        for si in shifted_images[-1]:
            new_key, shifted_image = si
            for _ in range(num_permutations):
                # normalize
                noisy_image = shifted_image / np.max(shifted_image)

                # corrupt
                noisy_image = torch.from_numpy(
                    noisy_image
                ) + NOISE_FACTOR * torch.randn(noisy_image.shape)

                # reformat
                noisy_image = format_image(noisy_image)

                noisy_images.append((new_key, noisy_image))

    random.shuffle(noisy_images)

    return shifted_images, noisy_images

#### Data augmentation
**WARNING: This may use a ton of memory**

In [None]:
shifted_images, training_data = augment_data(clean_images, vshift=True)

# output_file = f"augmented_data_{datetime.now().strftime('%y-%m-%d_%H%M%S')}"
# np.savez_compressed(
#     os.path.join("data", output_file),
#     **{name: arr for name, arr in noisy_images},
# )
print(
    f"used {len(list(clean_images.keys()))} clean images to generate {len(training_data)} noisy images of shape {training_data[0][1].size()}"
)

#### View Results

In [None]:
for data in training_data[:2]:
    k, v = data
    compare_plot(
        clean_images[k[: k.rfind("_")] + ".mid"], v, k, f"noise ({NOISE_FACTOR})"
    )

#### Batch Data

In [None]:
class MIDILoopsDataset(Dataset):
    def __init__(self, image_paths, transforms=None):
        self.image_paths = image_paths
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = read_image(image_path)
        if self.transforms:
            image = self.transforms(image)
        return image

### **6. Define the AutoEncoder**

In [None]:
def initialize_weights(model, a=0.0, b=1.0):
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            # Apply uniform initialization to the weights
            nn.init.uniform_(module.weight, a, b)
            # Check if the module has a bias attribute
            if module.bias is not None:
                # Initialize the bias with the same bounds
                nn.init.uniform_(module.bias, a, b)

#### Bad AutoEncoder

In [None]:
class BadAutoEncoder(nn.Module):
    def __init__(self):
        super(BadAutoEncoder, self).__init__()

        # Encoder layers
        self.enc1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.enc2 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.enc3 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.enc4 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        # Decoder layers
        self.dec1 = nn.ConvTranspose2d(8, 8, kernel_size=3, stride=2)
        self.dec2 = nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2)
        self.dec3 = nn.ConvTranspose2d(16, 32, kernel_size=2, stride=2)
        self.dec4 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2)
        self.out = nn.Conv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        x = self.pool(F.silu(self.enc1(x)))
        x = self.pool(F.silu(self.enc2(x)))
        x = self.pool(F.silu(self.enc3(x)))
        x = self.pool(F.silu(self.enc4(x)))  # latent space representation

        # Decoder
        x = F.silu(self.dec1(x))
        x = F.silu(self.dec2(x))
        x = F.silu(self.dec3(x))
        x = F.silu(self.dec4(x))
        # x = torch.sigmoid(self.out(x))
        x = self.out(x)

        return x


model = BadAutoEncoder().to(device)
initialize_weights(model, 0.01, 0.1)

#### Deep AutoEncoder

From the [Lightning tutorial](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/08-deep-autoencoders.ipynb#scrollTo=cf346a22)

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        num_input_channels: int,
        base_channel_size: int,
        latent_dim: int,
        act_fn = nn.GELU,
    ):
        """Encoder.

        Args:
           num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(
                num_input_channels, c_hid, kernel_size=3, padding=1, stride=2
            ),  # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(
                c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2
            ),  # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(
                2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2
            ),  # 8x8 => 4x4
            act_fn(),
            nn.Flatten(),  # Image grid to single feature vector
            nn.Linear(2 * 16 * c_hid, latent_dim),
        )

    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    def __init__(
        self,
        num_input_channels: int,
        base_channel_size: int,
        latent_dim: int,
        act_fn = nn.GELU,
    ):
        """Decoder.

        Args:
           num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(nn.Linear(latent_dim, 2 * 16 * c_hid), act_fn())
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                2 * c_hid,
                2 * c_hid,
                kernel_size=3,
                output_padding=1,
                padding=1,
                stride=2,
            ),  # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(
                2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
            ),  # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(
                c_hid,
                num_input_channels,
                kernel_size=3,
                output_padding=1,
                padding=1,
                stride=2,
            ),  # 16x16 => 32x32
            nn.Tanh(),  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x


class Autoencoder(L.LightningModule):
    def __init__(
        self,
        base_channel_size: int,
        latent_dim: int,
        encoder_class = Encoder,
        decoder_class = Decoder,
        num_input_channels: int = 3,
        width: int = 32,
        height: int = 32,
    ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)."""
        x, _ = batch  # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("val_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

In [None]:
print(model)

In [None]:
summary(model, (1, 58, 400))

### **7. Optimizer and Loss Function**

In [None]:
# the loss function
loss_fn = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

### **8. Train**


#### Function

In [None]:
def train(net: nn.Module, training_data, epochs=NUM_EPOCHS, plot=False, video=False, tb=False):
    train_loss = []
    if tb:
        writer = SummaryWriter(f"runs/{datetime.now().strftime('%y-%m-%d_%H%M%S')}")
    for epoch in trange(epochs, desc="Total "):
        running_loss = 0.0
        with tqdm(training_data, unit="images", dynamic_ncols=False) as tepoch:
            for i, (name, input_image) in enumerate(tepoch):
                tepoch.set_description(f"Epoch {epoch + 1}")

                # train
                input_image = input_image.to(device)  # (1, 60, 412)
                optimizer.zero_grad()
                predicted_image = net(input_image)  # compute prediction
                loss = loss_fn(predicted_image, input_image)  # compute loss
                loss.backward()  # backprop
                optimizer.step()  # update parameters

                if plot or video:
                    compare_plot(
                        input_image.cpu().data,
                        predicted_image.cpu().data,
                        name,
                        f"output (loss={loss})",
                        video=video,
                    )

                # tensorboard logging
                if tb:
                    global_step = epoch * len(training_data) + i
                    writer.add_scalar("training/loss", loss.item(), global_step)
                    for p_name, param in model.named_parameters():
                        writer.add_histogram(
                            f"weights/{p_name}", param.data, global_step
                        )
                        if param.requires_grad:
                            writer.add_histogram(
                                f"gradients/{p_name}.grad", param.grad, global_step
                            )

                running_loss += loss.item()
                tepoch.set_postfix(loss=f"{loss:.02f}")

            loss = running_loss / len(training_data)
            train_loss.append(loss)

    if tb:
        writer.close()

    return train_loss

#### Run

In [None]:
train_loss = train(model, training_data, epochs=1)

#### Plot Loss

In [None]:
plt.figure()
plt.plot(train_loss)
plt.title("Train Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

#### Test Reconstruction

In [None]:
img_noisy = training_data[0][1] + NOISE_FACTOR * torch.randn(training_data[0][1].shape)
img_noisy = img_noisy.to(device)
outputs = model(img_noisy)

compare_plot(img_noisy.cpu().data, outputs.cpu().data, training_data[0][0])

### Overfit Test

In [None]:
si_tensor = [
    (k, torch.from_numpy(np.expand_dims(format_image(v, True), 0)).to(torch.float32))
    for k, v in list(clean_images.items())
]
random.shuffle(si_tensor)
overfit_set = [si_tensor[0]] * 1000

initialize_weights(model, 0.01, 0.1)
train(model, overfit_set, epochs=3, tb=True)

In [None]:
img_noisy = overfit_set[0][1] + NOISE_FACTOR * torch.randn(overfit_set[0][1].shape)
img_noisy = img_noisy.to(device)
outputs = model(img_noisy)

compare_plot(img_noisy.cpu().data, outputs.cpu().data, overfit_set[0][0])

In [None]:
import cv2
import shutil

image_folder = "video"
video_name = "ugh.mp4"
fps = 20

images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
images.sort()  # Sort the images by name

# Determine the width and height from the first image
frame = cv2.imread(os.path.join(image_folder, images[0]))
height, width, layers = frame.shape

video = cv2.VideoWriter(
    video_name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height)
)

for image in images:
    video.write(cv2.imread(os.path.join(image_folder, image)))

cv2.destroyAllWindows()
video.release()
shutil.rmtree(image_folder)