In [9]:
import pennylane as qml
import numpy as np
import torch
import random
from torch.optim import Adam
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import time
import matplotlib.pyplot as plt
import psutil
import gc
from functools import partial
from copy import deepcopy

# Load the existing code from paste.txt
def set_seeds(seed=42):
    """Set seeds for reproducibility"""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

class QuantumAutoencoder:
    def __init__(self, n_qubits, latent_qubits, depth=4, params=None):
        self.n_qubits = n_qubits
        self.latent_qubits = latent_qubits
        self.depth = depth
        self.dev = qml.device("default.qubit", wires=n_qubits)
        
        self.n_params = self._calculate_params()
        if params is not None:
            self.params = params
        else:
            self.params = self._initialize_parameters()
        
        self.encoder = qml.QNode(self._encoder_circuit, self.dev, interface="torch")
        self.decoder = qml.QNode(self._decoder_circuit, self.dev, interface="torch")
        
        # Add importance scores for pruning
        self.importance_scores = np.ones_like(self.params)
    
    def _calculate_params(self):
        """Calculate total number of parameters"""
        params_per_qubit = 6
        params_per_layer = self.n_qubits * params_per_qubit
        total_layers = 2 * self.depth
        return params_per_layer * total_layers
    
    def _initialize_parameters(self):
        """Improved parameter initialization"""
        params = np.zeros(self.n_params)
        # Xavier/Glorot initialization scaled to [0, 2π]
        scale = np.sqrt(2.0 / (self.n_qubits + self.latent_qubits)) * np.pi
        for i in range(self.n_params):
            params[i] = np.random.uniform(-scale, scale)
        return params
    
    def _encoder_circuit(self, data, params):
        """Enhanced encoder circuit"""
        qml.AmplitudeEmbedding(data, wires=range(self.n_qubits), normalize=True)
        
        param_idx = 0
        for d in range(self.depth):
            # More rotations per qubit
            for i in range(self.n_qubits):
                qml.Rot(params[param_idx], params[param_idx + 1], 
                       params[param_idx + 2], wires=i)
                qml.RX(params[param_idx + 3], wires=i)
                qml.RY(params[param_idx + 4], wires=i)
                qml.RZ(params[param_idx + 5], wires=i)
                param_idx += 6
            
            # Enhanced entanglement pattern
            for i in range(self.n_qubits - 1):
                qml.CRZ(params[param_idx % self.n_params], wires=[i, i + 1])
                qml.CNOT(wires=[i, i + 1])
            
            if self.n_qubits > 2:
                for i in range(0, self.n_qubits - 2, 2):
                    qml.CRX(params[(param_idx + 1) % self.n_params], wires=[i, i + 2])
                    qml.CNOT(wires=[i, i + 2])
        
        return qml.state()
    
    def _decoder_circuit(self, latent_state, params):
        """Enhanced decoder circuit"""
        qml.QubitStateVector(latent_state, wires=range(self.latent_qubits))
        
        # Better initialization of non-latent qubits
        for i in range(self.latent_qubits, self.n_qubits):
            qml.Hadamard(wires=i)
            qml.RY(np.pi/4, wires=i)
            qml.RZ(np.pi/4, wires=i)
        
        param_idx = self.n_params // 2
        for d in range(self.depth):
            for i in range(self.n_qubits):
                qml.Rot(params[param_idx], params[param_idx + 1], 
                       params[param_idx + 2], wires=i)
                qml.RX(params[param_idx + 3], wires=i)
                qml.RY(params[param_idx + 4], wires=i)
                qml.RZ(params[param_idx + 5], wires=i)
                param_idx += 6
            
            for i in range(self.n_qubits - 1):
                qml.CRZ(params[param_idx % self.n_params], wires=[i, i + 1])
                qml.CNOT(wires=[i, i + 1])
            
            if self.n_qubits > 2:
                for i in range(0, self.n_qubits - 2, 2):
                    qml.CRX(params[(param_idx + 1) % self.n_params], wires=[i, i + 2])
                    qml.CNOT(wires=[i, i + 2])
        
        return qml.state()
    
    def get_latent_state(self, encoded_state):
        """Improved latent state extraction"""
        if torch.is_tensor(encoded_state):
            encoded_state = encoded_state.detach().numpy()
        
        state_matrix = np.outer(encoded_state, np.conjugate(encoded_state))
        dim = 2**self.latent_qubits
        reduced_matrix = state_matrix[:dim, :dim]
        
        eigenvals, eigenvecs = np.linalg.eigh(reduced_matrix)
        # Use top 2 eigenvectors with proper weighting
        top_k = 2
        top_indices = np.argsort(eigenvals)[-top_k:]
        weights = eigenvals[top_indices] / np.sum(eigenvals[top_indices])
        latent_state = np.sum([w * eigenvecs[:, i] for w, i in zip(weights, top_indices)], axis=0)
        return latent_state / np.linalg.norm(latent_state)
    
    def forward(self, x):
        """Forward pass"""
        encoded = self.encoder(x, self.params)
        latent = self.get_latent_state(encoded)
        decoded = self.decoder(latent, self.params)
        return decoded

def preprocess_data(X):
    """Enhanced data preprocessing"""
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    # Additional normalization step
    X_scaled = X_scaled / np.max(np.abs(X_scaled))
    return X_scaled / np.sqrt(np.sum(X_scaled**2, axis=1))[:, np.newaxis]

# Function to measure memory usage
def get_memory_usage():
    """Get current memory usage in MB"""
    process = psutil.Process()
    memory_info = process.memory_info()
    return memory_info.rss / (1024 * 1024)  # Convert bytes to MB

# Rényi Entropy-based Pruning
def calculate_renyi_entropy(param_values, alpha=2.0):
    """
    Calculate the Rényi entropy of parameter values
    
    Args:
        param_values: Parameter values
        alpha: Rényi entropy parameter (alpha=1 corresponds to Shannon entropy,
               alpha=2 corresponds to collision entropy)
    
    Returns:
        Rényi entropy value
    """
    # Normalize the values to form a probability distribution
    abs_values = np.abs(param_values)
    prob_dist = abs_values / np.sum(abs_values)
    
    # Handle zero probabilities by adding a small epsilon
    epsilon = 1e-10
    prob_dist = prob_dist + epsilon
    prob_dist = prob_dist / np.sum(prob_dist)
    
    if alpha == 1:
        # Shannon entropy for alpha=1
        entropy = -np.sum(prob_dist * np.log(prob_dist))
    else:
        # Rényi entropy for alpha≠1
        entropy = 1 / (1 - alpha) * np.log(np.sum(prob_dist ** alpha))
    
    return entropy

def calculate_importance_scores(model, X_sample, alpha=2.0):
    """
    Calculate importance scores for each parameter based on Rényi entropy
    
    Args:
        model: QuantumAutoencoder model
        X_sample: Sample data for evaluation
        alpha: Rényi entropy parameter
    
    Returns:
        Array of importance scores
    """
    n_params = len(model.params)
    importance_scores = np.zeros(n_params)
    
    # Create parameter masks for each parameter
    for i in range(n_params):
        # Create a copy of the parameters
        temp_params = np.copy(model.params)
        
        # Create parameter perturbations
        perturbations = np.linspace(-0.1, 0.1, 5) * np.pi
        entropy_values = []
        
        for perturbation in perturbations:
            temp_params[i] = model.params[i] + perturbation
            
            # Collect outputs for the sample data
            outputs = []
            for x in X_sample:
                model.params = temp_params
                decoded = model.forward(x)
                if torch.is_tensor(decoded):
                    decoded = decoded.detach().numpy()
                outputs.append(decoded)
            
            # Calculate entropy of the output distribution
            outputs = np.array(outputs).flatten()
            entropy = calculate_renyi_entropy(outputs, alpha=alpha)
            entropy_values.append(entropy)
        
        # Higher entropy variation means the parameter is more important
        importance = np.std(entropy_values)
        importance_scores[i] = importance
    
    # Restore original parameters
    model.params = np.copy(model.params)
    
    # Normalize importance scores
    importance_scores = importance_scores / np.sum(importance_scores)
    return importance_scores

def prune_model(model, pruning_ratio=0.5, X_sample=None, alpha=2.0):
    """
    Prune the quantum model based on Rényi entropy importance scores
    
    Args:
        model: QuantumAutoencoder model
        pruning_ratio: Ratio of parameters to prune (0.0 to 1.0)
        X_sample: Sample data for calculating importance scores
        alpha: Rényi entropy parameter
        
    Returns:
        Pruned model
    """
    if X_sample is not None:
        # Calculate importance scores
        importance_scores = calculate_importance_scores(model, X_sample, alpha)
    else:
        importance_scores = model.importance_scores
    
    # Create a pruned model
    pruned_model = deepcopy(model)
    pruned_model.importance_scores = importance_scores
    
    # Sort parameters by importance scores
    indices = np.argsort(importance_scores)
    n_prune = int(len(indices) * pruning_ratio)
    
    # Zero out the least important parameters
    prune_indices = indices[:n_prune]
    pruned_model.params[prune_indices] = 0.0
    
    return pruned_model

# Quantization
def quantize_model(model, bits=8):
    """
    Quantize model parameters to reduce precision
    
    Args:
        model: QuantumAutoencoder model
        bits: Number of bits for quantization (1-32)
        
    Returns:
        Quantized model
    """
    quantized_model = deepcopy(model)
    
    # Determine the range of parameters
    param_min = np.min(model.params)
    param_max = np.max(model.params)
    param_range = param_max - param_min
    
    # Calculate the quantization step
    levels = 2**bits - 1
    step = param_range / levels
    
    # Quantize the parameters
    quantized_params = np.round((model.params - param_min) / step) * step + param_min
    quantized_model.params = quantized_params
    
    # Store quantization info in the model for potential dequantization
    quantized_model.quant_info = {
        'bits': bits,
        'param_min': param_min,
        'param_max': param_max,
        'levels': levels
    }
    
    return quantized_model

def evaluate_model(model, X_test, metrics_only=False):
    """
    Evaluate the model on test data
    
    Args:
        model: QuantumAutoencoder model
        X_test: Test data
        metrics_only: If True, only return metrics, not predictions
        
    Returns:
        Dictionary of evaluation metrics (and predictions if metrics_only=False)
    """
    start_time = time.time()
    losses = []
    predictions = []
    
    memory_before = get_memory_usage()
    for x in X_test:
        decoded = model.forward(x)
        if torch.is_tensor(decoded):
            decoded = decoded.detach().numpy()
        
        # Calculate reconstruction loss
        loss = np.mean((np.real(decoded) - x)**2)
        losses.append(loss)
        
        if not metrics_only:
            predictions.append(decoded)
    
    memory_after = get_memory_usage()
    inference_time = time.time() - start_time
    test_loss = np.mean(losses)
    test_accuracy = 1 / (1 + test_loss)
    
    non_zero_params = np.count_nonzero(model.params)
    model_sparsity = 1.0 - (non_zero_params / len(model.params))
    
    metrics = {
        'loss': test_loss,
        'accuracy': test_accuracy,
        'inference_time': inference_time,
        'inference_memory': memory_after - memory_before,
        'params_total': len(model.params),
        'params_nonzero': non_zero_params,
        'sparsity': model_sparsity
    }
    
    if not metrics_only:
        metrics['predictions'] = predictions
    
    return metrics

def train_and_evaluate_models(n_epochs=100, batch_size=4, learning_rate=0.002, 
                              pruning_ratio=0.5, quant_bits=8, alpha=2.0, seed=42):
    """
    Train the original model, prune it, quantize it, and evaluate all three
    
    Args:
        n_epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Learning rate for optimizer
        pruning_ratio: Ratio of parameters to prune
        quant_bits: Number of bits for quantization
        seed: Random seed for reproducibility
        
    Returns:
        Dictionary of models and results
    """
    set_seeds(seed)
    
    # Generate synthetic data
    n_features = 16
    X, y = make_classification(
        n_samples=200,
        n_features=n_features,
        n_classes=2,
        n_informative=6,
        n_redundant=0,
        n_clusters_per_class=2,
        class_sep=2.5,
        random_state=seed
    )
    
    X_train, X_test = train_test_split(X, test_size=0.2, random_state=seed)
    X_train = preprocess_data(X_train)
    X_test = preprocess_data(X_test)
    
    # Initialize the original model
    n_qubits = int(np.log2(X_train.shape[1]))
    latent_qubits = n_qubits - 1
    original_model = QuantumAutoencoder(n_qubits=n_qubits, latent_qubits=latent_qubits)
    
    # Train the original model
    print("\nTraining the original model...")
    train_start_time = time.time()
    memory_before_training = get_memory_usage()
    
    params = torch.tensor(original_model.params, requires_grad=True)
    optimizer = Adam([params], lr=learning_rate)
    
    best_accuracy = 0.0
    best_params = None
    patience = 10
    patience_counter = 0
    
    metrics = {'train_losses': [], 'train_accuracies': [], 
              'val_losses': [], 'val_accuracies': []}
    
    for epoch in range(n_epochs):
        original_model.params = params.detach().numpy()
        epoch_loss = 0
        n_batches = max(1, len(X_train) // batch_size)
        
        for _ in range(n_batches):
            optimizer.zero_grad()
            batch_idx = np.random.choice(len(X_train), min(batch_size, len(X_train)))
            batch_data = X_train[batch_idx]
            
            total_loss = torch.tensor(0.0, requires_grad=True)
            for x in batch_data:
                decoded = original_model.forward(x)
                if torch.is_tensor(decoded):
                    decoded = decoded.real
                decoded = torch.tensor(np.real(decoded), dtype=torch.float64, requires_grad=True)
                x_tensor = torch.tensor(x, dtype=torch.float32)
                
                # Enhanced loss calculation
                reconstruction_loss = torch.mean((decoded - x_tensor)**2)
                l2_reg = 0.0001 * torch.sum(params**2)
                loss = reconstruction_loss + l2_reg
                
                total_loss = total_loss + loss
            
            avg_loss = total_loss / len(batch_data)
            avg_loss.backward()
            torch.nn.utils.clip_grad_norm_([params], max_norm=1.0)
            optimizer.step()
            epoch_loss += avg_loss.item()
        
        train_loss = epoch_loss / n_batches
        train_accuracy = 1 / (1 + train_loss)
        
        # Validation
        val_losses = []
        with torch.no_grad():
            for x in X_test:
                decoded = original_model.forward(x)
                decoded = np.real(decoded) if not torch.is_tensor(decoded) else decoded.real.numpy()
                val_loss = np.mean((decoded - x)**2)
                val_losses.append(val_loss)
        
        val_loss = np.mean(val_losses)
        val_accuracy = 1 / (1 + val_loss)
        
        metrics['train_losses'].append(train_loss)
        metrics['train_accuracies'].append(train_accuracy)
        metrics['val_losses'].append(val_loss)
        metrics['val_accuracies'].append(val_accuracy)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {train_loss:.4f} | Val Acc: {val_accuracy:.4f}")
        
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_params = params.detach().clone()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print("Early stopping triggered!")
            break
        
        if val_accuracy >= 0.95 and train_accuracy >= 0.95:
            print("Target accuracy achieved!")
            break
    
    original_model.params = best_params.numpy()
    train_time = time.time() - train_start_time
    memory_after_training = get_memory_usage()
    training_memory = memory_after_training - memory_before_training
    
    # Evaluate the original model
    print("\nEvaluating the original model...")
    original_results = evaluate_model(original_model, X_test)
    
    # Calculate importance scores for the original model
    print("\nCalculating importance scores for pruning...")
    X_sample = X_train[:min(10, len(X_train))]  # Use a small sample for calculating importance
    importance_scores = calculate_importance_scores(original_model, X_sample, alpha=alpha)
    original_model.importance_scores = importance_scores
    
    # Create and evaluate the pruned model
    print(f"\nPruning the model (pruning ratio: {pruning_ratio}, alpha: {alpha})...")
    pruned_model = prune_model(original_model, pruning_ratio=pruning_ratio, alpha=alpha)
    
    # Fine-tune the pruned model
    print("\nFine-tuning the pruned model...")
    pruned_train_start_time = time.time()
    pruned_memory_before = get_memory_usage()
    
    params = torch.tensor(pruned_model.params, requires_grad=True)
    optimizer = Adam([params], lr=learning_rate / 2)  # Lower learning rate for fine-tuning
    
    for epoch in range(n_epochs // 3):  # Fewer epochs for fine-tuning
        pruned_model.params = params.detach().numpy()
        epoch_loss = 0
        n_batches = max(1, len(X_train) // batch_size)
        
        for _ in range(n_batches):
            optimizer.zero_grad()
            batch_idx = np.random.choice(len(X_train), min(batch_size, len(X_train)))
            batch_data = X_train[batch_idx]
            
            total_loss = torch.tensor(0.0, requires_grad=True)
            for x in batch_data:
                decoded = pruned_model.forward(x)
                if torch.is_tensor(decoded):
                    decoded = decoded.real
                decoded = torch.tensor(np.real(decoded), dtype=torch.float64, requires_grad=True)
                x_tensor = torch.tensor(x, dtype=torch.float32)
                
                reconstruction_loss = torch.mean((decoded - x_tensor)**2)
                # No regularization for pruned model fine-tuning
                loss = reconstruction_loss
                
                total_loss = total_loss + loss
            
            avg_loss = total_loss / len(batch_data)
            avg_loss.backward()
            optimizer.step()
            
            # Re-zero the pruned parameters to maintain sparsity
            with torch.no_grad():
                prune_mask = (pruned_model.params == 0.0)
                params.data[prune_mask] = 0.0
            
            epoch_loss += avg_loss.item()
        
        train_loss = epoch_loss / n_batches
        
        if (epoch + 1) % 5 == 0:
            print(f"Fine-tune Epoch {epoch+1} | Train Loss: {train_loss:.4f}")
    
    pruned_model.params = params.detach().numpy()
    pruned_train_time = time.time() - pruned_train_start_time
    pruned_memory_after = get_memory_usage()
    pruned_training_memory = pruned_memory_after - pruned_memory_before
    
    print("\nEvaluating the pruned model...")
    pruned_results = evaluate_model(pruned_model, X_test)
    pruned_results['training_time'] = pruned_train_time
    pruned_results['training_memory'] = pruned_training_memory
    
    # Create and evaluate the quantized model
    print(f"\nQuantizing the pruned model (bits: {quant_bits})...")
    quantized_model = quantize_model(pruned_model, bits=quant_bits)
    
    # Fine-tune the quantized model
    print("\nFine-tuning the quantized model...")
    quant_train_start_time = time.time()
    quant_memory_before = get_memory_usage()
    
    params = torch.tensor(quantized_model.params, requires_grad=True)
    optimizer = Adam([params], lr=learning_rate / 3)  # Even lower learning rate for quantized model
    
    for epoch in range(n_epochs // 4):  # Even fewer epochs for fine-tuning
        quantized_model.params = params.detach().numpy()
        epoch_loss = 0
        n_batches = max(1, len(X_train) // batch_size)
        
        for _ in range(n_batches):
            optimizer.zero_grad()
            batch_idx = np.random.choice(len(X_train), min(batch_size, len(X_train)))
            batch_data = X_train[batch_idx]
            
            total_loss = torch.tensor(0.0, requires_grad=True)
            for x in batch_data:
                decoded = quantized_model.forward(x)
                if torch.is_tensor(decoded):
                    decoded = decoded.real
                decoded = torch.tensor(np.real(decoded), dtype=torch.float64, requires_grad=True)
                x_tensor = torch.tensor(x, dtype=torch.float32)
                
                reconstruction_loss = torch.mean((decoded - x_tensor)**2)
                loss = reconstruction_loss
                
                total_loss = total_loss + loss
            
            avg_loss = total_loss / len(batch_data)
            avg_loss.backward()
            optimizer.step()
            
            # Re-quantize parameters after each update
            with torch.no_grad():
                # Keep pruned parameters at zero
                prune_mask = (pruned_model.params == 0.0)
                params.data[prune_mask] = 0.0
                
                # Re-quantize non-zero parameters
                param_min = quantized_model.quant_info['param_min']
                param_max = quantized_model.quant_info['param_max']
                levels = quantized_model.quant_info['levels']
                step = (param_max - param_min) / levels
                
                # Clip to the range
                params.data = torch.clamp(params.data, param_min, param_max)
                
                # Quantize
                params.data = torch.round((params.data - param_min) / step) * step + param_min
            
            epoch_loss += avg_loss.item()
        
        train_loss = epoch_loss / n_batches
        
        if (epoch + 1) % 5 == 0:
            print(f"Quant Fine-tune Epoch {epoch+1} | Train Loss: {train_loss:.4f}")
    
    quantized_model.params = params.detach().numpy()
    quant_train_time = time.time() - quant_train_start_time
    quant_memory_after = get_memory_usage()
    quant_training_memory = quant_memory_after - quant_memory_before
    
    print("\nEvaluating the quantized model...")
    quantized_results = evaluate_model(quantized_model, X_test)
    quantized_results['training_time'] = quant_train_time
    quantized_results['training_memory'] = quant_training_memory
    
    # Add training time and memory to original results
    original_results['training_time'] = train_time
    original_results['training_memory'] = training_memory
    
    # Compare models
    print("\n=== Model Comparison ===")
    print(f"{'Model':<10} | {'Accuracy':<10} | {'Train Time':<12} | {'Infer Time':<12} | {'Train Mem (MB)':<15} | {'Infer Mem (MB)':<15} | {'Parameters':<12} | {'Sparsity':<10}")
    print("-" * 110)
    
    print(f"{'Original':<10} | {original_results['accuracy']:<10.4f} | {original_results['training_time']:<12.4f}s | {original_results['inference_time']:<12.4f}s | {original_results['training_memory']:<15.2f} | {original_results['inference_memory']:<15.2f} | {original_results['params_nonzero']}/{original_results['params_total']} | {original_results['sparsity']:<10.2%}")
    
    print(f"{'Pruned':<10} | {pruned_results['accuracy']:<10.4f} | {pruned_results['training_time']:<12.4f}s | {pruned_results['inference_time']:<12.4f}s | {pruned_results['training_memory']:<15.2f} | {pruned_results['inference_memory']:<15.2f} | {pruned_results['params_nonzero']}/{pruned_results['params_total']} | {pruned_results['sparsity']:<10.2%}")
    
    print(f"{'Quantized':<10} | {quantized_results['accuracy']:<10.4f} | {quantized_results['training_time']:<12.4f}s | {quantized_results['inference_time']:<12.4f}s | {quantized_results['training_memory']:<15.2f} | {quantized_results['inference_memory']:<15.2f} | {quantized_results['params_nonzero']}/{quantized_results['params_total']} | {quantized_results['sparsity']:<10.2%}")
    
    # Calculate memory and time savings
    mem_saving_pruned = (original_results['inference_memory'] - pruned_results['inference_memory']) / original_results['inference_memory'] * 100
    mem_saving_quantized = (original_results['inference_memory'] - quantized_results['inference_memory']) / original_results['inference_memory'] * 100
    
    time_saving_pruned = (original_results['inference_time'] - pruned_results['inference_time']) / original_results['inference_time'] * 100
    time_saving_quantized = (original_results['inference_time'] - quantized_results['inference_time']) / original_results['inference_time'] * 100
    
    accuracy_change_pruned = (pruned_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    accuracy_change_quantized = (quantized_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    
    print("\n=== Savings Summary ===")
    print(f"Pruned Model Memory Savings: {mem_saving_pruned:.2f}%")
    print(f"Pruned Model Time Savings: {time_saving_pruned:.2f}%")
    print(f"Pruned Model Accuracy Change: {accuracy_change_pruned:.2f}%")
    print(f"Quantized Model Memory Savings: {mem_saving_quantized:.2f}%")
    print(f"Quantized Model Time Savings: {time_saving_quantized:.2f}%")
    print(f"Quantized Model Accuracy Change: {accuracy_change_quantized:.2f}%")
    
    # Plot metrics for visualization
    plt.figure(figsize=(15, 10))
    
    # Plot accuracy comparison
    plt.subplot(2, 2, 1)
    models = ['Original', 'Pruned', 'Quantized']
    accuracies = [original_results['accuracy'], pruned_results['accuracy'], quantized_results['accuracy']]
    plt.bar(models, accuracies, color=['blue', 'green', 'orange'])
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Comparison')
    plt.ylim(0.5, 1.0)  # Set y-axis limits for better visualization
    
    # Plot inference time comparison
    plt.subplot(2, 2, 2)
    inference_times = [original_results['inference_time'], pruned_results['inference_time'], quantized_results['inference_time']]
    plt.bar(models, inference_times, color=['blue', 'green', 'orange'])
    plt.ylabel('Inference Time (s)')
    plt.title('Inference Time Comparison')
    
    # Plot inference memory comparison
    plt.subplot(2, 2, 3)
    inference_memory = [original_results['inference_memory'], pruned_results['inference_memory'], quantized_results['inference_memory']]
    plt.bar(models, inference_memory, color=['blue', 'green', 'orange'])
    plt.ylabel('Inference Memory (MB)')
    plt.title('Inference Memory Comparison')
    
    # Plot parameter counts
    plt.subplot(2, 2, 4)
    params_nonzero = [original_results['params_nonzero'], pruned_results['params_nonzero'], quantized_results['params_nonzero']]
    plt.bar(models, params_nonzero, color=['blue', 'green', 'orange'])
    plt.ylabel('Non-zero Parameters')
    plt.title('Model Parameters Comparison')
    
    plt.tight_layout()
    plt.savefig('model_comparison.png')
    plt.close()
    
    # Return all models and their results
    return {
        'original': {
            'model': original_model,
            'results': original_results
        },
        'pruned': {
            'model': pruned_model,
            'results': pruned_results
        },
        'quantized': {
            'model': quantized_model,
            'results': quantized_results
        }
    }

In [11]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def visualize_renyi_entropy_heatmap(model, X_sample, alpha_range=[0.5, 1.0, 2.0, 3.0, 4.0]):
    """
    Create a heatmap showing how Rényi entropy changes across different alpha values
    
    Args:
        model: QuantumAutoencoder model
        X_sample: Sample data for evaluation
        alpha_range: Range of alpha values to test
    """
    n_params = len(model.params)
    n_to_show = min(100, n_params)  # Show at most 100 parameters for clarity
    
    # Select parameter indices
    if n_params > n_to_show:
        indices = np.linspace(0, n_params-1, n_to_show, dtype=int)
    else:
        indices = np.arange(n_params)
    
    # Create a matrix to store entropy values
    entropy_matrix = np.zeros((len(alpha_range), len(indices)))
    
    # Calculate entropy for each alpha and parameter
    for i, alpha in enumerate(alpha_range):
        for j, param_idx in enumerate(indices):
            # Create a copy of the parameters
            temp_params = np.copy(model.params)
            
            # Zero out the parameter
            temp_params[param_idx] = 0.0
            
            # Collect outputs with zero parameter
            outputs = []
            model.params = temp_params
            for x in X_sample:
                decoded = model.forward(x)
                if hasattr(decoded, 'detach'):
                    decoded = decoded.detach().numpy()
                outputs.append(decoded)
            
            # Calculate entropy of the output distribution
            outputs = np.array(outputs).flatten()
            
            # Normalize the values to form a probability distribution
            abs_values = np.abs(outputs)
            prob_dist = abs_values / np.sum(abs_values)
            
            # Handle zero probabilities by adding a small epsilon
            epsilon = 1e-10
            prob_dist = prob_dist + epsilon
            prob_dist = prob_dist / np.sum(prob_dist)
            
            if alpha == 1:
                # Shannon entropy for alpha=1
                entropy = -np.sum(prob_dist * np.log(prob_dist))
            else:
                # Rényi entropy for alpha≠1
                entropy = 1 / (1 - alpha) * np.log(np.sum(prob_dist ** alpha))
            
            entropy_matrix[i, j] = entropy
    
    # Restore original parameters
    model.params = np.copy(model.params)
    
    # Create the heatmap
    plt.figure(figsize=(12, 6))
    
    # Create a custom colormap from blue to red
    colors = [(0, 0, 1), (1, 1, 1), (1, 0, 0)]  # Blue -> White -> Red
    cmap = LinearSegmentedColormap.from_list("entropy_cmap", colors, N=100)
    
    # Normalize the entropy values to [-1, 1] for better visualization
    entropy_norm = 2 * (entropy_matrix - np.min(entropy_matrix)) / (np.max(entropy_matrix) - np.min(entropy_matrix)) - 1
    
    # Plot the heatmap
    plt.imshow(entropy_norm, cmap=cmap, aspect='auto')
    plt.colorbar(label='Normalized Entropy')
    plt.xlabel('Parameter Index')
    plt.ylabel('Alpha Value')
    plt.title('Rényi Entropy Across Alpha Values and Parameters')
    
    # Set y-ticks to alpha values
    plt.yticks(np.arange(len(alpha_range)), [f"{alpha:.1f}" for alpha in alpha_range])
    
    # Set x-ticks to parameter indices
    plt.xticks(np.arange(0, len(indices), max(1, len(indices)//10)), 
              [f"{idx}" for idx in indices[::max(1, len(indices)//10)]])
    
    plt.tight_layout()
    plt.savefig('renyi_entropy_heatmap.png')
    plt.close()

def plot_circuit_diagram(model, filename='quantum_circuit.png'):
    """
    Plot the quantum circuit diagram of the model
    
    Args:
        model: QuantumAutoencoder model
        filename: Output filename for the circuit diagram
    """
    import pennylane as qml
    
    # Create a new figure
    plt.figure(figsize=(12, 8))
    
    # Get a sample input
    x_sample = np.random.random(2**model.n_qubits)
    x_sample = x_sample / np.linalg.norm(x_sample)
    
    # Create a circuit diagram for the encoder
    encoder_fig, encoder_ax = qml.draw_mpl(model.encoder, expansion_strategy="device")(x_sample, model.params)
    encoder_ax.set_title("Encoder Circuit")
    
    # Save the encoder circuit diagram
    encoder_fig.tight_layout()
    encoder_fig.savefig('encoder_circuit.png')
    
    # Create a latent state
    encoded_state = model.encoder(x_sample, model.params)
    latent_state = model.get_latent_state(encoded_state)
    
    # Create a circuit diagram for the decoder
    decoder_fig, decoder_ax = qml.draw_mpl(model.decoder, expansion_strategy="device")(latent_state, model.params)
    decoder_ax.set_title("Decoder Circuit")
    
    # Save the decoder circuit diagram
    decoder_fig.tight_layout()
    decoder_fig.savefig('decoder_circuit.png')
    
    # Combined image for both encoder and decoder
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 1, 1)
    plt.imshow(plt.imread('encoder_circuit.png'))
    plt.axis('off')
    
    plt.subplot(2, 1, 2)
    plt.imshow(plt.imread('decoder_circuit.png'))
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    
    return filename

def visualize_model_size_comparison(original_results, pruned_results, quantized_results):
    """
    Visualize the model size comparison between original, pruned, and quantized models
    
    Args:
        original_results: Results dictionary for the original model
        pruned_results: Results dictionary for the pruned model
        quantized_results: Results dictionary for the quantized model
    """
    # Calculate model sizes
    param_size_original = original_results['params_total'] * 8  # 8 bytes for float64
    param_size_pruned = pruned_results['params_nonzero'] * 8  # 8 bytes for float64
    param_size_quantized = quantize

In [13]:
import pennylane as qml
import numpy as np
import torch
import random
from torch.optim import Adam
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import time
import matplotlib.pyplot as plt
import psutil
import gc
from functools import partial
from copy import deepcopy
import os
import sys

# Import the quantum model optimization code
from quantum_model_optimization import *

# Import visualization helpers
from visualization_helper import *

def main():
    """
    Main function to run the complete model optimization pipeline
    """
    # Set random seed for reproducibility
    SEED = 42
    set_seeds(SEED)
    
    # Create a directory for results if it doesn't exist
    results_dir = "optimization_results"
    os.makedirs(results_dir, exist_ok=True)
    os.chdir(results_dir)
    
    # Configuration parameters
    config = {
        'n_epochs': 100,
        'batch_size': 4,
        'learning_rate': 0.002,
        'pruning_ratio': 0.5,
        'quant_bits': 8,
        'alpha': 2.0,
        'seed': SEED,
    }
    
    print("=== Quantum Model Optimization Pipeline ===")
    print("\nConfiguration:")
    for key, value in config.items():
        print(f"- {key}: {value}")
    
    # Train and evaluate models
    print("\nRunning main optimization pipeline...")
    results = train_and_evaluate_models(
        n_epochs=config['n_epochs'],
        batch_size=config['batch_size'],
        learning_rate=config['learning_rate'],
        pruning_ratio=config['pruning_ratio'],
        quant_bits=config['quant_bits'],
        alpha=config['alpha'],
        seed=config['seed']
    )
    
    # Extract models
    original_model = results['original']['model']
    pruned_model = results['pruned']['model']
    quantized_model = results['quantized']['model']
    
    # Extract results
    original_results = results['original']['results']
    pruned_results = results['pruned']['results']
    quantized_results = results['quantized']['results']
    
    # Generate data for additional analysis
    n_features = 16
    X, y = make_classification(
        n_samples=200,
        n_features=n_features,
        n_classes=2,
        n_informative=6,
        n_redundant=0,
        n_clusters_per_class=2,
        class_sep=2.5,
        random_state=SEED
    )
    
    X_train, X_test = train_test_split(X, test_size=0.2, random_state=SEED)
    X_train = preprocess_data(X_train)
    X_test = preprocess_data(X_test)
    
    # Additional visualizations
    print("\nGenerating parameter distribution visualization...")
    visualize_parameter_distribution(original_model, pruned_model, quantized_model)
    
    print("\nGenerating importance vs value visualization...")
    visualize_importance_vs_value(original_model)
    
    print("\nVisualizing model size comparison...")
    visualize_model_size_comparison(original_results, pruned_results, quantized_results)
    
    print("\nVisualizing parameter importance...")
    visualize_parameter_importance(original_model)
    
    print("\nVisualizing latent space...")
    visualize_latent_space(original_model, X_test)
    
    print("\nVisualizing reconstruction quality...")
    visualize_reconstruction_quality(original_model, pruned_model, quantized_model, X_test)
    
    print("\nGenerating Rényi entropy heatmap...")
    X_sample = X_train[:min(10, len(X_train))]
    visualize_renyi_entropy_heatmap(original_model, X_sample)
    
    try:
        print("\nGenerating circuit diagrams...")
        plot_circuit_diagram(original_model)
    except Exception as e:
        print(f"Warning: Could not generate circuit diagrams. Error: {str(e)}")
    
    # Sensitivity analyses
    print("\nRunning pruning sensitivity analysis...")
    pruning_sensitivity = run_pruning_sensitivity_analysis(
        original_model=original_model,
        X_train=X_train,
        X_test=X_test,
        pruning_ratios=[0.3, 0.5, 0.7, 0.9]
    )
    
    print("\nRunning alpha sensitivity analysis...")
    alpha_sensitivity = run_alpha_sensitivity_analysis(
        original_model=original_model,
        X_train=X_train,
        X_test=X_test,
        alphas=[0.5, 1.0, 2.0, 4.0]
    )
    
    print("\nRunning quantization sensitivity analysis...")
    quant_sensitivity = run_quantization_sensitivity_analysis(
        pruned_model=pruned_model,
        X_test=X_test,
        bits_list=[4, 8, 12, 16, 24, 32]
    )
    
    # Create summary report
    print("\nCreating summary report...")
    create_comparison_report(
        original_results=original_results,
        pruned_results=pruned_results,
        quantized_results=quantized_results,
        pruning_ratio=config['pruning_ratio'],
        quant_bits=config['quant_bits'],
        alpha=config['alpha']
    )
    
    print("\n=== Optimization Pipeline Completed ===")
    print(f"All results and visualizations have been saved to: {results_dir}")
    print("\nSummary of improvements:")
    
    # Calculate improvements
    accuracy_change_pruned = (pruned_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    accuracy_change_quantized = (quantized_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    
    time_saving_pruned = (original_results['inference_time'] - pruned_results['inference_time']) / original_results['inference_time'] * 100
    time_saving_quantized = (original_results['inference_time'] - quantized_results['inference_time']) / original_results['inference_time'] * 100
    
    mem_saving_pruned = (original_results['inference_memory'] - pruned_results['inference_memory']) / original_results['inference_memory'] * 100
    mem_saving_quantized = (original_results['inference_memory'] - quantized_results['inference_memory']) / original_results['inference_memory'] * 100
    
    print(f"Pruned Model:")
    print(f"  - Accuracy Change: {accuracy_change_pruned:+.2f}%")
    print(f"  - Inference Time Savings: {time_saving_pruned:+.2f}%")
    print(f"  - Memory Usage Savings: {mem_saving_pruned:+.2f}%")
    print(f"  - Parameter Reduction: {pruned_results['sparsity']:.2%}")
    
    print(f"\nQuantized Model:")
    print(f"  - Accuracy Change: {accuracy_change_quantized:+.2f}%")
    print(f"  - Inference Time Savings: {time_saving_quantized:+.2f}%")
    print(f"  - Memory Usage Savings: {mem_saving_quantized:+.2f}%")
    print(f"  - Effective Memory Reduction: {(pruned_results['sparsity'] + (1-pruned_results['sparsity'])*(1-config['quant_bits']/64)):.2%}")

if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'quantum_model_optimization'

In [1]:
def visualize_latent_space(model, X_test, n_samples=100):
    """
    Visualize the latent space representation of the test data
    
    Args:
        model: QuantumAutoencoder model
        X_test: Test data
        n_samples: Number of samples to visualize
        
    Returns:
        Path to the saved figure
    """
    # Use a subset of test data for visualization
    n_samples = min(n_samples, len(X_test))
    X_subset = X_test[:n_samples]
    
    # Get encoded states and latent representations
    encoded_states = []
    latent_states = []
    
    for x in X_subset:
        encoded = model.encoder(x, model.params)
        if hasattr(encoded, 'detach'):
            encoded = encoded.detach().numpy()
        latent = model.get_latent_state(encoded)
        
        encoded_states.append(encoded)
        latent_states.append(latent)
    
    # Convert to numpy arrays
    encoded_states = np.array(encoded_states)
    latent_states = np.array(latent_states)
    
    # Compute PCA of the latent states for 2D visualization
    from sklearn.decomposition import PCA
    
    # For encoded states
    pca_encoded = PCA(n_components=2)
    encoded_2d = pca_encoded.fit_transform(np.real(encoded_states.reshape(n_samples, -1)))
    
    # For latent states
    pca_latent = PCA(n_components=2)
    latent_2d = pca_latent.fit_transform(np.real(latent_states.reshape(n_samples, -1)))
    
    # Create a figure for visualization
    plt.figure(figsize=(15, 6))
    
    # Plot encoded states PCA
    plt.subplot(1, 2, 1)
    plt.scatter(encoded_2d[:, 0], encoded_2d[:, 1], c=range(n_samples), cmap='viridis', alpha=0.8)
    plt.colorbar(label='Sample Index')
    plt.title('PCA of Encoded States')
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.grid(alpha=0.3)
    
    # Plot latent states PCA
    plt.subplot(1, 2, 2)
    plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=range(n_samples), cmap='viridis', alpha=0.8)
    plt.colorbar(label='Sample Index')
    plt.title('PCA of Latent States')
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('latent_space_visualization.png')
    plt.close()
    
    return 'latent_space_visualization.png'

def visualize_reconstruction_quality(original_model, pruned_model, quantized_model, X_test, n_samples=5):
    """
    Visualize reconstruction quality comparison between original, pruned, and quantized models
    
    Args:
        original_model: Original quantum autoencoder model
        pruned_model: Pruned model
        quantized_model: Quantized model
        X_test: Test data
        n_samples: Number of samples to visualize
        
    Returns:
        Path to the saved figure
    """
    # Use a subset of test data for visualization
    n_samples = min(n_samples, len(X_test))
    X_subset = X_test[:n_samples]
    
    # Get reconstructions from each model
    original_recon = []
    pruned_recon = []
    quantized_recon = []
    
    for x in X_subset:
        # Original model reconstruction
        orig_decoded = original_model.forward(x)
        if hasattr(orig_decoded, 'detach'):
            orig_decoded = orig_decoded.detach().numpy()
        original_recon.append(np.real(orig_decoded))
        
        # Pruned model reconstruction
        pruned_decoded = pruned_model.forward(x)
        if hasattr(pruned_decoded, 'detach'):
            pruned_decoded = pruned_decoded.detach().numpy()
        pruned_recon.append(np.real(pruned_decoded))
        
        # Quantized model reconstruction
        quant_decoded = quantized_model.forward(x)
        if hasattr(quant_decoded, 'detach'):
            quant_decoded = quant_decoded.detach().numpy()
        quantized_recon.append(np.real(quant_decoded))
    
    # Convert to numpy arrays
    original_recon = np.array(original_recon)
    pruned_recon = np.array(pruned_recon)
    quantized_recon = np.array(quantized_recon)
    
    # Create a figure for visualization
    plt.figure(figsize=(15, 3*n_samples))
    
    for i in range(n_samples):
        # Original input
        plt.subplot(n_samples, 4, i*4 + 1)
        plt.stem(X_subset[i], use_line_collection=True)
        plt.title('Original Input' if i == 0 else '')
        plt.ylim(-1.1, 1.1)
        if i == n_samples - 1:
            plt.xlabel('Feature Index')
        
        # Original model reconstruction
        plt.subplot(n_samples, 4, i*4 + 2)
        plt.stem(original_recon[i], use_line_collection=True)
        plt.title('Original Model' if i == 0 else '')
        plt.ylim(-1.1, 1.1)
        if i == n_samples - 1:
            plt.xlabel('Feature Index')
        
        # Pruned model reconstruction
        plt.subplot(n_samples, 4, i*4 + 3)
        plt.stem(pruned_recon[i], use_line_collection=True)
        plt.title('Pruned Model' if i == 0 else '')
        plt.ylim(-1.1, 1.1)
        if i == n_samples - 1:
            plt.xlabel('Feature Index')
        
        # Quantized model reconstruction
        plt.subplot(n_samples, 4, i*4 + 4)
        plt.stem(quantized_recon[i], use_line_collection=True)
        plt.title('Quantized Model' if i == 0 else '')
        plt.ylim(-1.1, 1.1)
        if i == n_samples - 1:
            plt.xlabel('Feature Index')
    
    plt.tight_layout()
    plt.savefig('reconstruction_quality.png')
    plt.close()
    
    return 'reconstruction_quality.png'

def visualize_parameter_importance(model, top_n=20):
    """
    Visualize the most important parameters based on their importance scores
    
    Args:
        model: Model with importance scores
        top_n: Number of top important parameters to visualize
        
    Returns:
        Path to the saved figure
    """
    # Sort parameters by importance
    indices = np.argsort(model.importance_scores)[::-1]
    top_indices = indices[:top_n]
    top_importance = model.importance_scores[top_indices]
    top_values = model.params[top_indices]
    
    # Create a figure for visualization
    plt.figure(figsize=(12, 8))
    
    # Plot importance scores
    plt.subplot(2, 1, 1)
    plt.bar(range(top_n), top_importance, color='purple', alpha=0.7)
    plt.title('Top Parameter Importance Scores')
    plt.xlabel('Parameter Rank')
    plt.ylabel('Importance Score')
    plt.grid(alpha=0.3)
    
    # Plot parameter values
    plt.subplot(2, 1, 2)
    plt.bar(range(top_n), top_values, color='green', alpha=0.7)
    plt.title('Top Parameter Values')
    plt.xlabel('Parameter Rank')
    plt.ylabel('Parameter Value')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('parameter_importance.png')
    plt.close()
    
    # Also create a scatter plot of importance vs. value
    plt.figure(figsize=(10, 6))
    plt.scatter(np.abs(model.params), model.importance_scores, alpha=0.5)
    plt.title('Parameter Importance vs. Magnitude')
    plt.xlabel('Parameter Magnitude (absolute value)')
    plt.ylabel('Importance Score')
    plt.yscale('log')  # Log scale for better visualization
    plt.colorbar(plt.cm.ScalarMappable(norm=plt.Normalize(vmin=0, vmax=1)))
    plt.grid(True, alpha=0.3)
    plt.savefig('importance_vs_magnitude.png')
    plt.close()
    
    return 'parameter_importance.png', 'importance_vs_magnitude.png'

def create_comparison_report(original_results, pruned_results, quantized_results, 
                             pruning_ratio, quant_bits, alpha):
    """
    Create a summary report comparing the performance of all models
    
    Args:
        original_results: Results dictionary for the original model
        pruned_results: Results dictionary for the pruned model
        quantized_results: Results dictionary for the quantized model
        pruning_ratio: Pruning ratio used
        quant_bits: Quantization bits used
        alpha: Rényi entropy parameter used
        
    Returns:
        Path to the saved report
    """
    # Calculate improvements/reductions
    accuracy_change_pruned = (pruned_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    accuracy_change_quantized = (quantized_results['accuracy'] - original_results['accuracy']) / original_results['accuracy'] * 100
    
    time_saving_pruned = (original_results['inference_time'] - pruned_results['inference_time']) / original_results['inference_time'] * 100
    time_saving_quantized = (original_results['inference_time'] - quantized_results['inference_time']) / original_results['inference_time'] * 100
    
    mem_saving_pruned = (original_results['inference_memory'] - pruned_results['inference_memory']) / original_results['inference_memory'] * 100
    mem_saving_quantized = (original_results['inference_memory'] - quantized_results['inference_memory']) / original_results['inference_memory'] * 100
    
    # Create a report as a Markdown file
    with open('model_comparison_report.md', 'w') as f:
        f.write("# Quantum Autoencoder Optimization Report\n\n")
        
        f.write("## Optimization Configuration\n")
        f.write(f"- Pruning Ratio: {pruning_ratio:.2f}\n")
        f.write(f"- Quantization Bits: {quant_bits}\n")
        f.write(f"- Rényi Entropy Alpha: {alpha:.2f}\n\n")
        
        f.write("## Performance Metrics\n\n")
        
        f.write("### Accuracy\n")
        f.write(f"- Original Model: {original_results['accuracy']:.4f}\n")
        f.write(f"- Pruned Model: {pruned_results['accuracy']:.4f} ({accuracy_change_pruned:+.2f}%)\n")
        f.write(f"- Quantized Model: {quantized_results['accuracy']:.4f} ({accuracy_change_quantized:+.2f}%)\n\n")
        
        f.write("### Inference Time\n")
        f.write(f"- Original Model: {original_results['inference_time']:.4f}s\n")
        f.write(f"- Pruned Model: {pruned_results['inference_time']:.4f}s ({time_saving_pruned:+.2f}%)\n")
        f.write(f"- Quantized Model: {quantized_results['inference_time']:.4f}s ({time_saving_quantized:+.2f}%)\n\n")
        
        f.write("### Memory Usage\n")
        f.write(f"- Original Model: {original_results['inference_memory']:.2f} MB\n")
        f.write(f"- Pruned Model: {pruned_results['inference_memory']:.2f} MB ({mem_saving_pruned:+.2f}%)\n")
        f.write(f"- Quantized Model: {quantized_results['inference_memory']:.2f} MB ({mem_saving_quantized:+.2f}%)\n\n")
        
        f.write("### Model Parameters\n")
        f.write(f"- Original Model: {original_results['params_nonzero']}/{original_results['params_total']} parameters (Sparsity: {original_results['sparsity']:.2%})\n")
        f.write(f"- Pruned Model: {pruned_results['params_nonzero']}/{pruned_results['params_total']} parameters (Sparsity: {pruned_results['sparsity']:.2%})\n")
        f.write(f"- Quantized Model: {quantized_results['params_nonzero']}/{quantized_results['params_total']} parameters (Sparsity: {quantized_results['sparsity']:.2%})\n\n")
        
        f.write("## Visualization\n\n")
        f.write("Several visualizations have been generated to help analyze the models:\n\n")
        f.write("1. `model_comparison.png`: Bar charts comparing accuracy, inference time, and memory usage\n")
        f.write("2. `parameter_distributions.png`: Histograms of parameter distributions for all models\n")
        f.write("3. `importance_vs_value.png`: Scatter plot of parameter importance vs. parameter value\n")
        f.write("4. `pruning_sensitivity.png`: Analysis of model performance across different pruning ratios\n")
        f.write("5. `alpha_sensitivity.png`: Analysis of model performance across different alpha values\n")
        f.write("6. `quantization_sensitivity.png`: Analysis of model performance across different bit depths\n")
        f.write("7. `reconstruction_quality.png`: Comparison of reconstruction quality between models\n")
        f.write("8. `latent_space_visualization.png`: PCA visualization of the latent space\n")
        f.write("9. `renyi_entropy_heatmap.png`: Heatmap of Rényi entropy across parameters and alpha values\n\n")
        
        f.write("## Conclusion\n\n")
        
        # Overall assessment
        overall_assessment = "The optimization process "
        if accuracy_change_quantized >= -1:  # Less than 1% accuracy loss
            overall_assessment += "successfully maintained model accuracy while "
        elif accuracy_change_quantized >= -5:  # Less than 5% accuracy loss
            overall_assessment += "resulted in minimal accuracy loss while "
        else:
            overall_assessment += "resulted in some accuracy degradation but "
            
        if time_saving_quantized > 30 or mem_saving_quantized > 30:
            overall_assessment += "significantly reducing computational resources."
        elif time_saving_quantized > 10 or mem_saving_quantized > 10:
            overall_assessment += "moderately reducing computational resources."
        else:
            overall_assessment += "yielding some resource efficiency improvements."
            
        f.write(overall_assessment + "\n\n")
        
        # Specific recommendations
        f.write("### Recommendations\n\n")
        
        if accuracy_change_pruned > accuracy_change_quantized:
            f.write("- The pruned model provides better accuracy-efficiency trade-off than the quantized model.\n")
        else:
            f.write("- The quantized model provides better accuracy-efficiency trade-off than the pruned model.\n")
            
        if accuracy_change_quantized < -5:
            f.write("- Consider using a lower pruning ratio or higher bit depth to preserve more accuracy.\n")
            
        if time_saving_quantized < 10 and mem_saving_quantized < 10:
            f.write("- Explore more aggressive optimization techniques as current savings are modest.\n")
            
        f.write("- For deployment scenarios where memory is the primary constraint, the quantized model is recommended.\n")
        f.write("- For deployment scenarios where inference speed is critical, the pruned model may be preferred.\n")
    
    return 'model_comparison_report.md'import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

def visualize_renyi_entropy_heatmap(model, X_sample, alpha_range=[0.5, 1.0, 2.0, 3.0, 4.0]):
    """
    Create a heatmap showing how Rényi entropy changes across different alpha values
    
    Args:
        model: QuantumAutoencoder model
        X_sample: Sample data for evaluation
        alpha_range: Range of alpha values to test
    """
    n_params = len(model.params)
    n_to_show = min(100, n_params)  # Show at most 100 parameters for clarity
    
    # Select parameter indices
    if n_params > n_to_show:
        indices = np.linspace(0, n_params-1, n_to_show, dtype=int)
    else:
        indices = np.arange(n_params)
    
    # Create a matrix to store entropy values
    entropy_matrix = np.zeros((len(alpha_range), len(indices)))
    
    # Calculate entropy for each alpha and parameter
    for i, alpha in enumerate(alpha_range):
        for j, param_idx in enumerate(indices):
            # Create a copy of the parameters
            temp_params = np.copy(model.params)
            
            # Zero out the parameter
            temp_params[param_idx] = 0.0
            
            # Collect outputs with zero parameter
            outputs = []
            model.params = temp_params
            for x in X_sample:
                decoded = model.forward(x)
                if hasattr(decoded, 'detach'):
                    decoded = decoded.detach().numpy()
                outputs.append(decoded)
            
            # Calculate entropy of the output distribution
            outputs = np.array(outputs).flatten()
            
            # Normalize the values to form a probability distribution
            abs_values = np.abs(outputs)
            prob_dist = abs_values / np.sum(abs_values)
            
            # Handle zero probabilities by adding a small epsilon
            epsilon = 1e-10
            prob_dist = prob_dist + epsilon
            prob_dist = prob_dist / np.sum(prob_dist)
            
            if alpha == 1:
                # Shannon entropy for alpha=1
                entropy = -np.sum(prob_dist * np.log(prob_dist))
            else:
                # Rényi entropy for alpha≠1
                entropy = 1 / (1 - alpha) * np.log(np.sum(prob_dist ** alpha))
            
            entropy_matrix[i, j] = entropy
    
    # Restore original parameters
    model.params = np.copy(model.params)
    
    # Create the heatmap
    plt.figure(figsize=(12, 6))
    
    # Create a custom colormap from blue to red
    colors = [(0, 0, 1), (1, 1, 1), (1, 0, 0)]  # Blue -> White -> Red
    cmap = LinearSegmentedColormap.from_list("entropy_cmap", colors, N=100)
    
    # Normalize the entropy values to [-1, 1] for better visualization
    entropy_norm = 2 * (entropy_matrix - np.min(entropy_matrix)) / (np.max(entropy_matrix) - np.min(entropy_matrix)) - 1
    
    # Plot the heatmap
    plt.imshow(entropy_norm, cmap=cmap, aspect='auto')
    plt.colorbar(label='Normalized Entropy')
    plt.xlabel('Parameter Index')
    plt.ylabel('Alpha Value')
    plt.title('Rényi Entropy Across Alpha Values and Parameters')
    
    # Set y-ticks to alpha values
    plt.yticks(np.arange(len(alpha_range)), [f"{alpha:.1f}" for alpha in alpha_range])
    
    # Set x-ticks to parameter indices
    plt.xticks(np.arange(0, len(indices), max(1, len(indices)//10)), 
              [f"{idx}" for idx in indices[::max(1, len(indices)//10)]])
    
    plt.tight_layout()
    plt.savefig('renyi_entropy_heatmap.png')
    plt.close()

def plot_circuit_diagram(model, filename='quantum_circuit.png'):
    """
    Plot the quantum circuit diagram of the model
    
    Args:
        model: QuantumAutoencoder model
        filename: Output filename for the circuit diagram
    """
    import pennylane as qml
    
    # Create a new figure
    plt.figure(figsize=(12, 8))
    
    # Get a sample input
    x_sample = np.random.random(2**model.n_qubits)
    x_sample = x_sample / np.linalg.norm(x_sample)
    
    # Create a circuit diagram for the encoder
    encoder_fig, encoder_ax = qml.draw_mpl(model.encoder, expansion_strategy="device")(x_sample, model.params)
    encoder_ax.set_title("Encoder Circuit")
    
    # Save the encoder circuit diagram
    encoder_fig.tight_layout()
    encoder_fig.savefig('encoder_circuit.png')
    
    # Create a latent state
    encoded_state = model.encoder(x_sample, model.params)
    latent_state = model.get_latent_state(encoded_state)
    
    # Create a circuit diagram for the decoder
    decoder_fig, decoder_ax = qml.draw_mpl(model.decoder, expansion_strategy="device")(latent_state, model.params)
    decoder_ax.set_title("Decoder Circuit")
    
    # Save the decoder circuit diagram
    decoder_fig.tight_layout()
    decoder_fig.savefig('decoder_circuit.png')
    
    # Combined image for both encoder and decoder
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 1, 1)
    plt.imshow(plt.imread('encoder_circuit.png'))
    plt.axis('off')
    
    plt.subplot(2, 1, 2)
    plt.imshow(plt.imread('decoder_circuit.png'))
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()
    
    return filename

def visualize_model_size_comparison(original_results, pruned_results, quantized_results):
    """
    Visualize the model size comparison between original, pruned, and quantized models
    
    Args:
        original_results: Results dictionary for the original model
        pruned_results: Results dictionary for the pruned model
        quantized_results: Results dictionary for the quantized model
    """
    # Calculate model sizes
    param_size_original = original_results['params_total'] * 8  # 8 bytes for float64
    param_size_pruned = pruned_results['params_nonzero'] * 8  # 8 bytes for float64
    param_size_quantized = quantized_results['params_nonzero'] * (quantized_results.get('quant_info', {}).get('bits', 8) / 8)  # Convert bits to bytes
    
    # If quantization info is not available, estimate based on typical bit depth
    if 'quant_info' not in quantized_results:
        param_size_quantized = quantized_results['params_nonzero'] * 1  # Assume 8-bit (1 byte) quantization
    
    # Calculate memory reduction percentages
    pruned_reduction = 100 * (1 - param_size_pruned / param_size_original)
    quantized_reduction = 100 * (1 - param_size_quantized / param_size_original)
    
    # Create a bar chart for model size comparison
    plt.figure(figsize=(10, 6))
    
    models = ['Original', 'Pruned', 'Quantized']
    sizes = [param_size_original / 1024, param_size_pruned / 1024, param_size_quantized / 1024]  # Convert to KB
    
    bars = plt.bar(models, sizes, color=['blue', 'green', 'orange'])
    
    # Add labels and title
    plt.ylabel('Model Size (KB)')
    plt.title('Model Size Comparison')
    
    # Add size values and reduction percentages on top of bars
    for i, bar in enumerate(bars):
        height = bar.get_height()
        if i == 0:
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{sizes[i]:.2f} KB',
                    ha='center', va='bottom')
        else:
            reduction = pruned_reduction if i == 1 else quantized_reduction
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                    f'{sizes[i]:.2f} KB\n(-{reduction:.1f}%)',
                    ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('model_size_comparison.png')
    plt.close()
    
    return sizes, [pruned_reduction, quantized_reduction]

SyntaxError: invalid syntax (761780950.py, line 321)