**Task VIII: Vision transformer/Quantum Vision Transformer**

Implement a classical Vision transformer and apply it to MNIST. Show its performance on the test data. Comment on potential ideas to extend this classical vision transformer architecture to a quantum vision transformer and sketch out the architecture in detail.

In [1]:
!pip install pennylane
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pennylane as qml
import numpy as np
import matplotlib.pyplot as plt
import time

# For reproducible results, set random number generator seeds
torch.manual_seed(42)
np.random.seed(42)

# Determine if CUDA (GPU) is available, otherwise use CPU
computing_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {computing_device}")

# Create a custom dataset that mimics MNIST for demonstration
class SyntheticMNISTDataset(Dataset):
    def __init__(self, num_samples=10000, is_train_set=True):
        """
        Initializes the Synthetic MNIST-like dataset.

        Args:
            num_samples (int): Total number of samples in the dataset.
            is_train_set (bool): True for training set, False for test set.
        """
        self.num_samples = num_samples
        self.is_train_set = is_train_set

        # Initialize empty arrays for images and labels
        self.images = np.zeros((num_samples, 28, 28), dtype=np.float32)
        self.labels = np.zeros(num_samples, dtype=np.int64)

        # Define basic patterns for digits 0 to 9
        digit_patterns = []
        for digit_value in range(10):
            # Initialize a blank 28x28 pattern
            pattern = np.zeros((28, 28), dtype=np.float32)

            # Define simple shapes for each digit
            if digit_value == 0:  # Circle
                for x_pixel in range(28):
                    for y_pixel in range(28):
                        if 8 < x_pixel < 20 and 8 < y_pixel < 20:
                            dx_circle, dy_circle = x_pixel - 14, y_pixel - 14
                            distance_from_center = np.sqrt(dx_circle**2 + dy_circle**2)
                            if 4 < distance_from_center < 6:
                                pattern[x_pixel, y_pixel] = 1.0
            elif digit_value == 1:  # Vertical line
                pattern[5:23, 13:15] = 1.0
            elif digit_value == 2:  # Horizontal zigzag
                for x_pixel in range(6, 22, 4):
                    pattern[x_pixel:x_pixel+4, 8:20] = 1.0
            elif digit_value == 3:  # Cross
                pattern[10:18, 8:20] = 1.0
                pattern[6:22, 13:15] = 1.0
            elif digit_value == 4:  # Square
                pattern[8:20, 8:20] = 1.0
                pattern[10:18, 10:18] = 0.0
            elif digit_value == 5:  # Diamond
                for x_pixel in range(28):
                    for y_pixel in range(28):
                        if abs(x_pixel - 14) + abs(y_pixel - 14) < 8:
                            pattern[x_pixel, y_pixel] = 1.0
            elif digit_value == 6:  # Plus
                pattern[13:15, 8:20] = 1.0
                pattern[8:20, 13:15] = 1.0
            elif digit_value == 7:  # T-shape
                pattern[8:10, 8:20] = 1.0
                pattern[10:22, 13:15] = 1.0
            elif digit_value == 8:  # Eight (two circles)
                for x_pixel in range(28):
                    for y_pixel in range(28):
                        if 8 < x_pixel < 20:
                            dy_circle1 = y_pixel - 10
                            dy_circle2 = y_pixel - 18
                            dx_circle = x_pixel - 14
                            distance_circle1 = np.sqrt(dx_circle**2 + dy_circle1**2)
                            distance_circle2 = np.sqrt(dx_circle**2 + dy_circle2**2)
                            if distance_circle1 < 4 or distance_circle2 < 4:
                                pattern[x_pixel, y_pixel] = 1.0
            elif digit_value == 9:  # Nine (circle with tail)
                for x_pixel in range(28):
                    for y_pixel in range(28):
                        if 8 < x_pixel < 20 and 8 < y_pixel < 16:
                            dx_circle, dy_circle = x_pixel - 14, y_pixel - 12
                            distance_from_center = np.sqrt(dx_circle**2 + dy_circle**2)
                            if distance_from_center < 4:
                                pattern[x_pixel, y_pixel] = 1.0
                    pattern[14:16, 16:22] = 1.0

            digit_patterns.append(pattern)

        # Generate dataset samples
        samples_per_digit_class = num_samples // 10
        for digit_value in range(10):
            start_index = digit_value * samples_per_digit_class
            end_index = (digit_value + 1) * samples_per_digit_class

            for sample_index in range(start_index, end_index):
                # Introduce slight variations to each pattern
                noise = np.random.normal(0, 0.1, (28, 28))
                horizontal_shift = np.random.randint(-2, 3)
                vertical_shift = np.random.randint(-2, 3)

                # Shift and add noise to the base pattern
                image_instance = np.roll(np.roll(digit_patterns[digit_value], horizontal_shift, axis=0), vertical_shift, axis=1)
                image_instance = np.clip(image_instance + noise, 0, 1)

                self.images[sample_index] = image_instance
                self.labels[sample_index] = digit_value

        # Randomly shuffle the generated dataset
        shuffle_indices = np.random.permutation(num_samples)
        self.images = self.images[shuffle_indices]
        self.labels = self.labels[shuffle_indices]

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return self.num_samples

    def __getitem__(self, index):
        """
        Retrieves an image and its label given an index.

        Args:
            index (int): Index of the sample to retrieve.

        Returns:
            tuple: (image, label) where image is a tensor and label is an integer.
        """
        image = torch.FloatTensor(self.images[index]).unsqueeze(0)  # Add channel dimension for grayscale
        label = self.labels[index]
        return image, label

# Create training and testing datasets
training_dataset = SyntheticMNISTDataset(num_samples=5000, is_train_set=True)
testing_dataset = SyntheticMNISTDataset(num_samples=1000, is_train_set=False)

# Define hyperparameters for training
training_batch_size = 64
num_training_epochs = 3  # Reduced epochs for quicker demonstration
learning_rate = 0.001

# Create data loaders for training and testing
training_dataloader = DataLoader(training_dataset, batch_size=training_batch_size, shuffle=True)
testing_dataloader = DataLoader(testing_dataset, batch_size=training_batch_size, shuffle=False)

# Vision Transformer Model Parameters
image_resolution = 28  # Input image size: 28x28
patch_resolution = 7  # Divide image into 7x7 patches
num_image_patches = (image_resolution // patch_resolution) ** 2  # Total number of patches
embedding_dimension = 64  # Dimension of patch embeddings
num_attention_heads = 4  # Number of attention heads in Transformer
transformer_depth = 2  # Number of Transformer encoder blocks
mlp_hidden_dimension = 128  # Hidden dimension in MLP layers
num_output_classes = 10  # 10 classes for digits 0-9
input_channels = 1  # Grayscale images have 1 channel

# Define the Patch Embedding Layer
class PatchEmbeddingLayer(nn.Module):
    def __init__(self, image_resolution, patch_resolution, input_channels, embedding_dimension):
        """
        Converts input image patches into embeddings.

        Args:
            image_resolution (int): Height/Width of the input image.
            patch_resolution (int): Height/Width of each patch.
            input_channels (int): Number of input channels (e.g., 1 for grayscale, 3 for RGB).
            embedding_dimension (int): Dimensionality of the patch embeddings.
        """
        super().__init__()
        self.image_resolution = image_resolution
        self.patch_resolution = patch_resolution
        self.num_image_patches = (image_resolution // patch_resolution) ** 2

        # Convolutional layer to create patch embeddings
        self.patch_projection = nn.Conv2d(
            in_channels=input_channels,
            out_channels=embedding_dimension,
            kernel_size=patch_resolution,
            stride=patch_resolution
        )

    def forward(self, input_images):
        """
        Forward pass of the patch embedding layer.

        Args:
            input_images (torch.Tensor): Input images of shape (B, C, H, W).

        Returns:
            torch.Tensor: Patch embeddings of shape (B, num_patches, embedding_dimension).
        """
        # Input shape: (Batch Size, Channels, Height, Width)
        patches = self.patch_projection(input_images)  # Shape: (B, embedding_dimension, H/patch_size, W/patch_size)
        patches_flattened = patches.flatten(2)  # Shape: (B, embedding_dimension, num_patches)
        patches_embedded = patches_flattened.transpose(1, 2)  # Shape: (B, num_patches, embedding_dimension)
        return patches_embedded

# Define a Transformer Encoder Block
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dimension, num_attention_heads, mlp_hidden_dimension, dropout_rate=0.1):
        """
        Transformer encoder block with multi-head attention and MLP.

        Args:
            embedding_dimension (int): Embedding dimension.
            num_attention_heads (int): Number of attention heads.
            mlp_hidden_dimension (int): Hidden dimension in MLP.
            dropout_rate (float): Dropout probability.
        """
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embedding_dimension)
        self.multi_head_attention = nn.MultiheadAttention(embedding_dimension, num_attention_heads)
        self.layer_norm_2 = nn.LayerNorm(embedding_dimension)
        self.mlp_feedforward = nn.Sequential(
            nn.Linear(embedding_dimension, mlp_hidden_dimension),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dimension, embedding_dimension),
            nn.Dropout(dropout_rate)
        )

    def forward(self, input_embeddings):
        """
        Forward pass of the Transformer encoder block.

        Args:
            input_embeddings (torch.Tensor): Input embeddings.

        Returns:
            torch.Tensor: Output embeddings after transformer block.
        """
        # Input shape: (B, num_patches + 1, embedding_dimension)
        # Multi-head Self-Attention
        normalized_input = self.layer_norm_1(input_embeddings)
        attention_output, _ = self.multi_head_attention(normalized_input.transpose(0, 1), normalized_input.transpose(0, 1), normalized_input.transpose(0, 1))
        attention_output = attention_output.transpose(0, 1)  # Shape: (B, num_patches + 1, embedding_dimension)
        embeddings_after_attention = input_embeddings + attention_output

        # Feed-forward MLP
        embeddings_after_norm = self.layer_norm_2(embeddings_after_attention)
        mlp_output = self.mlp_feedforward(embeddings_after_norm)
        output_embeddings = embeddings_after_attention + mlp_output
        return output_embeddings

# Define the Vision Transformer (ViT) Model
class VisionTransformerModel(nn.Module):
    def __init__(self, image_resolution, patch_resolution, input_channels, embedding_dimension, num_attention_heads, transformer_depth, mlp_hidden_dimension, num_output_classes):
        """
        Vision Transformer (ViT) model for image classification.

        Args:
            image_resolution (int): Height/Width of the input image.
            patch_resolution (int): Height/Width of each patch.
            input_channels (int): Number of input channels.
            embedding_dimension (int): Embedding dimension for patches.
            num_attention_heads (int): Number of attention heads in Transformer.
            transformer_depth (int): Number of Transformer encoder blocks.
            mlp_hidden_dimension (int): Hidden dimension in MLP layers.
            num_output_classes (int): Number of output classes.
        """
        super().__init__()
        self.patch_embed = PatchEmbeddingLayer(image_resolution, patch_resolution, input_channels, embedding_dimension)
        self.num_image_patches = self.patch_embed.num_image_patches

        # Learnable class token and positional embeddings
        self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dimension))
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_image_patches + 1, embedding_dimension))

        # Stack of Transformer encoder blocks
        self.transformer_encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(embedding_dimension, num_attention_heads, mlp_hidden_dimension)
            for _ in range(transformer_depth)
        ])

        self.layer_norm_final = nn.LayerNorm(embedding_dimension)
        self.classification_head = nn.Linear(embedding_dimension, num_output_classes)

    def forward(self, input_images):
        """
        Forward pass of the Vision Transformer model.

        Args:
            input_images (torch.Tensor): Input images of shape (B, C, H, W).

        Returns:
            torch.Tensor: Classification output logits of shape (B, num_output_classes).
        """
        # Input shape: (Batch Size, Channels, Height, Width)
        batch_size = input_images.shape[0]

        # Create patch embeddings
        patches_embedded = self.patch_embed(input_images)  # Shape: (B, num_patches, embedding_dimension)

        # Prepend class token to patch embeddings
        class_tokens = self.class_token.expand(batch_size, -1, -1)  # Shape: (B, 1, embedding_dimension)
        embeddings_with_class_token = torch.cat([class_tokens, patches_embedded], dim=1)  # Shape: (B, num_patches + 1, embedding_dimension)

        # Add positional embeddings
        transformer_input = embeddings_with_class_token + self.position_embeddings

        # Pass through Transformer encoder blocks
        transformer_output = transformer_input
        for transformer_block in self.transformer_encoder_layers:
            transformer_output = transformer_block(transformer_output)

        # Extract class token embedding and normalize
        class_token_embedding = self.layer_norm_final(transformer_output[:, 0])  # Shape: (B, embedding_dimension)

        # Classification head for final output
        output_logits = self.classification_head(class_token_embedding)  # Shape: (B, num_output_classes)
        return output_logits

# Initialize the Vision Transformer model
vision_transformer_model = VisionTransformerModel(
    image_resolution=image_resolution,
    patch_resolution=patch_resolution,
    input_channels=input_channels,
    embedding_dimension=embedding_dimension,
    num_attention_heads=num_attention_heads,
    transformer_depth=transformer_depth,
    mlp_hidden_dimension=mlp_hidden_dimension,
    num_output_classes=num_output_classes
).to(computing_device)

print(f"Vision Transformer initialized with {sum(p.numel() for p in vision_transformer_model.parameters())} parameters")

# Loss function and optimizer for training
loss_criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vision_transformer_model.parameters(), lr=learning_rate)

# Training loop
total_steps_per_epoch = len(training_dataloader)
training_losses = []
print("Starting training process...")
for epoch in range(num_training_epochs):
    vision_transformer_model.train()  # Set model to training mode
    current_running_loss = 0.0
    epoch_start_time = time.time()
    for step_index, (images, labels) in enumerate(training_dataloader):
        images = images.to(computing_device)
        labels = labels.to(computing_device)

        # Forward pass
        outputs = vision_transformer_model(images)
        loss = loss_criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()  # Clear gradients from previous step
        loss.backward()  # Compute gradients
        optimizer.step()  # Update model parameters

        current_running_loss += loss.item()

        if (step_index + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_training_epochs}], Step [{step_index+1}/{total_steps_per_epoch}], Loss: {loss.item():.4f}')

    epoch_loss = current_running_loss / len(training_dataloader)
    training_losses.append(epoch_loss)
    epoch_end_time = time.time()
    print(f'Epoch [{epoch+1}/{num_training_epochs}], Loss: {epoch_loss:.4f}, Time: {epoch_end_time - epoch_start_time:.2f}s')

# Evaluate the trained model on the test dataset
vision_transformer_model.eval()  # Set model to evaluation mode
with torch.no_grad():  # Disable gradient calculation during inference
    num_correct_predictions = 0
    total_samples = 0
    for images, labels in testing_dataloader:
        images = images.to(computing_device)
        labels = labels.to(computing_device)
        outputs = vision_transformer_model(images)
        _, predicted_labels = torch.max(outputs.data, 1)  # Get index of the max log-probability
        total_samples += labels.size(0)
        num_correct_predictions += (predicted_labels == labels).sum().item()

test_accuracy = 100 * num_correct_predictions / total_samples
print(f'Test Accuracy: {test_accuracy:.2f}%')

# Plot the training loss curve
plt.figure(figsize=(10, 5))
plt.plot(training_losses)
plt.title('Training Loss Curve')
plt.xlabel('Epoch Number')
plt.ylabel('Training Loss')
plt.grid(True)
plt.savefig('training_loss.png')
plt.close()

# ===================================
# Quantum Vision Transformer Architecture
# ===================================

# Define a quantum computing device using PennyLane simulator
num_quantum_qubits = 4  # Number of qubits for quantum components
quantum_device = qml.device("default.qubit", wires=num_quantum_qubits)
print(f"Quantum device initialized with {num_quantum_qubits} qubits")

# Define a quantum circuit for feature extraction
@qml.qnode(quantum_device)
def quantum_feature_circuit(input_features, quantum_weights):
    """
    Quantum circuit for feature extraction using angle embedding and strongly entangling layers.

    Args:
        input_features (list[float]): Classical input features to be encoded.
        quantum_weights (array): Trainable parameters for quantum layers.

    Returns:
        list[float]: Expectation values of PauliZ operators for each qubit.
    """
    # Encode classical inputs into quantum state amplitudes
    qml.templates.AngleEmbedding(input_features, wires=range(num_quantum_qubits))

    # Apply parameterized quantum operations (layers of gates)
    qml.templates.StronglyEntanglingLayers(quantum_weights, wires=range(num_quantum_qubits))

    # Measure expectation value of Pauli Z on each qubit
    return [qml.expval(qml.PauliZ(qubit_index)) for qubit_index in range(num_quantum_qubits)]

# Quantum-enhanced Multi-Head Attention Mechanism
class QuantumAttentionLayer(nn.Module):
    def __init__(self, embedding_dimension, num_attention_heads, num_quantum_qubits=4, dropout_rate=0.1):
        """
        Multi-head attention layer enhanced with quantum processing in the query part.

        Args:
            embedding_dimension (int): Embedding dimension of inputs.
            num_attention_heads (int): Number of attention heads.
            num_quantum_qubits (int): Number of qubits to use in quantum circuit.
            dropout_rate (float): Dropout probability.
        """
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.num_attention_heads = num_attention_heads
        self.num_quantum_qubits = num_quantum_qubits
        self.head_dimension = embedding_dimension // num_attention_heads

        # Quantum circuit parameters - initialized randomly
        self.quantum_circuit_weights = nn.Parameter(
            torch.FloatTensor(2, num_quantum_qubits, 3).uniform_(0, 2 * np.pi)
        )

        # Linear layers for projecting Query, Key, Value
        self.query_projection = nn.Linear(embedding_dimension, embedding_dimension)
        self.key_projection = nn.Linear(embedding_dimension, embedding_dimension)
        self.value_projection = nn.Linear(embedding_dimension, embedding_dimension)
        self.output_projection = nn.Linear(embedding_dimension, embedding_dimension)

        self.dropout_layer = nn.Dropout(dropout_rate)

    def forward(self, input_embeddings):
        """
        Forward pass of the Quantum Attention Layer.

        Args:
            input_embeddings (torch.Tensor): Input embeddings.

        Returns:
            torch.Tensor: Output embeddings after quantum-enhanced attention.
        """
        batch_size_val = input_embeddings.size(0)
        sequence_length = input_embeddings.size(1)

        # Project input into Query, Key, and Value spaces
        query_layer = self.query_projection(input_embeddings)
        key_layer = self.key_projection(input_embeddings)
        value_layer = self.value_projection(input_embeddings)

        # Prepare Query for quantum processing
        quantum_feature_dimension = min(self.num_quantum_qubits, self.head_dimension) # Use at most num_quantum_qubits features for quantum processing
        query_reshaped = query_layer.view(batch_size_val * sequence_length, self.num_attention_heads, self.head_dimension)

        # Apply quantum circuit to a subset of query features
        query_quantum_enhanced = query_reshaped.clone() # Initialize tensor to store quantum processed queries

        # Example: Process a small batch and heads for demonstration speed
        sample_batch_size = min(batch_size_val * sequence_length, 10) # Limit batch size for faster execution
        sample_heads_count = min(self.num_attention_heads, 2) # Limit head count for faster execution

        for batch_index in range(sample_batch_size):
            for head_index in range(sample_heads_count):
                # Extract features for quantum circuit
                features_to_process = query_reshaped[batch_index, head_index, :quantum_feature_dimension].detach().cpu().numpy()

                # Scale features to range [0, 2pi] for angle encoding
                scaled_features = (features_to_process * 0.5 + 0.5) * 2 * np.pi

                # Execute quantum circuit to enhance query
                circuit_weights = self.quantum_circuit_weights.detach().cpu().numpy()
                quantum_results = torch.tensor(quantum_feature_circuit(scaled_features, circuit_weights))

                # Update query with quantum results
                query_quantum_enhanced[batch_index, head_index, :quantum_feature_dimension] = quantum_results

        # Reshape query back to original format
        query_layer = query_quantum_enhanced.view(batch_size_val, sequence_length, -1)

        # Standard Multi-Head Attention calculation
        query_heads = query_layer.view(batch_size_val, sequence_length, self.num_attention_heads, self.head_dimension).transpose(1, 2)
        key_heads = key_layer.view(batch_size_val, sequence_length, self.num_attention_heads, self.head_dimension).transpose(1, 2)
        value_heads = value_layer.view(batch_size_val, sequence_length, self.num_attention_heads, self.head_dimension).transpose(1, 2)

        # Calculate attention scores
        attention_scores = torch.matmul(query_heads, key_heads.transpose(-2, -1)) / (self.head_dimension ** 0.5)

        # Apply softmax to get attention probabilities
        attention_probabilities = torch.softmax(attention_scores, dim=-1)
        attention_probabilities = self.dropout_layer(attention_probabilities)

        # Compute context vector (weighted sum of values)
        attention_output_heads = torch.matmul(attention_probabilities, value_heads)
        attention_output_merged = attention_output_heads.transpose(1, 2).reshape(batch_size_val, sequence_length, self.embedding_dimension)
        attention_output_projected = self.output_projection(attention_output_merged)

        return attention_output_projected

# Quantum Transformer Encoder Block
class QuantumTransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dimension, num_attention_heads, mlp_hidden_dimension, num_quantum_qubits=4, dropout_rate=0.1):
        """
        Transformer encoder block with quantum-enhanced multi-head attention.

        Args:
            embedding_dimension (int): Embedding dimension.
            num_attention_heads (int): Number of attention heads.
            mlp_hidden_dimension (int): Hidden dimension in MLP.
            num_quantum_qubits (int): Number of qubits for quantum attention.
            dropout_rate (float): Dropout probability.
        """
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(embedding_dimension)
        self.quantum_attention = QuantumAttentionLayer(embedding_dimension, num_attention_heads, num_quantum_qubits, dropout_rate) # Use quantum attention here
        self.layer_norm_2 = nn.LayerNorm(embedding_dimension)
        # MLP remains classical for this hybrid approach
        self.mlp_feedforward = nn.Sequential(
            nn.Linear(embedding_dimension, mlp_hidden_dimension),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_hidden_dimension, embedding_dimension),
            nn.Dropout(dropout_rate)
        )

    def forward(self, input_embeddings):
        """
        Forward pass of the Quantum Transformer encoder block.

        Args:
            input_embeddings (torch.Tensor): Input embeddings.

        Returns:
            torch.Tensor: Output embeddings after quantum transformer block.
        """
        # Layer normalization followed by Quantum Attention
        normalized_input = self.layer_norm_1(input_embeddings)
        attention_output = self.quantum_attention(normalized_input)
        embeddings_after_attention = input_embeddings + attention_output

        # Layer normalization followed by classical MLP
        embeddings_after_norm = self.layer_norm_2(embeddings_after_attention)
        mlp_output = self.mlp_feedforward(embeddings_after_norm)
        output_embeddings = embeddings_after_attention + mlp_output
        return output_embeddings

# Quantum Vision Transformer Model
class QuantumVisionTransformerModel(nn.Module):
    def __init__(self, image_resolution, patch_resolution, input_channels, embedding_dimension, num_attention_heads, transformer_depth, mlp_hidden_dimension, num_output_classes, num_quantum_qubits=4):
        """
        Quantum Vision Transformer (QViT) model with quantum-enhanced attention.

        Args:
            image_resolution (int): Height/Width of the input image.
            patch_resolution (int): Height/Width of each patch.
            input_channels (int): Number of input channels.
            embedding_dimension (int): Embedding dimension for patches.
            num_attention_heads (int): Number of attention heads in Transformer.
            transformer_depth (int): Number of Transformer encoder blocks.
            mlp_hidden_dimension (int): Hidden dimension in MLP layers.
            num_output_classes (int): Number of output classes.
            num_quantum_qubits (int): Number of qubits for quantum attention.
        """
        super().__init__()
        # Classical Patch Embedding remains
        self.patch_embed = PatchEmbeddingLayer(image_resolution, patch_resolution, input_channels, embedding_dimension)
        self.num_image_patches = self.patch_embed.num_image_patches

        # Class token and positional embeddings (classical)
        self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dimension))
        self.position_embeddings = nn.Parameter(torch.randn(1, self.num_image_patches + 1, embedding_dimension))

        # Use Quantum Transformer Encoder Blocks
        self.transformer_encoder_layers = nn.ModuleList([
            QuantumTransformerEncoderBlock(embedding_dimension, num_attention_heads, mlp_hidden_dimension, num_quantum_qubits)
            for _ in range(transformer_depth)
        ])

        self.layer_norm_final = nn.LayerNorm(embedding_dimension)
        self.classification_head = nn.Linear(embedding_dimension, num_output_classes) # Classical classification head

    def forward(self, input_images):
        """
        Forward pass of the Quantum Vision Transformer Model.

        Args:
            input_images (torch.Tensor): Input images of shape (B, C, H, W).

        Returns:
            torch.Tensor: Classification output logits from QViT.
        """
        # Input shape: (Batch Size, Channels, Height, Width)
        batch_size_val = input_images.shape[0]

        # Classical Patch Embedding
        patches_embedded = self.patch_embed(input_images)  # Shape: (B, num_patches, embedding_dimension)

        # Add Class token
        class_tokens = self.class_token.expand(batch_size_val, -1, -1)  # Shape: (B, 1, embedding_dimension)
        embeddings_with_class_token = torch.cat([class_tokens, patches_embedded], dim=1)  # Shape: (B, num_patches + 1, embedding_dimension)

        # Add Positional Embeddings
        transformer_input = embeddings_with_class_token + self.position_embeddings

        # Pass through Quantum Transformer Blocks
        transformer_output = transformer_input
        for transformer_block in self.transformer_encoder_layers:
            transformer_output = transformer_block(transformer_output)

        # Normalize and extract class token for classification
        class_token_embedding = self.layer_norm_final(transformer_output[:, 0])  # Shape: (B, embedding_dimension)

        # Classical Classification Head
        output_logits = self.classification_head(class_token_embedding)  # Shape: (B, num_output_classes)
        return output_logits

# Initialize the Quantum Vision Transformer Model
quantum_vision_transformer_model = QuantumVisionTransformerModel(
    image_resolution=image_resolution,
    patch_resolution=patch_resolution,
    input_channels=input_channels,
    embedding_dimension=embedding_dimension,
    num_attention_heads=num_attention_heads,
    transformer_depth=transformer_depth,
    mlp_hidden_dimension=mlp_hidden_dimension,
    num_output_classes=num_output_classes,
    num_quantum_qubits=num_quantum_qubits
).to(computing_device)

print(f"Quantum Vision Transformer initialized with {sum(p.numel() for p in quantum_vision_transformer_model.parameters())} parameters")

# Display Quantum Vision Transformer Architecture Summary
print("\nQuantum Vision Transformer Architecture:")
print(f"- Image size: {image_resolution}x{image_resolution}")
print(f"- Patch size: {patch_resolution}x{patch_resolution}")
print(f"- Number of patches: {num_image_patches}")
print(f"- Embedding dimension: {embedding_dimension}")
print(f"- Number of attention heads: {num_attention_heads}")
print(f"- Number of transformer blocks: {transformer_depth}")
print(f"- MLP dimension: {mlp_hidden_dimension}")
print(f"- Number of qubits per patch: {num_quantum_qubits}")

print("\nKey Components of the Quantum Vision Transformer:")
print("1. Classical Patch Embedding Layer: Divides the image into patches and embeds them using convolution.")
print("2. Quantum Attention Mechanism: Enhances the query representation using quantum circuits for attention calculation.")
print("3. Quantum Transformer Blocks: Combines quantum-enhanced attention with classical MLP layers in each transformer block.")
print("4. Classical Classification Head: Uses a linear layer to classify based on the final class token embedding.")

print("\nDetails of Quantum Enhancement:")
print("- Quantum Circuit: Employs AngleEmbedding for classical data encoding into quantum states.")
print("- Entangling Layers: Utilizes StronglyEntanglingLayers to introduce entanglement and quantum non-linearity.")
print("- Quantum Attention: Applies quantum processing specifically to the query vectors within the attention mechanism.")
print("- Measurement: Uses PauliZ expectation values to extract classical information from quantum states for subsequent layers.")

print("\nPotential Extensions for Future Research:")
print("1. Fully Quantum Patch Embedding: Replace classical convolution in patch embedding with quantum circuits.")
print("2. Quantum MLP: Implement the feed-forward MLP layers using variational quantum circuits to explore further quantum benefits.")
print("3. Quantum Position Encoding: Investigate quantum methods for encoding positional information of patches.")
print("4. Quantum-Enhanced Key and Value: Extend quantum processing to key and value vectors in the attention mechanism, not just queries.")
print("5. Hardware-Efficient Quantum Circuits: Optimize the quantum circuits for better performance on Noisy Intermediate-Scale Quantum (NISQ) devices.")
print("6. Hybrid Classical-Quantum Training Strategies: Explore techniques like classical pre-training followed by quantum fine-tuning to stabilize and improve quantum training.")
print("7. Quantum Advantage Benchmarking: Conduct rigorous comparative analysis against classical models to quantify and demonstrate potential quantum advantage.")

Collecting pennylane
  Downloading PennyLane-0.40.0-py3-none-any.whl.metadata (10 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting tomlkit (from pennylane)
  Downloading tomlkit-0.13.2-py3-none-any.whl.metadata (2.7 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.40 (from pennylane)
  Downloading PennyLane_Lightning-0.40.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (27 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.40->pennylane)
  Downloading scipy_openblas32-0.3.29.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5