In [1]:
import os
import zipfile
import pickle
import numpy as np
import music21 as m21
import tensorflow as tf
from tensorflow.keras import layers, callbacks

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget,
    QPushButton, QLabel, QFileDialog, QProgressBar, QSpinBox,
    QDoubleSpinBox, QTextEdit, QGroupBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QFont, QPalette, QColor

# ---------------------------
# Constants and Configuration
# ---------------------------
DEFAULT_SEQ_LENGTH = 50
DEFAULT_TEMPERATURE = 1.0
DEFAULT_GEN_NOTES = 200
DEFAULT_EPOCHS = 10
DEFAULT_BATCH_SIZE = 64


# =======================================================================
#                                APP UI
# =======================================================================
class MusicGeneratorApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.model = None
        self.notes = []
        self.unique_notes = []
        self.note_to_int = {}
        self.int_to_note = {}
        self.X = None
        self.y = None
        self.generated_notes = []

        # Threads
        self.load_thread = None
        self.train_thread = None
        self.gen_thread = None

        self.init_ui()

    def init_ui(self):
        self.setWindowTitle("🎵 AI Music Generator")
        self.setGeometry(100, 100, 900, 700)

        # Modern dark theme
        self.set_dark_theme()

        # Main widget + layout
        main_widget = QWidget()
        self.setCentralWidget(main_widget)
        main_layout = QVBoxLayout()
        main_widget.setLayout(main_layout)

        # Title
        title = QLabel("🎼 AI Music Generator")
        title.setFont(QFont('Arial', 20, QFont.Bold))
        title.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(title)

        # ---------------- Settings Group ----------------
        settings_group = QGroupBox("⚙️ Settings")
        settings_layout = QVBoxLayout()

        # Dataset controls
        dataset_layout = QHBoxLayout()
        self.dataset_label = QLabel("No dataset loaded")
        self.load_dataset_btn = QPushButton("📂 Load MIDI Dataset (ZIP or single MIDI)")
        self.load_dataset_btn.clicked.connect(self.load_dataset)
        dataset_layout.addWidget(self.dataset_label, stretch=1)
        dataset_layout.addWidget(self.load_dataset_btn, stretch=0)
        settings_layout.addLayout(dataset_layout)

        # Model controls
        model_layout = QHBoxLayout()
        self.load_model_btn = QPushButton("📥 Load Model")
        self.load_model_btn.clicked.connect(self.load_model)
        self.save_model_btn = QPushButton("💾 Save Model")
        self.save_model_btn.clicked.connect(self.save_model)
        model_layout.addWidget(self.load_model_btn)
        model_layout.addWidget(self.save_model_btn)
        settings_layout.addLayout(model_layout)

        # Training parameters
        param_layout = QHBoxLayout()
        param_layout.addWidget(QLabel("Sequence Length:"))
        self.seq_length_spin = QSpinBox()
        self.seq_length_spin.setRange(10, 200)
        self.seq_length_spin.setValue(DEFAULT_SEQ_LENGTH)
        param_layout.addWidget(self.seq_length_spin)

        param_layout.addWidget(QLabel("Epochs:"))
        self.epochs_spin = QSpinBox()
        self.epochs_spin.setRange(1, 100)
        self.epochs_spin.setValue(DEFAULT_EPOCHS)
        param_layout.addWidget(self.epochs_spin)

        param_layout.addWidget(QLabel("Batch Size:"))
        self.batch_spin = QSpinBox()
        self.batch_spin.setRange(8, 256)
        self.batch_spin.setValue(DEFAULT_BATCH_SIZE)
        param_layout.addWidget(self.batch_spin)
        settings_layout.addLayout(param_layout)

        settings_group.setLayout(settings_layout)
        main_layout.addWidget(settings_group)

        # ---------------- Generation Group ----------------
        gen_group = QGroupBox("🎹 Music Generation")
        gen_layout = QVBoxLayout()

        gen_param_layout = QHBoxLayout()
        gen_param_layout.addWidget(QLabel("Notes to Generate:"))
        self.gen_notes_spin = QSpinBox()
        self.gen_notes_spin.setRange(10, 2000)
        self.gen_notes_spin.setValue(DEFAULT_GEN_NOTES)
        gen_param_layout.addWidget(self.gen_notes_spin)

        gen_param_layout.addWidget(QLabel("Temperature:"))
        self.temp_spin = QDoubleSpinBox()
        self.temp_spin.setRange(0.1, 2.0)
        self.temp_spin.setSingleStep(0.1)
        self.temp_spin.setValue(DEFAULT_TEMPERATURE)
        gen_param_layout.addWidget(self.temp_spin)
        gen_layout.addLayout(gen_param_layout)

        self.generate_btn = QPushButton("🎶 Generate Music")
        self.generate_btn.clicked.connect(self.generate_music)
        self.generate_btn.setEnabled(False)
        gen_layout.addWidget(self.generate_btn)

        self.save_midi_btn = QPushButton("💾 Save MIDI File")
        self.save_midi_btn.clicked.connect(self.save_midi)
        self.save_midi_btn.setEnabled(False)
        gen_layout.addWidget(self.save_midi_btn)

        gen_group.setLayout(gen_layout)
        main_layout.addWidget(gen_group)

        # Progress + status
        self.progress_bar = QProgressBar()
        self.progress_bar.setRange(0, 100)
        self.progress_bar.setValue(0)
        main_layout.addWidget(self.progress_bar)

        self.status_label = QLabel("✅ Ready")
        self.status_label.setAlignment(Qt.AlignCenter)
        main_layout.addWidget(self.status_label)

        # Log
        self.log_text = QTextEdit()
        self.log_text.setReadOnly(True)
        self.log_text.setPlaceholderText("Logs will appear here...")
        main_layout.addWidget(self.log_text, stretch=1)

        # Train button
        self.train_btn = QPushButton("🚀 Train Model")
        self.train_btn.clicked.connect(self.train_model)
        self.train_btn.setEnabled(False)
        main_layout.addWidget(self.train_btn)

        # Enable/disable controls based on state
        self.update_ui_state()

    def set_dark_theme(self):
        """Apply modern dark theme with gradient background and styled widgets"""
        dark_palette = QPalette()
        dark_palette.setColor(QPalette.Window, QColor(30, 30, 30))
        dark_palette.setColor(QPalette.WindowText, Qt.white)
        dark_palette.setColor(QPalette.Base, QColor(20, 20, 20))
        dark_palette.setColor(QPalette.AlternateBase, QColor(45, 45, 45))
        dark_palette.setColor(QPalette.ToolTipBase, Qt.white)
        dark_palette.setColor(QPalette.ToolTipText, Qt.white)
        dark_palette.setColor(QPalette.Text, Qt.white)
        dark_palette.setColor(QPalette.Button, QColor(40, 40, 40))
        dark_palette.setColor(QPalette.ButtonText, Qt.white)
        dark_palette.setColor(QPalette.BrightText, Qt.red)
        dark_palette.setColor(QPalette.Link, QColor(120, 180, 255))
        dark_palette.setColor(QPalette.Highlight, QColor(70, 120, 200))
        dark_palette.setColor(QPalette.HighlightedText, Qt.black)

        self.setPalette(dark_palette)

        # Custom stylesheet for a friendlier look
        self.setStyleSheet("""
            QMainWindow {
                background-color: qlineargradient(
                    x1:0, y1:0, x2:1, y2:1,
                    stop:0 #2b2b2b, stop:1 #1e1e1e
                );
            }
            QLabel {
                font-size: 14px;
            }
            QPushButton {
                background-color: #3a3f44;
                color: white;
                border: 1px solid #5a5a5a;
                border-radius: 6px;
                padding: 8px 12px;
                font-size: 13px;
            }
            QPushButton:hover {
                background-color: #50565c;
            }
            QPushButton:disabled {
                background-color: #2b2b2b;
                color: #777;
                border: 1px solid #444;
            }
            QSpinBox, QDoubleSpinBox, QComboBox {
                background-color: #2a2a2a;
                color: white;
                border: 1px solid #555;
                border-radius: 4px;
                padding: 3px 6px;
            }
            QTextEdit {
                background-color: #121212;
                color: #aef1c1;
                border: 1px solid #444;
                border-radius: 6px;
                font-family: Consolas, monospace;
                font-size: 13px;
                padding: 8px;
            }
            QGroupBox {
                border: 1px solid #666;
                border-radius: 8px;
                margin-top: 10px;
                padding: 8px;
                font-weight: bold;
                font-size: 14px;
                color: #ddd;
            }
            QGroupBox::title {
                subcontrol-origin: margin;
                subcontrol-position: top left;
                padding: 0 5px;
            }
            QProgressBar {
                border: 1px solid #555;
                border-radius: 6px;
                text-align: center;
                height: 20px;
                font-size: 12px;
                background-color: #222;
            }
            QProgressBar::chunk {
                background-color: #00bfff;
                border-radius: 6px;
            }
        """)

    # ---------------- Utility/UI helpers ----------------
    def update_ui_state(self):
        has_dataset = len(self.notes) > 0
        has_model = self.model is not None

        self.train_btn.setEnabled(has_dataset)
        self.generate_btn.setEnabled(has_model)
        self.save_model_btn.setEnabled(has_model)

    def log_message(self, message: str):
        self.log_text.append(message)
        self.status_label.setText(message)
        QApplication.processEvents()

    def log_model_summary(self):
        """Log model summary to the UI"""
        if self.model is None:
            return
            
        stringlist = []
        self.model.summary(print_fn=lambda x: stringlist.append(x))
        summary = "\n".join(stringlist)
        self.log_message(f"🧠 Model Summary:\n{summary}")

    # ---------------- Dataset loading ----------------
    def load_dataset(self):
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Select MIDI Dataset (ZIP) or single MIDI",
            "",
            "ZIP Files (*.zip);;MIDI Files (*.mid *.midi);;All Files (*)",
            options=options
        )
        if not file_path:
            return

        self.progress_bar.setValue(0)
        self.log_message(f"📂 Loading dataset from: {file_path}")

        self.load_thread = DatasetLoader(file_path)
        self.load_thread.progress_signal.connect(self.update_progress)
        self.load_thread.finished_signal.connect(self.dataset_loaded)
        self.load_thread.start()

    def update_progress(self, value: int, message: str):
        self.progress_bar.setValue(value)
        if message:
            self.log_message(message)

    def dataset_loaded(self, notes):
        if len(notes) < 100:
            self.log_message(f"⚠️ Error: Only {len(notes)} notes extracted. Need at least 100.")
            self.notes = []
            self.unique_notes = []
            self.note_to_int = {}
            self.int_to_note = {}
            self.update_ui_state()
            return

        self.notes = notes
        self.unique_notes = sorted(set(notes))
        self.note_to_int = {note: number for number, note in enumerate(self.unique_notes)}
        self.int_to_note = {i: n for n, i in self.note_to_int.items()}

        self.log_message(f"✅ Dataset loaded: {len(self.notes)} notes ({len(self.unique_notes)} unique)")
        self.dataset_label.setText(f"Dataset: {len(self.notes)} notes • {len(self.unique_notes)} unique")
        self.update_ui_state()

    # ---------------- Sequences + Model ----------------
    def prepare_sequences(self):
        """Make integer inputs and integer targets (sparse), lightweight and fast."""
        sequence_length = self.seq_length_spin.value()
        inputs, targets = [], []

        for i in range(len(self.notes) - sequence_length):
            seq_in = self.notes[i:i + sequence_length]
            seq_out = self.notes[i + sequence_length]
            inputs.append([self.note_to_int[n] for n in seq_in])
            targets.append(self.note_to_int[seq_out])

        # Sparse targets to save memory and speed up
        self.X = np.array(inputs, dtype=np.int32)
        self.y = np.array(targets, dtype=np.int32)

        self.log_message(f"✅ Prepared sequences: X={self.X.shape}, y={self.y.shape}")

    def build_model(self):
        sequence_length = self.seq_length_spin.value()
        model = tf.keras.Sequential([
            layers.Embedding(input_dim=len(self.unique_notes), 
                            output_dim=100, 
                            input_length=sequence_length),
            layers.LSTM(256, return_sequences=True),
            layers.Dropout(0.3),
            layers.LSTM(256),
            layers.Dense(256, activation="relu"),
            layers.Dense(len(self.unique_notes), activation="softmax")
        ])
        # Use sparse categorical to match integer labels
        model.compile(loss="sparse_categorical_crossentropy", 
                    optimizer="adam", 
                    metrics=["accuracy"])
        return model

    # ---------------- Training ----------------
    def train_model(self):
        if len(self.notes) < 100:
            self.log_message("⚠️ Not enough notes to train. Load a larger dataset.")
            return

        # Prepare data (main thread). If your dataset is huge, we can also move this into the worker.
        if self.X is None or self.y is None:
            self.prepare_sequences()

        if self.model is None:
            self.model = self.build_model()
            self.log_model_summary()

        epochs = self.epochs_spin.value()
        batch_size = self.batch_spin.value()

        self.log_message(f"🚀 Starting training for {epochs} epochs...")
        self.progress_bar.setValue(0)

        # Disable buttons during training
        self.train_btn.setEnabled(False)
        self.load_dataset_btn.setEnabled(False)

        # Train in worker thread with thread-safe progress updates
        self.train_thread = ModelTrainer(self.model, self.X, self.y, epochs, batch_size)
        self.train_thread.progress_signal.connect(self.update_progress)   # safe UI updates
        self.train_thread.error_signal.connect(lambda msg: self.log_message(f"⚠️ Training error: {msg}"))
        self.train_thread.finished_signal.connect(self.training_complete)
        self.train_thread.start()

    def training_complete(self):
        self.log_message("✅ Training completed!")
        self.train_btn.setEnabled(True)
        self.load_dataset_btn.setEnabled(True)
        self.update_ui_state()

    # ---------------- Generation ----------------
    def generate_music(self):
        if self.model is None:
            self.log_message("⚠️ Error: No model loaded or trained.")
            return

        sequence_length = self.seq_length_spin.value()
        num_notes = self.gen_notes_spin.value()
        temperature = self.temp_spin.value()

        # Create a valid seed sequence
        if not hasattr(self, 'notes') or len(self.notes) == 0:
            # If no dataset loaded, use the unique notes from the model
            if not hasattr(self, 'unique_notes') or len(self.unique_notes) == 0:
                self.log_message("⚠️ Error: No notes available for generation.")
                return
            seed_seq = (self.unique_notes * (sequence_length // max(1, len(self.unique_notes)) + 1))[:sequence_length]
        else:
            # Use loaded/trained notes
            if len(self.notes) < sequence_length:
                seed_seq = (self.notes * (sequence_length // max(1, len(self.notes)) + 1))[:sequence_length]
            else:
                seed_seq = self.notes[:sequence_length]

        # Verify we can map all seed notes to integers
        try:
            [self.note_to_int[n] for n in seed_seq]
        except KeyError as e:
            self.log_message(f"⚠️ Error: Note {e} not found in vocabulary. Load matching dataset or model.")
            return

        self.log_message(f"🎼 Generating {num_notes} notes (temperature={temperature})...")
        self.log_message(f"🔠 Using seed sequence: {seed_seq}")

        self.gen_thread = MusicGenerator(
            self.model, seed_seq, self.note_to_int, self.unique_notes,
            sequence_length, num_notes, temperature
        )
        self.gen_thread.progress_signal.connect(self.update_progress)
        self.gen_thread.finished_signal.connect(self.generation_complete)
        self.gen_thread.start()

    def generation_complete(self, generated_notes):
        if not generated_notes:
            self.log_message("⚠️ Generation failed.")
            self.save_midi_btn.setEnabled(False)
            return

        self.generated_notes = generated_notes
        self.log_message(f"🎶 Generation complete! First 20 notes: {generated_notes[:20]}")
        self.save_midi_btn.setEnabled(True)

    # ---------------- Save/Load Model ----------------
    def save_model(self):
        if self.model is None:
            self.log_message("⚠️ Error: No model to save.")
            return

        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save Model", "music_model.weights.h5",
            "Model Weights (*.weights.h5);;All Files (*)", options=options
        )

        if file_path:
            try:
                # Ensure proper file extension
                if not file_path.endswith('.weights.h5'):
                    file_path += '.weights.h5'
                    
                # Save model weights
                self.model.save_weights(file_path)

                # Automatically save metadata
                metadata_path = file_path.replace('.weights.h5', '.metadata.pkl')
                metadata = {
                    'unique_notes': self.unique_notes,
                    'note_to_int': self.note_to_int,
                    'int_to_note': self.int_to_note,
                    'seq_length': self.seq_length_spin.value(),
                    'model_config': self.model.get_config()  # Save model architecture
                }
                
                with open(metadata_path, 'wb') as f:
                    pickle.dump(metadata, f)

                self.log_message(f"💾 Saved model to:\n  {file_path}\n  {metadata_path}")
            except Exception as e:
                self.log_message(f"⚠️ Error saving model: {e}")

    def load_model(self):
        options = QFileDialog.Options()
        weights_path, _ = QFileDialog.getOpenFileName(
            self, "Select Model Weights", "",
            "Weights (*.weights.h5);;All Files (*)", options=options
        )
        if not weights_path:
            return

        try:
            # Automatically look for corresponding .pkl file
            metadata_path = weights_path.replace('.weights.h5', '.metadata.pkl')
            if not os.path.exists(metadata_path):
                # If not found, prompt user to select it
                metadata_path, _ = QFileDialog.getOpenFileName(
                    self, "Select Metadata File", "",
                    "Metadata (*.metadata.pkl);;All Files (*)", options=options
                )
                if not metadata_path:
                    return

            with open(metadata_path, 'rb') as f:
                metadata = pickle.load(f)

            # Reconstruct the model architecture and load all mappings
            self.unique_notes = metadata['unique_notes']
            self.note_to_int = metadata['note_to_int']
            self.int_to_note = metadata['int_to_note']
            self.seq_length_spin.setValue(metadata['seq_length'])
            
            # Rebuild the model from saved config
            self.model = tf.keras.Sequential.from_config(metadata['model_config'])
            self.model.compile(loss="sparse_categorical_crossentropy", 
                             optimizer="adam", 
                             metrics=["accuracy"])
            
            # Then load the weights
            self.model.load_weights(weights_path)

            # Initialize notes for generation (use unique notes as fallback)
            self.notes = self.unique_notes.copy() if not self.notes else self.notes

            self.log_message(f"✅ Loaded model from: {weights_path}")
            self.log_message(f"✅ Loaded metadata from: {metadata_path}")
            self.log_model_summary()
            self.dataset_label.setText(f"Model loaded • {len(self.unique_notes)} unique notes")
            self.update_ui_state()
        except Exception as e:
            self.log_message(f"⚠️ Error loading model: {e}")

    # ---------------- Save MIDI ----------------
    def save_midi(self):
        if not self.generated_notes:
            self.log_message("⚠️ Error: No generated music to save.")
            return

        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getSaveFileName(
            self, "Save MIDI File", "generated_music.mid",
            "MIDI Files (*.mid);;All Files (*)", options=options
        )
        if not file_path:
            return

        try:
            stream = m21.stream.Stream()
            for n in self.generated_notes:
                try:
                    # chord in "x.y.z" format where parts are digits -> build as pitch classes offset
                    if "." in n and all(p.strip("-").isdigit() for p in n.split(".")):
                        chord_notes = [m21.note.Note(int(num) + 60) for num in n.split(".")]
                        chord = m21.chord.Chord(chord_notes)
                        stream.append(chord)
                    else:
                        stream.append(m21.note.Note(n))
                except Exception as e:
                    self.log_message(f"⚠️ Skipped note {n}: {e}")

            stream.write("midi", fp=file_path)
            self.log_message(f"💾 Saved MIDI file: {file_path}")
        except Exception as e:
            self.log_message(f"⚠️ Error saving MIDI: {e}")


# =======================================================================
#                           WORKER THREADS
# =======================================================================
class DatasetLoader(QThread):
    progress_signal = pyqtSignal(int, str)
    finished_signal = pyqtSignal(list)

    def __init__(self, path):
        super().__init__()
        self.path = path

    def run(self):
        try:
            notes = []
            skipped = 0

            # Handle single MIDI directly
            if self.path.lower().endswith(('.mid', '.midi')):
                self.progress_signal.emit(5, "Parsing single MIDI file...")
                notes = self.parse_midi_file(self.path)
                self.progress_signal.emit(100, f"🎶 Parsed single MIDI. Notes: {len(notes)}")
                self.finished_signal.emit(notes)
                return

            # Else treat as zip
            extract_path = "temp_midi_extract"
            if os.path.exists(extract_path):
                # If folder exists (from previous runs), reuse to avoid re-extraction
                self.progress_signal.emit(5, "Using previously extracted dataset...")
            else:
                os.makedirs(extract_path, exist_ok=True)
                with zipfile.ZipFile(self.path, 'r') as zip_ref:
                    zip_ref.extractall(extract_path)
                self.progress_signal.emit(10, "✅ Dataset extracted. Processing MIDI files...")

            # Collect all MIDIs
            all_files = []
            for root, _, files in os.walk(extract_path):
                for file in files:
                    if file.lower().endswith((".mid", ".midi")):
                        all_files.append(os.path.join(root, file))

            total_files = len(all_files)
            if total_files == 0:
                self.progress_signal.emit(0, "⚠️ No MIDI files found in dataset.")
                self.finished_signal.emit([])
                return

            self.progress_signal.emit(15, f"Found {total_files} MIDI files to process...")

            limit = min(500, total_files)  # cap for speed
            for i, filepath in enumerate(all_files[:limit]):
                try:
                    notes.extend(self.parse_midi_file(filepath))
                except Exception:
                    skipped += 1

                # Update progress (15% -> 95%)
                progress = 15 + int(80 * (i + 1) / limit)
                self.progress_signal.emit(progress, f"Processed {i+1}/{limit} files...")

            self.progress_signal.emit(97, "Finalizing...")
            self.progress_signal.emit(100, f"🎶 Extraction complete. Total notes: {len(notes)} | Skipped files: {skipped}")
            self.finished_signal.emit(notes)

        except Exception as e:
            self.progress_signal.emit(0, f"⚠️ Error loading dataset: {e}")
            self.finished_signal.emit([])

    @staticmethod
    def parse_midi_file(filepath):
        notes = []
        midi = m21.converter.parse(filepath)
        parts = midi.flatten().notes
        for n in parts:
            if isinstance(n, m21.note.Note):
                notes.append(str(n.pitch))
            elif isinstance(n, m21.chord.Chord):
                # store as pitch classes joined by dots (e.g., "0.4.7")
                notes.append(".".join(str(p) for p in n.normalOrder))
        return notes


class ModelTrainer(QThread):
    progress_signal = pyqtSignal(int, str)  # (progress %, message)
    finished_signal = pyqtSignal()
    error_signal = pyqtSignal(str)

    def __init__(self, model, X, y, epochs, batch_size):
        super().__init__()
        self.model = model
        self.X = X
        self.y = y
        self.epochs = epochs
        self.batch_size = batch_size

    def run(self):
        try:
            # Bridge Keras callbacks -> Qt signals (thread-safe)
            total_epochs = self.epochs

            class KerasToQtCallback(callbacks.Callback):
                def on_epoch_end(self_inner, epoch, logs=None):
                    logs = logs or {}
                    loss_val = logs.get('loss', None)
                    msg = (f"Epoch {epoch+1}/{total_epochs} - loss: {loss_val:.4f}"
                           if loss_val is not None else
                           f"Epoch {epoch+1}/{total_epochs}")
                    progress = int(100 * (epoch + 1) / total_epochs)
                    # Emit from worker thread; Qt will deliver to main thread
                    self.progress_signal.emit(progress, msg)

            self.model.fit(
                self.X,
                self.y,
                epochs=self.epochs,
                batch_size=self.batch_size,
                callbacks=[KerasToQtCallback()],
                verbose=0
            )
        except Exception as e:
            self.error_signal.emit(str(e))
        finally:
            self.finished_signal.emit()


class MusicGenerator(QThread):
    progress_signal = pyqtSignal(int, str)
    finished_signal = pyqtSignal(list)

    def __init__(self, model, start_seq, note_to_int, unique_notes, seq_length, num_notes, temp):
        super().__init__()
        self.model = model
        self.start_seq = start_seq
        self.note_to_int = note_to_int
        self.unique_notes = unique_notes
        self.seq_length = seq_length
        self.num_notes = num_notes
        self.temp = temp

    def run(self):
        try:
            generated = self.start_seq[:]

            for i in range(self.num_notes):
                seq_input = np.array([self.note_to_int[n] for n in generated[-self.seq_length:]], dtype=np.int32)
                seq_input = seq_input.reshape(1, -1)
                prediction = self.model.predict(seq_input, verbose=0)[0]

                # Temperature scaling
                prediction = np.log(prediction + 1e-8) / self.temp
                exp_preds = np.exp(prediction)
                prediction = exp_preds / np.sum(exp_preds)
                idx = np.random.choice(len(prediction), p=prediction)

                generated.append(self.unique_notes[idx])

                if i % 10 == 0 or i == self.num_notes - 1:
                    progress = int(100 * (i + 1) / self.num_notes)
                    self.progress_signal.emit(progress, f"Generating note {i+1}/{self.num_notes}")

            self.progress_signal.emit(100, "Generation complete!")
            self.finished_signal.emit(generated)
        except Exception as e:
            self.progress_signal.emit(0, f"⚠️ Generation error: {e}")
            self.finished_signal.emit([])


# =======================================================================
#                                MAIN
# =======================================================================
if __name__ == "__main__":
    import sys

    # (Optional) reduce TF verbosity
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

    app = QApplication(sys.argv)
    window = MusicGeneratorApp()
    window.show()
    sys.exit(app.exec_())



SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
