In [None]:
import glob
import pickle
import numpy as np
from music21 import converter, instrument, note, chord
from keras.models import Model
from keras.layers import Input, Dense, Dropout, LSTM, Embedding, MultiHeadAttention, LayerNormalization
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam
import tensorflow as tf
import os

# Create output directories
os.makedirs("/kaggle/working/hybrid_outputs/model_notes", exist_ok=True)
os.makedirs("/kaggle/working/hybrid_outputs/weights", exist_ok=True)

# Check GPU availability
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

def train_hybrid():
    """ Trains a hybrid LSTM-Transformer model to generate music """
    # Check for saved notes file
    notes_file = "/kaggle/input/utiles/notes.pkl"
    if os.path.exists(notes_file):
        try:
            with open(notes_file, 'rb') as filepath:
                notes = pickle.load(filepath)
            print(f"Loaded {len(notes)} notes from {notes_file}")
        except Exception as e:
            print(f"Error loading notes file: {e}. Parsing MIDI files instead.")
            notes = get_notes()
    else:
        print(f"No notes file found at {notes_file}. Parsing MIDI files.")
        notes = get_notes()
    
    print(f"Total notes extracted: {len(notes)}")
    if len(notes) < 25:
        raise ValueError(f"Not enough notes ({len(notes)}) to create sequences of length 25")
    vocab_size = len(set(notes))  # Vocabulary size
    print(f"Vocabulary size: {vocab_size}")
    network_input, network_output = prepare_sequences(notes, vocab_size)
    model = create_hybrid_model(network_input, vocab_size)
    
    # Check for last epoch and weights
    start_epoch = 0
    last_epoch_file = "/kaggle/input/utiles/last_epoch.txt"
    weights_dir = "/kaggle/input/utiles/*.keras"
    weights_files = glob.glob(weights_dir)
    
    if os.path.exists(last_epoch_file) and weights_files:
        try:
            with open(last_epoch_file, 'r') as f:
                start_epoch = int(f.read().strip())
            latest_weights = max(weights_files, key=os.path.getctime)  # Get most recent .keras file
            print(f"Resuming training from epoch {start_epoch}, loading weights from {latest_weights}")
            model.load_weights(latest_weights)
        except Exception as e:
            print(f"Error loading last epoch or weights: {e}. Starting training from epoch 0.")
            start_epoch = 0
    else:
        print("No last epoch file or weights found in /kaggle/input/utiles/. Starting training from epoch 0.")
    
    train_model(model, network_input, network_output, start_epoch)

def get_notes():
    """ Gets all notes and chords with their durations from MIDI files """
    notes = []
    # Use Chopin dataset
    midi_path = "/kaggle/input/transposed-4artists-dataset/*.mid"
    files = glob.glob(midi_path) #[:100]  
    print(f"Found {len(files)} MIDI files")
    if not files:
        raise FileNotFoundError(f"No MIDI files found at {midi_path}")
    for file in files:
        try:
            midi = converter.parse(file)
            print(f"Parsing {file}")
            notes_to_parse = None
            try:
                s2 = instrument.partitionByInstrument(midi)
                notes_to_parse = s2.parts[0].recurse()
            except:
                notes_to_parse = midi.flat.notes
            for element in notes_to_parse:
                duration = element.duration.quarterLength
                if duration < 0.75:
                    duration_class = 'short'
                elif duration < 1.5:
                    duration_class = 'medium'
                else:
                    duration_class = 'long'
                if isinstance(element, note.Note):
                    note_str = f"{str(element.pitch)}_{duration_class}"
                    notes.append(note_str)
                elif isinstance(element, chord.Chord):
                    chord_str = f"{'.'.join(str(n) for n in element.normalOrder)}_{duration_class}"
                    notes.append(chord_str)
        except Exception as e:
            print(f"Error parsing {file}: {e}")
    # Save notes to Kaggle's working directory
    with open('/kaggle/working/hybrid_outputs/model_notes/notes.pkl', 'wb') as filepath:
        pickle.dump(notes, filepath)
    print(f"Saved {len(notes)} notes to /kaggle/working/hybrid_outputs/model_notes/notes.pkl")
    return notes

def prepare_sequences(notes, vocab_size):
    """ Prepares the sequences used by the hybrid model """
    sequence_length = 25
    note_names = sorted(set(notes))
    note_to_int = dict((note, number) for number, note in enumerate(note_names))
    network_input = []
    network_output = []
    for i in range(len(notes) - sequence_length):
        input_sequence = notes[i:i + sequence_length]
        output_sequence = notes[i + sequence_length]
        network_input.append([note_to_int[char] for char in input_sequence])
        network_output.append(note_to_int[output_sequence])
    if not network_input:
        raise ValueError("No sequences generated. Ensure notes list is long enough.")
    network_input = np.array(network_input)
    network_output = to_categorical(network_output, num_classes=vocab_size)
    print(f"Generated {len(network_input)} sequences")
    return network_input, network_output

def create_hybrid_model(network_input, vocab_size):
    """ Creates the hybrid LSTM-Transformer model structure """
    sequence_length = network_input.shape[1]
    d_model = 512  # Embedding dimension, matched to lstm_units
    num_heads = 8  # Number of attention heads
    dff = 512      # Feed-forward layer dimension
    lstm_units = 512  # LSTM units

    # Input layer
    inputs = Input(shape=(sequence_length,))
    
    # Embedding layer
    x = Embedding(input_dim=vocab_size, output_dim=d_model)(inputs)
    
    # LSTM layers
    x = LSTM(lstm_units, return_sequences=True, recurrent_dropout=0.3)(x)
    x = Dropout(0.3)(x)
    x = LSTM(lstm_units, return_sequences=True, recurrent_dropout=0.3)(x)
    x = Dropout(0.3)(x)
    
    # Transformer block
    attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)(x, x)
    attn_output = Dropout(0.1)(attn_output)
    out1 = LayerNormalization(epsilon=1e-6)(x + attn_output)
    ffn_output = Dense(dff, activation='relu')(out1)
    ffn_output = Dense(d_model)(ffn_output)  # Output d_model to match out1
    ffn_output = Dropout(0.1)(ffn_output)
    x = LayerNormalization(epsilon=1e-6)(out1 + ffn_output)
    
    # Output layer
    x = Dense(256, activation='relu')(x[:, -1, :])  # Take last timestep
    x = Dropout(0.3)(x)
    outputs = Dense(vocab_size, activation='softmax')(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy')
    return model

def train_model(model, network_input, network_output, start_epoch):
    """ Trains the hybrid model """
    # Save weights every epoch in Kaggle's working directory
    os.makedirs('/kaggle/working/hybrid_outputs/weights', exist_ok=True)
    filepath = "/kaggle/working/hybrid_outputs/weights/weights_hybrid-epoch{epoch:02d}-loss{loss:.4f}.keras"
    checkpoint = ModelCheckpoint(
        filepath,
        monitor='loss',
        verbose=1,
        save_best_only=False,
        save_weights_only=False,
        mode='min'
    )
    
    # Custom callback to save last epoch
    class SaveLastEpoch(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            with open('/kaggle/working/hybrid_outputs/last_epoch.txt', 'w') as f:
                f.write(str(epoch + 1))
    
    callbacks_list = [checkpoint, SaveLastEpoch()]
    
    # Train from start_epoch to 150
    model.fit(
        network_input, 
        network_output, 
        epochs=150, 
        batch_size=64, 
        callbacks=callbacks_list, 
        verbose=1,
        initial_epoch=start_epoch
    )

def generate_music(model, network_input, note_names, sequence_length=25, generation_length=100):
    """ Generates music using the trained hybrid model """
    start = np.random.randint(0, len(network_input) - 1)
    pattern = network_input[start].copy()
    int_to_note = dict((number, note) for number, note in enumerate(note_names))
    prediction_output = []
    
    for _ in range(generation_length):
        prediction_input = np.reshape(pattern, (1, sequence_length))
        prediction = model.predict(prediction_input, verbose=0)
        index = np.argmax(prediction)
        result = int_to_note[index]
        prediction_output.append(result)
        pattern = np.append(pattern[1:], index)
    
    return prediction_output

if __name__ == '__main__':
    try:
        train_hybrid()
    except Exception as e:
        print(f"Training failed: {e}")
    # Optional: Generate music after training (uncomment after training)
    # notes = get_notes()
    # vocab_size = len(set(notes))
    # network_input, _ = prepare_sequences(notes, vocab_size)
    # model = create_hybrid_model(network_input, vocab_size)
    # model.load_weights(latest_weights_file)  # Load your trained weights
    # generated = generate_music(model, network_input, sorted(set(notes)))
    # print("Generated music:", generated)