In [7]:
import os
import numpy as np
import tensorflow as tf
import tarfile
import json
from typing import Optional, Dict, Any, Tuple
import argparse
from pathlib import Path
import seaborn as sns


In [8]:
# MuseGAN Music Generator - Jupyter Notebook
# Run each cell sequentially to generate music using pretrained MuseGAN models

# Cell 1: Install Required Packages
# Run this cell first to install all necessary dependencies
import subprocess
import sys


# Uncomment the line below if you need to install packages
# install_packages()

# Cell 2: Import Libraries and Set Up Environment
import os
import numpy as np
import tensorflow as tf
import tarfile
import json
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Dict, Any, Tuple
from pathlib import Path
from IPython.display import display, Audio, HTML, Markdown
import warnings
warnings.filterwarnings('ignore')

# Disable TensorFlow warnings for cleaner output
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Set up matplotlib for better plotting in notebooks
plt.style.use('default')
sns.set_palette("husl")

print("✅ Libraries imported successfully!")
print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")

# Cell 3: MuseGAN Generator Class Definition
class MuseGANGenerator:
    """
    Standalone MuseGAN music generator optimized for Jupyter notebooks.
    """
    
    def __init__(self, model_path: str, verbose: bool = True):
        """
        Initialize the MuseGAN generator.
        
        Args:
            model_path: Path to the pretrained model directory or tar file
            verbose: Whether to print detailed information
        """
        self.model_path = model_path
        self.model = None
        self.config = None
        self.sess = None
        self.graph = None
        self.verbose = verbose
        
        # Track names for visualization
        self.track_names = ['Drums', 'Piano', 'Guitar', 'Bass', 'Strings']
        self.track_colors = ['red', 'blue', 'green', 'orange', 'purple']
        
        # Music generation parameters
        self.n_tracks = 5
        self.n_bars = 4
        self.n_steps_per_bar = 24
        self.n_pitches = 128
        self.lowest_pitch = 24
        
        if self.verbose:
            print("🎵 MuseGAN Generator initialized!")
    
    def extract_model(self, tar_path: str) -> str:
        """Extract the tar file containing the pretrained model."""
        extract_dir = tar_path.replace('.tar', '_extracted')
        
        if not os.path.exists(extract_dir):
            if self.verbose:
                print(f"📦 Extracting {tar_path} to {extract_dir}...")
            with tarfile.open(tar_path, 'r') as tar:
                tar.extractall(extract_dir)
            if self.verbose:
                print("✅ Extraction complete!")
        else:
            if self.verbose:
                print(f"📁 Using existing extracted directory: {extract_dir}")
        
        return extract_dir
    
    def load_model(self) -> bool:
        """Load the pretrained MuseGAN model."""
        model_dir = self.model_path
        
        # If it's a tar file, extract it first
        if self.model_path.endswith('.tar'):
            model_dir = self.extract_model(self.model_path)
        
        # Look for model files
        ckpt_files = []
        config_file = None
        
        for root, dirs, files in os.walk(model_dir):
            for file in files:
                if file.endswith('.meta'):
                    ckpt_files.append(os.path.join(root, file.replace('.meta', '')))
                elif file == 'config.json':
                    config_file = os.path.join(root, file)
        
        if not ckpt_files:
            if self.verbose:
                print("❌ No checkpoint files found. Will create fallback generator.")
            return False
        
        # Load configuration if available
        if config_file and os.path.exists(config_file):
            with open(config_file, 'r') as f:
                self.config = json.load(f)
                if self.verbose:
                    print(f"📄 Loaded configuration: {config_file}")
        
        # Use the most recent checkpoint
        ckpt_path = ckpt_files[0]
        if self.verbose:
            print(f"🔄 Loading model from: {ckpt_path}")
        
        try:
            # Create TensorFlow session and load model
            tf.compat.v1.disable_eager_execution()
            self.sess = tf.compat.v1.Session()
            
            # Import the meta graph
            saver = tf.compat.v1.train.import_meta_graph(ckpt_path + '.meta')
            saver.restore(self.sess, ckpt_path)
            
            # Get the default graph
            self.graph = tf.compat.v1.get_default_graph()
            
            if self.verbose:
                print("✅ Model loaded successfully!")
            return True
            
        except Exception as e:
            if self.verbose:
                print(f"❌ Error loading model: {e}")
                print("🔄 Will create fallback generator.")
            return False
    
    def create_generator_network(self, z_dim: int = 32):
        """Create a simple generator network for music generation."""
        # Input noise vector
        z = tf.compat.v1.placeholder(tf.float32, [None, z_dim], name='z')
        
        # Generator architecture
        with tf.compat.v1.variable_scope('generator'):
            # Dense layers
            h1 = tf.compat.v1.layers.dense(z, 1024, activation=tf.nn.relu)
            h2 = tf.compat.v1.layers.dense(h1, 2048, activation=tf.nn.relu)
            
            # Reshape for convolutional layers
            h2_reshaped = tf.reshape(h2, [-1, 1, 1, 2048])
            
            # Transpose convolutions to generate music
            conv1 = tf.compat.v1.layers.conv2d_transpose(
                h2_reshaped, 512, (1, 4), strides=(1, 2), padding='same', activation=tf.nn.relu
            )
            conv2 = tf.compat.v1.layers.conv2d_transpose(
                conv1, 256, (1, 4), strides=(1, 2), padding='same', activation=tf.nn.relu
            )
            conv3 = tf.compat.v1.layers.conv2d_transpose(
                conv2, 128, (1, 4), strides=(1, 3), padding='same', activation=tf.nn.relu
            )
            
            # Final layer to generate music
            output = tf.compat.v1.layers.conv2d_transpose(
                conv3, self.n_tracks * self.n_pitches, (4, 4), 
                strides=(1, 1), padding='same', activation=tf.nn.sigmoid
            )
            
            # Reshape to music tensor format
            music_output = tf.reshape(output, 
                [-1, self.n_bars, self.n_steps_per_bar, self.n_tracks, self.n_pitches])
        
        return z, music_output
    
    def generate_music(self, n_samples: int = 4, temperature: float = 1.0) -> np.ndarray:
        """Generate music samples with progress indication."""
        if self.verbose:
            print(f"🎼 Generating {n_samples} music samples...")
        
        try:
            # Try to find generator output in the loaded graph
            try:
                # Look for common tensor names
                possible_inputs = ['z:0', 'noise:0', 'input:0', 'generator/input:0']
                possible_outputs = ['generator/output:0', 'generated_music:0', 'fake_data:0']
                
                input_tensor = None
                output_tensor = None
                
                for name in possible_inputs:
                    try:
                        input_tensor = self.graph.get_tensor_by_name(name)
                        break
                    except KeyError:
                        continue
                
                for name in possible_outputs:
                    try:
                        output_tensor = self.graph.get_tensor_by_name(name)
                        break
                    except KeyError:
                        continue
                
                if input_tensor is None or output_tensor is None:
                    raise KeyError("Could not find input/output tensors")
                
                # Generate random noise
                z_dim = input_tensor.shape[-1]
                noise = np.random.normal(0, temperature, (n_samples, z_dim))
                
                # Generate music
                generated_music = self.sess.run(output_tensor, feed_dict={input_tensor: noise})
                
            except KeyError:
                if self.verbose:
                    print("🔄 Creating fallback generator...")
                # Fallback: create our own generator
                z_input, music_output = self.create_generator_network()
                
                # Initialize variables
                init = tf.compat.v1.global_variables_initializer()
                self.sess.run(init)
                
                # Generate random noise
                noise = np.random.normal(0, temperature, (n_samples, 32))
                
                # Generate music
                generated_music = self.sess.run(music_output, feed_dict={z_input: noise})
            
            if self.verbose:
                print(f"✅ Generated music shape: {generated_music.shape}")
            return generated_music
            
        except Exception as e:
            if self.verbose:
                print(f"❌ Error generating music: {e}")
                print("🎲 Using random music as fallback...")
            return self.generate_random_music(n_samples)
    
    def generate_random_music(self, n_samples: int) -> np.ndarray:
        """Generate random music as a fallback."""
        return np.random.random((n_samples, self.n_bars, self.n_steps_per_bar, 
                               self.n_tracks, self.n_pitches))
    
    def postprocess_music(self, music: np.ndarray, threshold: float = 0.5) -> np.ndarray:
        """Postprocess generated music."""
        # Apply threshold to create binary piano roll
        binary_music = (music > threshold).astype(np.float32)
        
        # Remove notes below the lowest pitch
        binary_music[:, :, :, :, :self.lowest_pitch] = 0
        
        return binary_music
    
    def music_to_pianoroll(self, music: np.ndarray, track_idx: int = 1) -> np.ndarray:
        """Convert music tensor to piano roll for visualization."""
        n_bars, n_steps_per_bar, n_tracks, n_pitches = music.shape
        total_steps = n_bars * n_steps_per_bar
        
        # Extract the specific track
        track_data = music[:, :, track_idx, :]
        
        # Reshape to [total_steps, n_pitches]
        pianoroll = track_data.reshape(total_steps, n_pitches)
        
        return pianoroll
    
    def plot_pianoroll(self, music: np.ndarray, sample_idx: int = 0, 
                      track_idx: int = 1, figsize: tuple = (15, 8)):
        """Plot piano roll visualization."""
        pianoroll = self.music_to_pianoroll(music[sample_idx], track_idx)
        
        fig, ax = plt.subplots(figsize=figsize)
        
        # Create the plot
        im = ax.imshow(pianoroll.T, aspect='auto', origin='lower', 
                      cmap='Blues', interpolation='nearest')
        
        # Customize the plot
        ax.set_title(f'{self.track_names[track_idx]} - Sample {sample_idx}', 
                    fontsize=16, fontweight='bold')
        ax.set_xlabel('Time Steps', fontsize=12)
        ax.set_ylabel('MIDI Pitch', fontsize=12)
        
        # Add colorbar
        plt.colorbar(im, ax=ax, label='Note Velocity')
        
        # Add grid for bars
        for bar in range(1, self.n_bars):
            ax.axvline(x=bar * self.n_steps_per_bar, color='red', 
                      linestyle='--', alpha=0.7, linewidth=1)
        
        plt.tight_layout()
        plt.show()
    
    def plot_track_comparison(self, music: np.ndarray, sample_idx: int = 0, 
                            figsize: tuple = (20, 12)):
        """Plot comparison of all tracks."""
        fig, axes = plt.subplots(self.n_tracks, 1, figsize=figsize, sharex=True)
        
        for track_idx in range(self.n_tracks):
            pianoroll = self.music_to_pianoroll(music[sample_idx], track_idx)
            
            im = axes[track_idx].imshow(pianoroll.T, aspect='auto', origin='lower', 
                                      cmap='Blues', interpolation='nearest')
            
            axes[track_idx].set_title(f'{self.track_names[track_idx]}', 
                                    fontsize=14, fontweight='bold')
            axes[track_idx].set_ylabel('MIDI Pitch', fontsize=10)
            
            # Add bar lines
            for bar in range(1, self.n_bars):
                axes[track_idx].axvline(x=bar * self.n_steps_per_bar, 
                                       color='red', linestyle='--', alpha=0.5)
        
        axes[-1].set_xlabel('Time Steps', fontsize=12)
        plt.tight_layout()
        plt.show()
    
    def analyze_music_statistics(self, music: np.ndarray) -> dict:
        """Analyze and return statistics about the generated music."""
        stats = {}
        
        for i, track_name in enumerate(self.track_names[:music.shape[3]]):
            track_data = music[0, :, :, i, :]
            active_notes = np.sum(track_data > 0.5)
            total_notes = track_data.size
            density = active_notes / total_notes * 100
            
            # Pitch range analysis
            active_pitches = np.where(np.sum(track_data, axis=(0, 1)) > 0)[0]
            pitch_range = (active_pitches.min(), active_pitches.max()) if len(active_pitches) > 0 else (0, 0)
            
            stats[track_name] = {
                'active_notes': active_notes,
                'total_notes': total_notes,
                'density': density,
                'pitch_range': pitch_range,
                'unique_pitches': len(active_pitches)
            }
        
        return stats
    
    def save_as_midi(self, music: np.ndarray, filename: str, tempo: int = 120):
        """Save generated music as MIDI file."""
        try:
            import pretty_midi
            
            # Create a MIDI object
            midi = pretty_midi.PrettyMIDI(initial_tempo=tempo)
            
            # Process each track
            for track_idx in range(min(self.n_tracks, music.shape[3])):
                # Create instrument
                instrument = pretty_midi.Instrument(
                    program=0 if track_idx == 0 else track_idx * 8,
                    is_drum=(track_idx == 0),
                    name=self.track_names[track_idx] if track_idx < len(self.track_names) else f'Track_{track_idx}'
                )
                
                # Convert to piano roll
                pianoroll = self.music_to_pianoroll(music[0], track_idx)
                
                # Convert piano roll to MIDI notes
                for step, pitches in enumerate(pianoroll):
                    for pitch, velocity in enumerate(pitches):
                        if velocity > 0:
                            # Calculate timing
                            start_time = step * (60.0 / tempo / 6)
                            end_time = start_time + (60.0 / tempo / 6)
                            
                            # Create note
                            note = pretty_midi.Note(
                                velocity=int(velocity * 127),
                                pitch=pitch,
                                start=start_time,
                                end=end_time
                            )
                            instrument.notes.append(note)
                
                midi.instruments.append(instrument)
            
            # Save MIDI file
            midi.write(filename)
            if self.verbose:
                print(f"🎵 MIDI file saved as: {filename}")
            
        except ImportError:
            print("❌ pretty_midi not installed. Cannot save as MIDI.")
            print("Install with: !pip install pretty_midi")
    
    def save_as_numpy(self, music: np.ndarray, filename: str):
        """Save generated music as numpy array."""
        np.save(filename, music)
        if self.verbose:
            print(f"💾 Music saved as numpy array: {filename}")

✅ Libraries imported successfully!
TensorFlow version: 2.15.0
NumPy version: 1.26.4


In [11]:
# Update this path to point to your pretrained model
MODEL_PATH = "DL_music_MouseGAN/pretrained_models.tar"  # Change this to your model path

# Generation parameters
N_SAMPLES = 4
TEMPERATURE = 1.0
THRESHOLD = 0.5
OUTPUT_DIR = "./generated_music"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize generator
print("🚀 Initializing MuseGAN generator...")
generator = MuseGANGenerator(MODEL_PATH, verbose=True)

🚀 Initializing MuseGAN generator...
🎵 MuseGAN Generator initialized!


In [None]:
print(os.getcwd())

In [12]:
print("🔄 Loading pretrained model...")
model_loaded = generator.load_model()

if model_loaded:
    print("✅ Model loaded successfully!")
else:
    print("⚠️  Using fallback generator (model not found or failed to load)")

🔄 Loading pretrained model...
📦 Extracting DL_music_MouseGAN/pretrained_models.tar to DL_music_MouseGAN/pretrained_models_extracted...


FileNotFoundError: [Errno 2] No such file or directory: 'DL_music_MouseGAN/pretrained_models.tar'

In [None]:
print(f"🎼 Generating {N_SAMPLES} music samples...")
print(f"Parameters: Temperature={TEMPERATURE}, Threshold={THRESHOLD}")

# Generate music
generated_music = generator.generate_music(
    n_samples=N_SAMPLES,
    temperature=TEMPERATURE
)

# Postprocess music
processed_music = generator.postprocess_music(generated_music, THRESHOLD)

print("✅ Music generation complete!")
print(f"Generated music shape: {processed_music.shape}")

In [None]:
print("📊 Analyzing generated music...")

# Get statistics
stats = generator.analyze_music_statistics(processed_music)

# Display statistics in a nice format
print("\n" + "="*60)
print("🎵 GENERATED MUSIC STATISTICS")
print("="*60)

for track_name, track_stats in stats.items():
    print(f"\n🎹 {track_name.upper()}:")
    print(f"   Active Notes: {track_stats['active_notes']:,}")
    print(f"   Density: {track_stats['density']:.1f}%")
    print(f"   Pitch Range: {track_stats['pitch_range'][0]} - {track_stats['pitch_range'][1]}")
    print(f"   Unique Pitches: {track_stats['unique_pitches']}")

In [None]:
print("🎨 Creating visualizations...")

# Plot individual track (Piano by default)
print("\n📊 Piano Roll Visualization (Piano Track):")
generator.plot_pianoroll(processed_music, sample_idx=0, track_idx=1)

In [None]:
print("📊 All Tracks Comparison:")
generator.plot_track_comparison(processed_music, sample_idx=0)

In [None]:
print("💾 Saving generated music...")

# Save each sample
for i in range(N_SAMPLES):
    sample_music = processed_music[i:i+1]
    
    # Save as numpy
    numpy_path = os.path.join(OUTPUT_DIR, f'generated_music_sample_{i}.npy')
    generator.save_as_numpy(sample_music, numpy_path)
    
    # Save as MIDI
    midi_path = os.path.join(OUTPUT_DIR, f'generated_music_sample_{i}.mid')
    generator.save_as_midi(sample_music, midi_path)

print(f"\n✅ All files saved to: {OUTPUT_DIR}")
