# LSTM Model Training

## Imports

In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='pygame')

from funcs import *

### Defining Constants

In [None]:
# File Paths
PROJECT_PATH = PROJECT_PATH = Path.cwd()  # Assumes notebook is run from project root
DATA_PATH = f"{PROJECT_PATH}music/cleaned_data/"    # must have 'train', 'test', and 'valid' subfolders

# Other variables
CIRCLE_OF_FIFTHS = [0, 7, 2, 9, 4, 11, 6, 1, 8, 3, 10, 5]

device = 'cuda' if torch.cuda.is_available() else 'cpu' # (testing with cpu)

## Embedding MIDI Files

This method uses the work done by Ching-Hua Chuan and Dorien Herremans:

They proposed the idea of encoding polyphonic music tracks geometrically, allowing a deep learning algorithm to capture the nuances of both polyphony and time in a musical arrangement. They proposed using the Tonnetz, a lattice diagram that relates tones in two-dimensional space.

---

I decided to make use of Torchvision's library and modules for encoding imagery for deep learning tasks. I also used the pretty_midi and music21 libraries to assist in deciphering MIDI files.

Finally, the encoding process will make use of a Convolutional Neural Network for auto-encoding, and an LSTM for sequence prediction.

---

Features:
1. Converts MIDI files into Tonnetz images
2. Trains a CNN Autoencoder to learn tonal/chord features
3. Trains an LSTM sequence predictor to predict the next frame in the sequence

### Requirements

```
pip install torch torchvision pretty_midi numpy music21 tqdm
```

## CNN & LSTM Model Creation

### Data Loader: Custom Dataset Class

In [None]:
class TonnetzSequenceDataset(Dataset):
    """
    Precomputed: list of tonnetz sequences (numpy arrays of shape (T, R, C))
    We will extract sliding windows: given sequence length seq_len, we produce:
      X: (seq_len, R, C), y: (R, C) the next slice after the seq
    """
    def __init__(self, sequences: List[np.ndarray], seq_len: int = 16):
        self.samples = []
        self.seq_len = seq_len
        for seq in sequences:
            T = seq.shape[0]
            if T <= seq_len:
                continue
            # sliding windows
            for i in range(0, T - seq_len):
                X = seq[i:i+seq_len]      # (seq_len, R, C)
                y = seq[i+seq_len]        # (R, C)
                self.samples.append((X, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        X, y = self.samples[idx]
        # return tensors: (seq_len, channels=1, H, W) and (1, H, W)
        Xt = torch.tensor(X).unsqueeze(1)  # 1 channel
        yt = torch.tensor(y).unsqueeze(0)
        return Xt, yt

In [None]:
# 1) parse MIDIs to tonnetz sequences
#    -- uses pre-partitioned subfolders
partitions = ["train", "valid", "test"]
partition_sequences = {}

print("\n=== DEBUG: STARTING DATA LOAD ===\n")
print(f"Base midi_folder = {DATA_PATH}")
print(f"Expected subfolders = {partitions}\n")

for part in partitions:

    folder = os.path.join(DATA_PATH, part)
    print(f"\n--- Checking partition: {part} ---")
    print(f"Looking for folder: {folder}")

    if not os.path.isdir(folder):
        print(f"WARNING: Folder does NOT exist: {folder}")
        partition_sequences[part] = []
        continue

    # list files
    midi_paths = (
        glob.glob(os.path.join(folder, "*.mid")) +
        glob.glob(os.path.join(folder, "*.midi"))
    )

    print(f"Found {len(midi_paths)} MIDI files in {folder}")

    if len(midi_paths) == 0:
        print("WARNING: Folder contains NO MIDI FILES.")
        partition_sequences[part] = []
        continue

    sequences = []
    print(f"Converting MIDI -> tonnetz sequences [{part}]")

    for idx, p in enumerate(midi_paths):
        print(f"  [{idx+1}/{len(midi_paths)}] Processing: {p}")
        try:
            seq = midi_to_tonnetz_sequence(
                p,
                rows=24,
                cols=12,
                quantize_beat=1.0
            )
            print(f"      ✓ Loaded. Shape = {seq.shape}")

            if seq.shape[0] > 8:
                sequences.append(seq)
                print("      ✓ Added to dataset.")
            else:
                print(f"\033[91m      ✗ Skipped: sequence too short (< 8 frames).\033[0m")

        except Exception as e:
            print(f"\033[91m      ✗ ERROR processing file:\033[0m")
            print(f"\033[91m        {e}\033[0m")

    partition_sequences[part] = sequences
    print(f"Finished partition '{part}'. Sequences loaded: {len(sequences)}")

# Assign explicitly to train/valid/test variables
train_seqs = partition_sequences["train"]
val_seqs   = partition_sequences["valid"]
test_seqs  = partition_sequences["test"]

print("\n=== FINAL COUNTS ===")
print(f"Train sequences: {len(train_seqs)}")
print(f"Valid sequences: {len(val_seqs)}")
print(f"Test sequences:  {len(test_seqs)}")
print("======================\n")

### Models

#### Autoencoder (CNN)

In [None]:
class ConvAutoencoder(nn.Module):
    """2-layer convolutional encoder + FC decoder to reconstruct tonnetz image"""
    def __init__(self, in_channels=1, feat_maps=(20,10), rows=24, cols=12, latent_dim=128):
        super().__init__()
        # Encoder
        self.enc_conv1 = nn.Conv2d(in_channels, feat_maps[0], kernel_size=3, padding=1)  # preserve shape
        self.enc_pool1 = nn.MaxPool2d((2,2))   # reduce
        self.enc_conv2 = nn.Conv2d(feat_maps[0], feat_maps[1], kernel_size=3, padding=1)
        # second pooling (as paper: 2x1)
        self.enc_pool2 = nn.MaxPool2d((2,1))
        # compute shape after conv/pool with given rows, cols
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, rows, cols)
            x = self.enc_conv1(dummy); x = self.enc_pool1(x)
            x = self.enc_conv2(x); x = self.enc_pool2(x)
            self.enc_out_shape = x.shape  # (1, C, H, W)
            enc_flat = int(np.prod(self.enc_out_shape[1:]))
        self.fc_enc = nn.Linear(enc_flat, latent_dim)
        # Decoder: mirror
        self.fc_dec = nn.Linear(latent_dim, enc_flat)
        self.dec_convT1 = nn.ConvTranspose2d(feat_maps[1], feat_maps[0], kernel_size=3, padding=1)
        self.unpool1 = nn.Upsample(scale_factor=(2,1), mode='nearest')
        self.dec_convT2 = nn.ConvTranspose2d(feat_maps[0], in_channels, kernel_size=3, padding=1)
        self.unpool2 = nn.Upsample(scale_factor=(2,2), mode='nearest')
        self.activation = nn.ReLU()

    def encode(self, x):
        x = self.activation(self.enc_conv1(x))
        x = self.enc_pool1(x)
        x = self.activation(self.enc_conv2(x))
        x = self.enc_pool2(x)
        batch = x.shape[0]
        flat = x.view(batch, -1)
        z = self.fc_enc(flat)
        return z

    def decode(self, z):
        batch = z.shape[0]
        x = self.fc_dec(z)
        x = x.view(batch, *tuple(self.enc_out_shape[1:]))  # (B, C, H, W)
        x = self.activation(self.dec_convT1(x))
        x = self.unpool1(x)
        x = self.activation(self.dec_convT2(x))
        x = self.unpool2(x)
        # final reconstruction logits (no sigmoid here)
        return x

    def forward(self, x):
        z = self.encode(x)
        recon_logits = self.decode(z)
        return recon_logits, z


#### Sequence Predictor (LSTM)

In [None]:
class SequencePredictor(nn.Module):
    """
    LSTM sequence model that takes latent vectors sequence and predicts next frame (tonnetz)
    """
    def __init__(self, latent_dim=128, hidden_dim=256, num_layers=2, out_size=(1,24,12)):
        super().__init__()
        self.lstm = nn.LSTM(input_size=latent_dim, hidden_size=hidden_dim,
                            num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, np.prod(out_size))  # predict entire tonnetz image flattened
        self.out_size = out_size
    def forward(self, z_seq):
        # z_seq: (B, seq_len, latent_dim)
        output, (h_n, c_n) = self.lstm(z_seq)  # output (B, seq_len, hidden_dim)
        last = output[:, -1, :]  # take last output
        logits = self.fc(last)
        logits = logits.view(-1, *self.out_size)  # (B, 1, H, W)
        return logits


#### Training Helper Functions

In [None]:
def pretrain_autoencoder(autoenc: ConvAutoencoder, dataloader, device,
                         criterion = "logit", epochs=10, lr=1e-3,
                         verbose=False, debug=False):
    autoenc.to(device)
    optim = torch.optim.Adam(autoenc.parameters(), lr=lr)
    if criterion == "logit":
        criterion = nn.BCEWithLogitsLoss()
    elif criterion == "mse":
        criterion = nn.MSELoss()
    elif criterion == "l1":
        criterion = nn.L1Loss()
    autoenc.train()

    for epoch in range(epochs):
        total_loss = 0.0
        num_batches = len(dataloader)

        # Determine the interval: every 1% of total batches
        if verbose and num_batches > 0:
            interval = max(1, num_batches // 100)
            print(f"\n=== Autoencoder Pretraining Epoch {epoch+1}/{epochs} ===")
            print(f"Total batches = {num_batches} | Printing every {interval} batches (~1%)")

        for batch_idx, (X, _) in enumerate(dataloader):
            B, seq_len, ch, H, W = X.shape
            images = X.view(B * seq_len, ch, H, W).to(device)

            if debug:
                with torch.no_grad():
                    recon, _ = autoenc(images)
                    print("RECON min/max:", recon.min().item(), recon.max().item())


            optim.zero_grad()
            recon_logits, _ = autoenc(images)
            loss = criterion(recon_logits, images)
            loss.backward()
            optim.step()

            batch_loss = loss.item() * images.size(0)
            total_loss += batch_loss

            # Print progress every 1% of total batches
            if verbose and (batch_idx % interval == 0):
                pct = int((batch_idx / num_batches) * 100)
                print(f"  Training: {pct}% Complete. "
                      f"(Batch {batch_idx+1}/{num_batches}: loss={loss.item():.6f})")

        # Compute average loss
        avg = (total_loss /
               (len(dataloader.dataset) * dataloader.batch_size)
               if len(dataloader) > 0 else total_loss)

        print(f"[AE] Epoch {epoch+1}/{epochs} avg_loss={avg:.6f}")

    return autoenc

In [None]:
def train_sequence_model(
        autoenc: ConvAutoencoder,
        seq_model: SequencePredictor,
        train_loader,
        val_loader,
        device,
        epochs=30,
        lr=1e-3,
        freeze_encoder=True,
        verbose=False,
        patience=5,
        min_delta=1e-4):

    history_train = []
    history_val   = []

    # freeze encoder if requested
    if freeze_encoder:
        for p in autoenc.parameters():
            p.requires_grad = False

    autoenc.to(device)
    seq_model.to(device)

    optim = torch.optim.Adam(
        filter(lambda p: p.requires_grad, seq_model.parameters()),
        lr=lr
    )
    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float('inf')
    best_state = None
    epochs_without_improvement = 0

    # ---------------------------
    # MAIN TRAINING LOOP
    # ---------------------------
    for epoch in range(epochs):
        seq_model.train()
        total_loss = 0.0
        num_batches = len(train_loader)

        # progress interval (~1%)
        interval = max(1, num_batches // 100) if verbose else None

        if verbose:
            print(f"\n=== Sequence Model Epoch {epoch+1}/{epochs} ===")
            print(f"Total training batches = {num_batches}")

        # ---------------------------
        # TRAINING EPOCH
        # ---------------------------
        for batch_idx, (X, y) in enumerate(train_loader):
            B, seq_len, ch, H, W = X.shape

            # flatten time frames for encoder
            X_flat = X.view(B * seq_len, ch, H, W).to(device)

            # extract latents
            with torch.no_grad():
                _, z_flat = autoenc(X_flat)

            # restore time structure: (B, seq_len, latent_dim)
            z_seq = z_flat.view(B, seq_len, -1).to(device)
            y = y.to(device)

            optim.zero_grad()
            logits = seq_model(z_seq)        # output: (B, 1, H, W)
            loss = criterion(logits, y)
            loss.backward()
            optim.step()

            total_loss += loss.item() * B

            # 1% progress print
            if verbose and (batch_idx % interval == 0):
                pct = int((batch_idx / num_batches) * 100)
                print(f"  Training: {pct}% Complete. "
                      f"(Batch {batch_idx+1}/{num_batches}: loss={loss.item():.6f})")

        train_avg = total_loss / len(train_loader.dataset)

        # ---------------------------
        # VALIDATION EPOCH
        # ---------------------------
        seq_model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for X, y in val_loader:
                B, seq_len, ch, H, W = X.shape
                X_flat = X.view(B * seq_len, ch, H, W).to(device)

                # encode
                _, z_flat = autoenc(X_flat)
                z_seq = z_flat.view(B, seq_len, -1).to(device)
                y = y.to(device)

                logits = seq_model(z_seq)
                loss = criterion(logits, y)
                val_loss += loss.item() * B

        val_avg = val_loss / len(val_loader.dataset)

        history_train.append(train_avg)
        history_val.append(val_avg)

        print(f"[SEQ] Epoch {epoch+1}/{epochs} train={train_avg:.6f} val={val_avg:.6f}")

        # ---------------------------
        # EARLY STOPPING CHECK
        # ---------------------------
        if val_avg + min_delta < best_val_loss:
            best_val_loss = val_avg
            best_state = seq_model.state_dict()
            epochs_without_improvement = 0
            if verbose:
                print(f"  ✓ New best validation loss: {best_val_loss:.6f}")
        else:
            epochs_without_improvement += 1
            if verbose:
                print(f"  ✗ No improvement ({epochs_without_improvement}/{patience})")

            if epochs_without_improvement >= patience:
                print("Early stopping triggered — restoring best sequence model.")
                seq_model.load_state_dict(best_state)
                return seq_model

    # restore best model at end
    if best_state is not None:
        seq_model.load_state_dict(best_state)

    return seq_model, history_train, history_val

## Training

### Autoencoder

In [None]:
# 2) create dataset and dataloader
seq_len = 16
train_ds = TonnetzSequenceDataset(train_seqs, seq_len=seq_len)
val_ds = TonnetzSequenceDataset(val_seqs, seq_len=seq_len)
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False)

#-- Counting occurrences of 0 and 1 across the entire train_loader
count_0 = 0
count_1 = 0
for X, _ in train_loader:
    # Flatten the batch and count
    flat = X.flatten()
    count_0 += (flat == 0).sum().item()
    count_1 += (flat == 1).sum().item()

print("Total count of 0 in train_loader:", count_0)
print("Total count of 1 in train_loader:", count_1)

In [None]:
# 3) initialize autoencoder model
autoenc = ConvAutoencoder(in_channels=1, feat_maps=(20,10), rows=24, cols=12, latent_dim=128)

In [None]:
# 4) pre-train autoencoder using all frames from training sequences
print("Pretraining autoencoder")
autoenc = pretrain_autoencoder(autoenc, train_loader, device=device, criterion="l1", epochs=8, lr=1e-2, verbose=True, debug=False)

In [None]:
# Saving autoencoder model
torch.save(autoenc.state_dict(), f"{PROJECT_PATH}/models/autoencoder_final.pt")

### Sequence Predictor

In [None]:
# 5) initialize sequence model
seq_model = SequencePredictor(latent_dim=128, hidden_dim=256, num_layers=2, out_size=(1,24,12))

In [None]:
# 6) train sequence model (freeze encoder)
print("Training sequence model")
seq_model, train_hist, valid_hist = train_sequence_model(
    autoenc, seq_model,
    train_loader, val_loader,
    device=device,
    epochs=25, lr=2e-4,
    freeze_encoder=True,
    verbose=True,
    patience=5, min_delta=1e-4
)

df_history = pd.DataFrame({
    "train_loss": train_hist,
    "val_loss": valid_hist
})

In [None]:
# Saving sequence model

torch.save(seq_model.state_dict(), f"{PROJECT_PATH}/models/sequence_model_final.pt")

#### Plotting History

In [None]:
def plot_training_history(train_hist, val_hist, title="Training History"):
    epochs = np.arange(1, len(train_hist) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_hist, label="Training Loss", color="green", marker="o")
    plt.plot(epochs, val_hist, label="Validation Loss", color="deeppink", marker="o")

    plt.xticks(np.arange(0, len(train_hist) + 1, 1))

    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
train_hist = df_history['train_loss'].to_list()
val_hist = df_history['val_loss'].to_list()

plot_training_history(train_hist, val_hist, title="Training History - Music Prediction LSTM")