In [None]:
# MIDI Style Transfer UI - Jupyter Notebook Cell
# Run this cell to launch the interactive UI

import torch
import numpy as np
import os
import json
import pretty_midi
from fluidsynth import Synth
import wave
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import threading
from pathlib import Path
import tempfile
import shutil

# Import your model classes (make sure these are available)
try:
    from GAN import Generator
    from TransformerVAE import TransformerVAE
    print("Model classes imported successfully")
except ImportError as e:
    print(f"Warning: Could not import model classes - {e}")
    print("Make sure GAN.py and TransformerVAE.py are in your Python path")

class MidiStyleTransferApp:
    def __init__(self):
        # Configuration
        self.config = self.create_config()
        self.models = {}
        self.selected_model = None
        self.input_midi_path = None
        
        # Create GUI
        self.root = tk.Tk()
        self.root.title("MIDI Style Transfer")
        self.root.geometry("600x500")
        
        self.setup_gui()
        
    def create_config(self):
        """Create configuration object with default paths - update these for your system"""
        class Config:
            # Update these paths for your system
            SOUNDFONT_PATH = r"C:\Users\User\Desktop\college\fyp\other\soundfont\GeneralUser GS v1.471.sf2"
            OUTPUT_DIR = r"C:\Users\User\Desktop\college\fyp\converted_samples"
            
            # Model paths - update these
            GAN_MODEL_PATH = r"C:\Users\User\Desktop\college\fyp\models\PPO_Tuned_GAN\ppo_tuned_GAN_best.pth"
            RAGAN_MODEL_PATH = r"C:\Users\User\Desktop\college\fyp\models\PPO_Tuned_RaGAN\ppo_tuned_RaGAN_best.pth"
            VAE_MODEL_PATH = r"C:\Users\User\Desktop\college\fyp\models\PPO_Tuned_VAE\ppo_tuned_VAE_epoch_50.pth"
            
            # Model parameters
            DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            PITCHES = 88
            TIMESTEPS = 128
            CHANNELS = 3
            NOISE_DIM = 100
            
            # VAE parameters
            LATENT_DIM = 256
            EMBED_DIM = 512
            NHEAD = 8
            NUM_ENCODER_LAYERS = 6
            NUM_DECODER_LAYERS = 6
            DROPOUT = 0.1
            
        return Config()
    
    def setup_gui(self):
        """Setup the GUI components"""
        # Main frame
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        # Title
        title_label = ttk.Label(main_frame, text="MIDI Style Transfer", font=("Arial", 16, "bold"))
        title_label.grid(row=0, column=0, columnspan=2, pady=(0, 20))
        
        # Model selection
        ttk.Label(main_frame, text="Select Model:", font=("Arial", 12)).grid(row=1, column=0, sticky=tk.W, pady=5)
        self.model_var = tk.StringVar(value="Select a model...")
        model_dropdown = ttk.Combobox(main_frame, textvariable=self.model_var, 
                                    values=["GAN", "RaGAN", "VAE"], state="readonly", width=30)
        model_dropdown.grid(row=1, column=1, sticky=(tk.W, tk.E), pady=5)
        model_dropdown.bind("<<ComboboxSelected>>", self.on_model_selected)
        
        # Load model button
        self.load_model_btn = ttk.Button(main_frame, text="Load Model", 
                                       command=self.load_model, state=tk.DISABLED)
        self.load_model_btn.grid(row=2, column=1, sticky=tk.W, pady=5)
        
        # Model status
        self.model_status_label = ttk.Label(main_frame, text="No model loaded", foreground="red")
        self.model_status_label.grid(row=3, column=0, columnspan=2, pady=5)
        
        # File selection
        ttk.Label(main_frame, text="Select MIDI File:", font=("Arial", 12)).grid(row=4, column=0, sticky=tk.W, pady=(20, 5))
        self.file_btn = ttk.Button(main_frame, text="Browse MIDI File", 
                                 command=self.browse_midi_file)
        self.file_btn.grid(row=4, column=1, sticky=tk.W, pady=(20, 5))
        
        # Selected file display
        self.file_label = ttk.Label(main_frame, text="No file selected", foreground="gray")
        self.file_label.grid(row=5, column=0, columnspan=2, pady=5)
        
        # Process button
        self.process_btn = ttk.Button(main_frame, text="Generate Style Transfer", 
                                    command=self.process_midi, state=tk.DISABLED,
                                    style="Accent.TButton")
        self.process_btn.grid(row=6, column=0, columnspan=2, pady=20)
        
        # Progress bar
        self.progress = ttk.Progressbar(main_frame, mode='indeterminate')
        self.progress.grid(row=7, column=0, columnspan=2, sticky=(tk.W, tk.E), pady=5)
        
        # Status text
        self.status_text = tk.Text(main_frame, height=10, width=70, state=tk.DISABLED)
        self.status_text.grid(row=8, column=0, columnspan=2, pady=10, sticky=(tk.W, tk.E, tk.N, tk.S))
        
        # Scrollbar for status text
        scrollbar = ttk.Scrollbar(main_frame, orient=tk.VERTICAL, command=self.status_text.yview)
        scrollbar.grid(row=8, column=2, sticky=(tk.N, tk.S), pady=10)
        self.status_text.configure(yscrollcommand=scrollbar.set)
        
        # Configure grid weights
        main_frame.columnconfigure(1, weight=1)
        main_frame.rowconfigure(8, weight=1)
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)
    
    def log_status(self, message):
        """Add message to status text"""
        self.status_text.config(state=tk.NORMAL)
        self.status_text.insert(tk.END, f"{message}\n")
        self.status_text.see(tk.END)
        self.status_text.config(state=tk.DISABLED)
        self.root.update_idletasks()
    
    def on_model_selected(self, event=None):
        """Enable load model button when a model is selected"""
        if self.model_var.get() != "Select a model...":
            self.load_model_btn.config(state=tk.NORMAL)
    
    def load_model(self):
        """Load the selected model"""
        model_name = self.model_var.get()
        if model_name == "Select a model...":
            return
            
        self.log_status(f"Loading {model_name} model...")
        
        try:
            if model_name == 'GAN':
                if not os.path.exists(self.config.GAN_MODEL_PATH):
                    raise FileNotFoundError(f"GAN model file not found: {self.config.GAN_MODEL_PATH}")
                model = Generator(self.config).to(self.config.DEVICE)
                model.load_state_dict(torch.load(self.config.GAN_MODEL_PATH, map_location=self.config.DEVICE))
                
            elif model_name == 'RaGAN':
                if not os.path.exists(self.config.RAGAN_MODEL_PATH):
                    raise FileNotFoundError(f"RaGAN model file not found: {self.config.RAGAN_MODEL_PATH}")
                model = Generator(self.config).to(self.config.DEVICE)
                model.load_state_dict(torch.load(self.config.RAGAN_MODEL_PATH, map_location=self.config.DEVICE))
                
            elif model_name == 'VAE':
                if not os.path.exists(self.config.VAE_MODEL_PATH):
                    raise FileNotFoundError(f"VAE model file not found: {self.config.VAE_MODEL_PATH}")
                model = TransformerVAE(self.config).to(self.config.DEVICE)
                model.load_state_dict(torch.load(self.config.VAE_MODEL_PATH, map_location=self.config.DEVICE))
            
            model.eval()
            self.models[model_name] = model
            self.selected_model = model_name
            
            self.model_status_label.config(text=f"{model_name} model loaded successfully", foreground="green")
            self.log_status(f"{model_name} model loaded successfully")
            self.update_process_button_state()
            
        except Exception as e:
            error_msg = f"Error loading {model_name} model: {str(e)}"
            self.model_status_label.config(text=error_msg, foreground="red")
            self.log_status(error_msg)
            messagebox.showerror("Model Loading Error", error_msg)
    
    def browse_midi_file(self):
        """Open file dialog to select MIDI file"""
        file_path = filedialog.askopenfilename(
            title="Select MIDI File",
            filetypes=[("MIDI files", "*.mid *.midi"), ("All files", "*.*")]
        )
        
        if file_path:
            self.input_midi_path = file_path
            filename = os.path.basename(file_path)
            self.file_label.config(text=f"Selected: {filename}", foreground="black")
            self.log_status(f"Selected MIDI file: {filename}")
            self.update_process_button_state()
    
    def update_process_button_state(self):
        """Enable process button if both model and file are selected"""
        if self.selected_model and self.input_midi_path:
            self.process_btn.config(state=tk.NORMAL)
        else:
            self.process_btn.config(state=tk.DISABLED)
    
    def midi_to_matrices(self, midi_path):
        """Convert MIDI file to 3-channel matrix representation"""
        try:
            pm = pretty_midi.PrettyMIDI(midi_path)
            
            # Find piano tracks
            all_piano_notes = []
            PIANO_PROGRAMS = list(range(8))  # Piano family instruments
            
            for instrument in pm.instruments:
                if instrument.program in PIANO_PROGRAMS:
                    all_piano_notes.extend(instrument.notes)
            
            if not all_piano_notes:
                raise ValueError("No piano tracks found in MIDI file")
            
            # Sort notes by start time
            all_piano_notes.sort(key=lambda x: x.start)
            
            # Calculate timing parameters
            song_end_time = max(note.end for note in all_piano_notes)
            resolution = 24
            tempo_bpm = pm.get_tempo_changes()[1][0] if len(pm.get_tempo_changes()[1]) > 0 else 120.0
            ticks_per_second = (resolution * tempo_bpm) / 60.0
            
            total_timesteps = int(np.ceil(song_end_time * ticks_per_second))
            
            # Initialize matrices
            onset_matrix = np.zeros((total_timesteps, 88), dtype=np.float32)
            sustain_matrix = np.zeros((total_timesteps, 88), dtype=np.float32)
            velocity_matrix = np.zeros((total_timesteps, 88), dtype=np.float32)
            
            # Populate matrices
            LOWEST_PITCH = 21
            for note in all_piano_notes:
                if 21 <= note.pitch <= 108:  # 88-key piano range
                    pitch_idx = note.pitch - LOWEST_PITCH
                    
                    onset_step = int(round(note.start * ticks_per_second))
                    offset_step = int(round(note.end * ticks_per_second))
                    
                    if onset_step < total_timesteps:
                        # Onset and velocity
                        onset_matrix[onset_step, pitch_idx] = 1.0
                        velocity_matrix[onset_step, pitch_idx] = note.velocity / 127.0
                        
                        # Sustain
                        sustain_end = min(offset_step, total_timesteps)
                        sustain_matrix[onset_step:sustain_end, pitch_idx] = 1.0
            
            return onset_matrix, sustain_matrix, velocity_matrix
            
        except Exception as e:
            raise Exception(f"Error converting MIDI to matrices: {str(e)}")
    
    def matrices_to_segments(self, onset_matrix, sustain_matrix, velocity_matrix):
        """Convert matrices to model input segments"""
        segments = []
        timesteps = self.config.TIMESTEPS
        
        # Combine matrices into 3-channel representation
        combined_matrix = np.stack([onset_matrix, sustain_matrix, velocity_matrix], axis=0)  # Shape: (3, time, pitch)
        
        # Split into segments
        total_time = combined_matrix.shape[1]
        num_segments = (total_time + timesteps - 1) // timesteps  # Ceiling division
        
        for i in range(num_segments):
            start_idx = i * timesteps
            end_idx = min(start_idx + timesteps, total_time)
            
            # Extract segment
            segment = combined_matrix[:, start_idx:end_idx, :]
            
            # Pad if necessary
            if segment.shape[1] < timesteps:
                padding = np.zeros((3, timesteps - segment.shape[1], 88))
                segment = np.concatenate([segment, padding], axis=1)
            
            segments.append(torch.tensor(segment, dtype=torch.float32))
        
        return segments
    
    def segments_to_midi(self, segments, output_path, onset_threshold=0.5):
        """Convert model output segments back to MIDI"""
        # Concatenate all segments
        full_tensor = torch.cat(segments, dim=1)  # Shape: (3, total_time, 88)
        
        pm = pretty_midi.PrettyMIDI(initial_tempo=120.0)
        instrument = pretty_midi.Instrument(program=0, name='Acoustic Grand Piano')
        
        onset_probs = full_tensor[0].numpy()
        sustain_probs = full_tensor[1].numpy()
        velocity_values = full_tensor[2].numpy()
        
        reconstructed_onset = (onset_probs > onset_threshold).astype(int)
        
        resolution = 24
        lowest_pitch = 21
        tick_duration = 60.0 / (120.0 * resolution)
        
        active_notes = {}
        total_timesteps = reconstructed_onset.shape[0]
        
        for t_step in range(total_timesteps):
            for pitch_idx in range(88):
                if pitch_idx not in active_notes and reconstructed_onset[t_step, pitch_idx] == 1:
                    start_time = t_step * tick_duration
                    velocity = int(velocity_values[t_step, pitch_idx] * 126) + 1
                    active_notes[pitch_idx] = (start_time, velocity)
                elif pitch_idx in active_notes and sustain_probs[t_step, pitch_idx] < 0.5:
                    start_time, velocity = active_notes.pop(pitch_idx)
                    end_time = t_step * tick_duration
                    if end_time > start_time:
                        instrument.notes.append(pretty_midi.Note(
                            velocity=max(1, min(127, velocity)),
                            pitch=pitch_idx + lowest_pitch,
                            start=start_time,
                            end=end_time
                        ))
        
        # Close any remaining active notes
        for pitch_idx, (start_time, velocity) in active_notes.items():
            end_time = total_timesteps * tick_duration
            if end_time > start_time:
                instrument.notes.append(pretty_midi.Note(
                    velocity=max(1, min(127, velocity)),
                    pitch=pitch_idx + lowest_pitch,
                    start=start_time,
                    end=end_time
                ))
        
        pm.instruments.append(instrument)
        pm.write(output_path)
    
    def midi_to_wav(self, midi_path, wav_path):
        """Convert MIDI to WAV using FluidSynth"""
        try:
            if not os.path.exists(self.config.SOUNDFONT_PATH):
                raise FileNotFoundError(f"Soundfont not found: {self.config.SOUNDFONT_PATH}")
            
            pm = pretty_midi.PrettyMIDI(midi_path)
            audio_data = pm.fluidsynth(fs=44100)
            
            with wave.open(wav_path, 'wb') as wf:
                wf.setnchannels(1)
                wf.setsampwidth(2)
                wf.setframerate(44100)
                wf.writeframes((audio_data * 32767).astype(np.int16).tobytes())
                
        except Exception as e:
            raise Exception(f"Error converting MIDI to WAV: {str(e)}")
    
    def process_midi_thread(self):
        """Process MIDI file in separate thread"""
        try:
            self.log_status("Starting MIDI style transfer process...")
            
            # Create output directory
            os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
            
            # Step 1: Convert MIDI to matrices
            self.log_status("Converting MIDI to matrices...")
            onset_matrix, sustain_matrix, velocity_matrix = self.midi_to_matrices(self.input_midi_path)
            self.log_status(f"Matrix shapes: {onset_matrix.shape}")
            
            # Step 2: Convert to segments
            self.log_status("Preparing model input segments...")
            input_segments = self.matrices_to_segments(onset_matrix, sustain_matrix, velocity_matrix)
            self.log_status(f"Created {len(input_segments)} segments")
            
            # Step 3: Run model inference
            self.log_status(f"Running {self.selected_model} model inference...")
            model = self.models[self.selected_model]
            output_segments = []
            
            with torch.no_grad():
                for i, segment in enumerate(input_segments):
                    input_tensor = segment.unsqueeze(0).to(self.config.DEVICE)  # Add batch dimension
                    
                    if self.selected_model in ['GAN', 'RaGAN']:
                        noise = torch.randn(1, self.config.NOISE_DIM).to(self.config.DEVICE)
                        output = model(input_tensor, noise)
                    elif self.selected_model == 'VAE':
                        output, _, _ = model(input_tensor)
                    
                    output_segments.append(output.squeeze(0).cpu())  # Remove batch dimension
                    
                    if (i + 1) % 10 == 0:
                        self.log_status(f"Processed {i + 1}/{len(input_segments)} segments")
            
            # Step 4: Convert back to MIDI
            base_filename = os.path.splitext(os.path.basename(self.input_midi_path))[0]
            midi_output_path = os.path.join(self.config.OUTPUT_DIR, f"{base_filename}_{self.selected_model}_output.mid")
            wav_output_path = os.path.join(self.config.OUTPUT_DIR, f"{base_filename}_{self.selected_model}_output.wav")
            
            self.log_status("Converting output to MIDI...")
            onset_threshold = 0.1 if self.selected_model == 'VAE' else 0.5
            self.segments_to_midi(output_segments, midi_output_path, onset_threshold)
            
            # Step 5: Convert MIDI to WAV
            self.log_status("Converting MIDI to WAV...")
            self.midi_to_wav(midi_output_path, wav_output_path)
            
            self.log_status(f"✓ Process completed successfully!")
            self.log_status(f"✓ MIDI saved to: {midi_output_path}")
            self.log_status(f"✓ WAV saved to: {wav_output_path}")
            
            # Show completion message
            self.root.after(0, lambda: messagebox.showinfo(
                "Success", 
                f"Style transfer completed!\n\nFiles saved to:\n{self.config.OUTPUT_DIR}"
            ))
            
        except Exception as e:
            error_msg = f"Error during processing: {str(e)}"
            self.log_status(f"✗ {error_msg}")
            self.root.after(0, lambda: messagebox.showerror("Processing Error", error_msg))
        
        finally:
            # Re-enable button and stop progress bar
            self.root.after(0, self.processing_complete)
    
    def process_midi(self):
        """Start MIDI processing in separate thread"""
        self.process_btn.config(state=tk.DISABLED)
        self.progress.start()
        
        # Start processing in separate thread
        thread = threading.Thread(target=self.process_midi_thread)
        thread.daemon = True
        thread.start()
    
    def processing_complete(self):
        """Called when processing is complete"""
        self.progress.stop()
        self.process_btn.config(state=tk.NORMAL)
    
    def run(self):
        """Start the GUI application"""
        self.log_status("MIDI Style Transfer Application Started")
        self.log_status("1. Select a model and click 'Load Model'")
        self.log_status("2. Browse and select a MIDI file")
        self.log_status("3. Click 'Generate Style Transfer' to process")
        
        self.root.mainloop()

# Create and run the application
if __name__ == "__main__":
    try:
        app = MidiStyleTransferApp()
        app.run()
    except Exception as e:
        print(f"Error starting application: {e}")
        import traceback
        traceback.print_exc()
else:
    # For Jupyter notebook
    app = MidiStyleTransferApp()
    app.run()

Model classes imported successfully


  model.load_state_dict(torch.load(self.config.RAGAN_MODEL_PATH, map_location=self.config.DEVICE))
