# LSTM Model Training

## Imports

In [None]:
from pathlib import Path
import warnings

warnings.filterwarnings('ignore', category=UserWarning, module='pygame')

from funcs import *
from model_classes import *

### Defining Constants

In [None]:
# File Paths
PROJECT_PATH = PROJECT_PATH = Path.cwd()  # Assumes notebook is run from project root
DATA_PATH = f"{PROJECT_PATH}music/cleaned_data/"    # must have 'train', 'test', and 'valid' subfolders

# Other variables
CIRCLE_OF_FIFTHS = [0, 7, 2, 9, 4, 11, 6, 1, 8, 3, 10, 5]

device = 'cuda' if torch.cuda.is_available() else 'cpu' # (testing with cpu)

## Embedding MIDI Files

This method uses the work done by Ching-Hua Chuan and Dorien Herremans:

They proposed the idea of encoding polyphonic music tracks geometrically, allowing a deep learning algorithm to capture the nuances of both polyphony and time in a musical arrangement. They proposed using the Tonnetz, a lattice diagram that relates tones in two-dimensional space.

---

I decided to make use of Torchvision's library and modules for encoding imagery for deep learning tasks. I also used the pretty_midi and music21 libraries to assist in deciphering MIDI files.

Finally, the encoding process will make use of a Convolutional Neural Network for auto-encoding, and an LSTM for sequence prediction.

---

Features:
1. Converts MIDI files into Tonnetz images
2. Trains a CNN Autoencoder to learn tonal/chord features
3. Trains an LSTM sequence predictor to predict the next frame in the sequence

### Requirements

```
pip install torch torchvision pretty_midi numpy music21 tqdm
```

## CNN & LSTM Model Creation

### Data Loader: Custom Dataset Class

In [None]:
# 1) parse MIDIs to tonnetz sequences
#    -- uses pre-partitioned subfolders
partitions = ["train", "valid", "test"]
partition_sequences = {}

print("\n=== DEBUG: STARTING DATA LOAD ===\n")
print(f"Base midi_folder = {DATA_PATH}")
print(f"Expected subfolders = {partitions}\n")

for part in partitions:

    folder = os.path.join(DATA_PATH, part)
    print(f"\n--- Checking partition: {part} ---")
    print(f"Looking for folder: {folder}")

    if not os.path.isdir(folder):
        print(f"WARNING: Folder does NOT exist: {folder}")
        partition_sequences[part] = []
        continue

    # list files
    midi_paths = (
        glob.glob(os.path.join(folder, "*.mid")) +
        glob.glob(os.path.join(folder, "*.midi"))
    )

    print(f"Found {len(midi_paths)} MIDI files in {folder}")

    if len(midi_paths) == 0:
        print("WARNING: Folder contains NO MIDI FILES.")
        partition_sequences[part] = []
        continue

    sequences = []
    print(f"Converting MIDI -> tonnetz sequences [{part}]")

    for idx, p in enumerate(midi_paths):
        print(f"  [{idx+1}/{len(midi_paths)}] Processing: {p}")
        try:
            seq = midi_to_tonnetz_sequence(
                p,
                rows=24,
                cols=12,
                quantize_beat=1.0
            )
            print(f"      ✓ Loaded. Shape = {seq.shape}")

            if seq.shape[0] > 8:
                sequences.append(seq)
                print("      ✓ Added to dataset.")
            else:
                print(f"\033[91m      ✗ Skipped: sequence too short (< 8 frames).\033[0m")

        except Exception as e:
            print(f"\033[91m      ✗ ERROR processing file:\033[0m")
            print(f"\033[91m        {e}\033[0m")

    partition_sequences[part] = sequences
    print(f"Finished partition '{part}'. Sequences loaded: {len(sequences)}")

# Assign explicitly to train/valid/test variables
train_seqs = partition_sequences["train"]
val_seqs   = partition_sequences["valid"]
test_seqs  = partition_sequences["test"]

print("\n=== FINAL COUNTS ===")
print(f"Train sequences: {len(train_seqs)}")
print(f"Valid sequences: {len(val_seqs)}")
print(f"Test sequences:  {len(test_seqs)}")
print("======================\n")

## Training

### Autoencoder

In [None]:
# 2) create dataset and dataloader
seq_len = 16
train_ds = TonnetzSequenceDataset(train_seqs, seq_len=seq_len)
val_ds = TonnetzSequenceDataset(val_seqs, seq_len=seq_len)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

#-- Counting occurrences of 0 and 1 across the entire train_loader
count_0 = 0
count_1 = 0
for X, _ in train_loader:
    # Flatten the batch and count
    flat = X.flatten()
    count_0 += (flat == 0).sum().item()
    count_1 += (flat == 1).sum().item()

print("Total count of 0 in train_loader:", count_0)
print("Total count of 1 in train_loader:", count_1)

In [None]:
# 3) initialize autoencoder model
autoenc = ConvAutoencoder(in_channels=1, feat_maps=(20,10), rows=24, cols=12, latent_dim=128)

In [None]:
# 4) pre-train autoencoder using all frames from training sequences
print("Pretraining autoencoder")
autoenc = pretrain_autoencoder(autoenc, train_loader, device=device, criterion="l1", epochs=8, lr=1e-2, verbose=True, debug=False)

In [None]:
# Saving autoencoder model
torch.save(autoenc.state_dict(), f"{PROJECT_PATH}/models/autoencoder_final.pt")

### Sequence Predictor

In [None]:
# 5) initialize sequence model
seq_model = SequencePredictor(latent_dim=128, hidden_dim=256, num_layers=2, out_size=(1,24,12))

In [None]:
# 6) train sequence model (freeze encoder)
print("Training sequence model")
seq_model, train_hist, valid_hist = train_sequence_model(
    autoenc, seq_model,
    train_loader, val_loader,
    device=device,
    epochs=25, lr=2e-4,
    freeze_encoder=True,
    verbose=True,
    patience=5, min_delta=1e-4
)

df_history = pd.DataFrame({
    "train_loss": train_hist,
    "val_loss": valid_hist
})

In [None]:
# Saving sequence model

torch.save(seq_model.state_dict(), f"{PROJECT_PATH}/models/sequence_model_final.pt")

#### Plotting History

In [None]:
def plot_training_history(train_hist, val_hist, title="Training History"):
    epochs = np.arange(1, len(train_hist) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_hist, label="Training Loss", color="green", marker="o")
    plt.plot(epochs, val_hist, label="Validation Loss", color="deeppink", marker="o")

    plt.xticks(np.arange(0, len(train_hist) + 1, 1))

    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
train_hist = df_history['train_loss'].to_list()
val_hist = df_history['val_loss'].to_list()

plot_training_history(train_hist, val_hist, title="Training History - Music Prediction LSTM")