In [1]:
import json
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import models, layers, callbacks
from sklearn.model_selection import train_test_split

# 1. Load vocabulary
with open('chord_vocab.json') as f:
    vocab = json.load(f)
    chord_to_idx = {str(k):int(v) for k,v in vocab['chord_to_idx'].items() 
                   if str(k) not in ['unk', 'intro', 'end']}
    idx_to_chord = {int(k):v for k,v in vocab['idx_to_chord'].items() 
                   if v not in ['unk', 'intro', 'end']}
    vocab_size = len(chord_to_idx)

# 2. Simplified data loading
def load_and_preprocess_data(filepath, sample_size=50000):
    df = pd.read_parquet(filepath).sample(n=sample_size, random_state=42)
    
    X, y = [], []
    for _, row in df.iterrows():
        # Skip if any chord is invalid
        input_chord = str(row['input_chord'])
        output_chords = [str(row[f'chord_{i}']) for i in range(2, 5)]
        
        if (input_chord in chord_to_idx and 
            all(c in chord_to_idx for c in output_chords)):
            X.append(chord_to_idx[input_chord])
            y.append([chord_to_idx[c] for c in output_chords])
    
    return np.array(X), np.array(y)

# Load data
X, y = load_and_preprocess_data('all_chord_sequences.parquet.gzip')
print(f"Loaded {len(X)} valid sequences")
print(f"Vocabulary size: {vocab_size}")

# Train/val split
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 3. Model architecture
model = models.Sequential([
    layers.Embedding(vocab_size, 128, input_length=1),
    layers.Reshape((1, 128)),
    layers.LSTM(256, return_sequences=False),
    layers.Dropout(0.3),
    layers.RepeatVector(3),
    layers.LSTM(256, return_sequences=True),
    layers.TimeDistributed(layers.Dense(vocab_size, activation='softmax'))
])
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# 4. Training
history = model.fit(
    X_train.reshape(-1, 1),
    y_train,
    validation_data=(X_val.reshape(-1, 1), y_val),
    batch_size=128,
    epochs=15,
    callbacks=[
        callbacks.ModelCheckpoint(
            'model.keras',
            monitor='val_accuracy',
            save_best_only=True
        ),
        callbacks.EarlyStopping(patience=5)
    ]
)

# 5. Prediction function
def predict_sequence(start_chord, model, max_attempts=20):
    if start_chord not in chord_to_idx:
        return []
    
    input_idx = chord_to_idx[start_chord]
    
    for _ in range(max_attempts):
        sequence = []
        current_input = np.array([[input_idx]])
        
        for _ in range(3):
            preds = model.predict(current_input, verbose=0)[0][len(sequence)]
            preds = preds / np.sum(preds)
            
            # Remove input and existing chords
            mask = np.ones_like(preds)
            mask[[input_idx] + [chord_to_idx.get(c, -1) for c in sequence]] = 0
            preds = preds * mask
            
            if np.sum(preds) == 0:
                break
                
            preds = preds / np.sum(preds)
            chosen_idx = np.random.choice(len(preds), p=preds)
            sequence.append(idx_to_chord[chosen_idx])
            current_input = np.array([[chosen_idx]])
        
        if len(sequence) == 3 and len(set(sequence)) == 3:
            return sequence
    
    return []

# Test predictions
for chord in ["C", "G", "Dm", "F#"]:
    if chord in chord_to_idx:
        seq = predict_sequence(chord, model)
        print(f"{chord} → {' - '.join(seq)}")

Loaded 50000 valid sequences
Vocabulary size: 159
Epoch 1/15




[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 32ms/step - accuracy: 0.1490 - loss: 3.4263 - val_accuracy: 0.2004 - val_loss: 2.6445
Epoch 2/15
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 33ms/step - accuracy: 0.2099 - loss: 2.6179 - val_accuracy: 0.2176 - val_loss: 2.5954
Epoch 3/15
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 35ms/step - accuracy: 0.2203 - loss: 2.5768 - val_accuracy: 0.2213 - val_loss: 2.5721
Epoch 4/15
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 37ms/step - accuracy: 0.2245 - loss: 2.5597 - val_accuracy: 0.2212 - val_loss: 2.5564
Epoch 5/15
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 46ms/step - accuracy: 0.2257 - loss: 2.5383 - val_accuracy: 0.2247 - val_loss: 2.5422
Epoch 6/15
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 36ms/step - accuracy: 0.2302 - loss: 2.5169 - val_accuracy: 0.2288 - val_loss: 2.5302
Epoch 7/15
[1m313/313[0m 