In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('./src')

from chatgpt_util_imports import generate_dataset
from data_loader import PatchDataset
from models import VAE
from train import load_checkpoint, print_loss_metrics, save_checkpoint, train_one_epoch, validate
from util import get_date_and_time

import numpy as np
import os
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
# Load +31k patches as compiled by `bwhitman` for `learnfm`.
fpath = 'data/compact.bin'

df = generate_dataset(fpath)
df = pd.DataFrame(df)

print(f"Num patches: {len(df)}")
print(f"Num features: {len(df.keys())}")

# df.hist(figsize=(50, 50))
# plt.show()

Num patches: 31380
Num features: 146


In [4]:
p_train = 0.8
batch_size = 32

dataset = PatchDataset(df)

n_train = int(p_train * len(dataset))
n_val = len(dataset) - n_train

train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [5]:
n_features = len(dataset._parameter_names)
n_hidden = 64
n_latent = 16

model = VAE(n_features, n_hidden, n_latent)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

In [6]:
# Training.
n_epochs = 0
checkpoint_dir = 'checkpoints/'
save_interval_epochs = 10

best_loss = np.inf
for i in range(n_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_loss = validate(model, val_loader, device)

    print(f"Epoch {i}")
    print_loss_metrics(train_loss, val_loss)

    if i % save_interval_epochs != 0:
        continue

    if val_loss['loss'] < best_loss:
        fname = f"chkpnt_{get_date_and_time()}_epoch{i}.pt"
        fpath = os.path.join(checkpoint_dir, fname)

        save_checkpoint(model, optimizer, fpath)

        print("=" * 80)
        print(f"Saved checkpoint at: {fpath}")
        print("=" * 80)

        best_loss = val_loss['loss']

In [7]:
fpath = 'checkpoints/chkpnt_20231001_235220_epoch990.pt'
load_checkpoint(fpath, model, optimizer)

restore = dataset.get_restorer()

model.eval()
with torch.no_grad():
    for x in dataset:
        x = x.to(device)
        y, mu, log_var = model(x)

        x = x.to('cpu')
        x = restore(x)

        y = y.to('cpu')
        y = restore(y)

        break

In [8]:
from chatgpt_util_imports import pack_patch, generate_sysex_with_corrected_checksum, write_sysex_to_file

packed_patches = bytearray()

for _ in range(16):
    packed_patches += pack_patch(x)

for _ in range(16):
    packed_patches += pack_patch(y)

# Generate the SysEx message for the selected patches
sysex_data_32_patches = generate_sysex_with_corrected_checksum(packed_patches)

# Write the SysEx message to a file
write_sysex_to_file(sysex_data_32_patches, "test.syx")
