This script was used to retrain the discriminator on colab

# Imports

In [None]:
!git clone https://github.com/darinchau/project-remucs.git

In [None]:
cd project-remucs

In [None]:
!git submodule update --init --recursive

In [None]:
# Get the backup dataset
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!pip install -r requirements.txt

# Stuff

In [None]:
# VQVAE epochs
vqvae_epoch = 247808

In [None]:
cd project-remucs

In [None]:
mkdir -p ./resources/key && [ ! -f ./resources/key/key.json ] && cp /content/drive/MyDrive/dataset/key.json ./resources/key/key.json

In [None]:
# Upload your vqvae now
import os
assert os.getcwd() == "/content/project-remucs"

# Upload your key now
assert os.path.isfile("./resources/key/key.json")

vqvae_path = f"/content/drive/MyDrive/vqvae_{vqvae_epoch}_vqvae_autoencoder_ckpt.pth"
assert os.path.isfile(vqvae_path)

In [None]:
from google.colab import userdata
import wandb
wandb.login(key=userdata.get('WANDB'))

In [None]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

TARGET_FEATURES = 512

class SpectrogramPatchModel(nn.Module):
    """This uses the idea of PatchGAN but changes the architecture to use Conv2d layers on each bar (4, 128, 512) patches

    Assumes input is of shape (B, 4, 512, 512), outputs a tensor of shape (B, 4, 4)"""
    def __init__(self, target_features: int = TARGET_FEATURES):
        super(SpectrogramPatchModel, self).__init__()
        # Define a simple CNN architecture for each patch
        self.conv1 = nn.Conv2d(4, 16, kernel_size=3, padding=1)  # Output: (B, 16, 128, 512)
        self.pool11 = nn.AdaptiveMaxPool2d((128, 256))
        self.pool12 = nn.AdaptiveAvgPool2d((64, 256))
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) # Output: (B, 32, 64, 256)
        self.pool21 = nn.AdaptiveMaxPool2d((64, 128))
        self.pool22 = nn.AdaptiveAvgPool2d((32, 128))
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # Output: (B, 64, 32, 128)
        self.pool31 = nn.AdaptiveMaxPool2d((32, 32))
        self.pool32 = nn.AdaptiveAvgPool2d((8, 32))
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  # Output: (B, 128, 8, 32)
        self.fc = nn.Conv2d(128, 4, (8, 32)) # Equivalent to FC layers over each channel
        self.target_features = target_features

    def forward(self, x: Tensor):
        # x shape: (B, 4, 512, 512)
        # Splitting along the T axis into 4 patches
        patches = x.unflatten(2, (x.size(2) // 128, 128))  # Output: (B, 4, 4, 128, 512)

        # Process each patch
        batch_size, num_patches, channels, height, width = patches.size()
        patches = patches.reshape(-1, channels, height, width)  # Flatten patches for batch processing

        # Apply CNN
        x = self.conv1(patches)
        x = F.relu(x)
        x = self.pool11(x)
        x = self.pool12(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool21(x)
        x = self.pool22(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.pool31(x)
        x = self.pool32(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.fc(x)
        x = x.view(batch_size, num_patches, channels, -1).squeeze(-1).squeeze(-1)
        return x

x = torch.randn((3, 4, 512, 512))
model = SpectrogramPatchModel()
model(x).shape

In [None]:
from scripts.retrain_discriminator import train

train(
    discriminator = SpectrogramPatchModel(),
    vae_ckpt_path=vqvae_path,
    vae_config_path="./resources/config/vqvae.yaml",
    local_dataset_dir="/content/drive/MyDrive/dataset/test_specs",
    base_dir = "./resources/models",
    start_from_iter=0,
    dataset_params={
        "train_lookup_table_path": "/content/drive/MyDrive/dataset/lookup_table_train.json",
        "val_lookup_table_path": "/content/drive/MyDrive/dataset/lookup_table_val.json",
    },
    train_params={
        "autoencoder_batch_size": 16
    }
)