**Task VIII: Vision transformer/Quantum Vision Transformer**

Implement a classical Vision transformer and apply it to MNIST. Show its performance on the test data. Comment on potential ideas to extend this classical vision transformer architecture to a quantum vision transformer and sketch out the architecture in detail.

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

# --- 1. Data Loading and Preprocessing ---

def load_mnist(batch_size=64):
    """Loads MNIST data, prepares it for the model."""

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])

    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_load = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_load = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_load, test_load


# --- 2. Vision Transformer Implementation ---

class PatchEmbed(nn.Module):
    """Breaks images into patches, converts them into embeddings."""
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super(PatchEmbed, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim))

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        batch_size = x.shape[0]
        cls_tokens = self.class_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        return x

class MultiHeadAttn(nn.Module):
    """Handles multi-head attention."""
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttn, self).__init__()
        assert embed_dim % num_heads == 0, "Embed dim must divide heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = torch.softmax(scores, dim=-1)
        attended_values = torch.matmul(attn_weights, v)
        attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        output = self.out_linear(attended_values)
        return output

class TransformerBlock(nn.Module):
    """One block of the Transformer encoder."""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super(TransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttn(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x_norm = self.norm1(x)
        x = x + self.dropout(self.attn(x_norm))
        x_norm = self.norm2(x)
        x = x + self.dropout(self.mlp(x_norm))
        return x

class VisionTransformer(nn.Module):
    """The full Vision Transformer model."""
    def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, num_heads, num_layers, mlp_ratio=4, dropout=0.1):
        super(VisionTransformer, self).__init__()

        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.transformer_blocks = nn.Sequential(
            *[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(num_layers)]
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.transformer_blocks(x)
        class_token = x[:, 0, :]
        class_token = self.norm(class_token)
        logits = self.head(class_token)
        return logits

# --- 3. Training and Evaluation ---

def train(model, train_load, optimizer, criterion, device):
    """Trains the model for one round."""
    model.train()
    total_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_load):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {total_loss / (batch_idx + 1):.4f}')
    return total_loss / len(train_load)

def evaluate(model, test_load, criterion, device):
    """Checks the model's performance on the test data."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_load:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_load)
    accuracy = 100. * correct / len(test_load.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_load.dataset)} ({accuracy:.2f}%)\n')
    return test_loss, accuracy

def main():
    img_size = 28
    patch_size = 7
    in_channels = 1
    num_classes = 10
    embed_dim = 64
    num_heads = 4
    num_layers = 2
    mlp_ratio = 2
    dropout = 0.1
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 10

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_load, test_load = load_mnist(batch_size=batch_size)

    model = VisionTransformer(img_size, patch_size, in_channels, num_classes, embed_dim, num_heads, num_layers, mlp_ratio, dropout).to(device)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()
        train_loss = train(model, train_load, optimizer, criterion, device)
        end_time = time.time()
        print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Time: {end_time - start_time:.2f}s')
        test_loss, test_accuracy = evaluate(model, test_load, criterion, device)

    print("Final Evaluation:")
    evaluate(model, test_load, criterion, device)

# --- 4. Quantum Vision Transformer (Conceptual) ---
def quantum_vision_transformer_ideas():
    """
    Ideas for how to build a Quantum Vision Transformer (QViT).
    This is just thinking out loud, real QViTs need lots of quantum power.
    """
    print("\n--- Thinking about Quantum Vision Transformers (QViT) ---")

    print("\n1. Quantum Patch Embedding:")
    print("    - Instead of normal picture splitting, use quantum ways to encode image info.")
    print("    - We could try:")
    print("        a) Angle Encoding: Use pixel brightness to turn qubits.")
    print("        b) Amplitude Encoding: Use pixel brightness as quantum state sizes.")
    print("        c) Quantum Convolution Networks: Use quantum circuits to find patch features, instead of normal picture filtering.")
    print("    - Hard part: getting big picture data into small quantum spaces.")

    print("\n2. Quantum Multi-Head Attention (QMHA):")
    print("    - This is where quantum might really shine. Normal attention takes lots of work.")
    print("    - We could try:")
    print("        a) Quantum Memory: Make a quantum system that acts like attention, using quantum search tricks.")
    print("        b) Quantum Circuits: Use circuits that learn attention weights. This is part quantum, part normal computer.")
    print("        c) Quantum Matrix Tricks: Use quantum ways to do the math in attention. This is advanced, but powerful.")
    print("    - Hard part: making a QMHA that's actually faster or better than normal attention.")

    print("\n3. Quantum Transformer Encoder Block:")
    print("    - Put Quantum Patch Embedding and QMHA together.")
    print("    - Normal computers use feed-forward networks (MLP). We can try quantum versions.")
    print("    - Think about quantum ways to normalize data.")

    print("\n4. Quantum Measurement and Output:")
    print("    - After quantum work, check qubit states to get normal numbers.")
    print("    - Use these numbers to make a final guess with a normal computer layer (or a small quantum circuit).")

    print("\n5. Mix Quantum and Normal Computers:")
    print("    - Real QViTs will likely be a mix:")
    print("        - Use normal computers for easy stuff (resizing, normalizing).")
    print("        - Use quantum for hard parts (embedding, attention).")
    print("        - Use normal computers for final guesses.")
    print("        - Train with a mix of quantum and normal computer tricks.")

    print("\nQViT Idea Sketch:")
    print("""
    Picture In (Normal)
    ↓
    Normal Prep (Resize, Normalize)
    ↓
    Quantum Patch Embed (Angle/Amplitude or Quantum Filters)
    ↓
    [Quantum Transformer Blocks (Repeat N times)]
        ↓
        Quantum Multi-Head Attention (Quantum Circuits or Memory)
        ↓
        Quantum Feed-Forward (Maybe, Quantum Circuits)
        ↓
        Quantum Normalize (Maybe)
        ↓
    [End of Quantum Blocks]
    ↓
    Quantum Check (Measure Class Qubits)
    ↓
    Normal Guess Layer (Normal Computer)
    ↓
    Final Guess (Normal Numbers)
    """)

    print("\nThings to think about:")
    print("    - Quantum power is limited now. Small, noisy quantum computers make big QViTs hard.")
    print("    - Getting picture data into quantum form is tricky.")
    print("    - Making quantum tricks (like for attention) that are actually better is hard.")
    print("    - Training mixed quantum-normal models takes special tricks.")
    print("    - Making QViTs work on big pictures is a big question.")
    print("    - This is new. The best way to do this is still being figured out.")

if __name__ == "__main__":
    main()
    quantum_vision_transformer_ideas()

Using device: cuda
Batch 0, Loss: 2.3705
Batch 100, Loss: 1.2790
Batch 200, Loss: 0.8625
Batch 300, Loss: 0.6873
Batch 400, Loss: 0.5872
Batch 500, Loss: 0.5203
Batch 600, Loss: 0.4730
Batch 700, Loss: 0.4362
Batch 800, Loss: 0.4087
Batch 900, Loss: 0.3836
Epoch 1, Train Loss: 0.3755, Time: 16.52s

Test set: Average loss: 0.1521, Accuracy: 9560/10000 (95.60%)

Batch 0, Loss: 0.2043
Batch 100, Loss: 0.1705
Batch 200, Loss: 0.1732
Batch 300, Loss: 0.1740
Batch 400, Loss: 0.1711
Batch 500, Loss: 0.1680
Batch 600, Loss: 0.1688
Batch 700, Loss: 0.1655
Batch 800, Loss: 0.1624
Batch 900, Loss: 0.1610
Epoch 2, Train Loss: 0.1597, Time: 16.33s

Test set: Average loss: 0.1124, Accuracy: 9669/10000 (96.69%)

Batch 0, Loss: 0.1037
Batch 100, Loss: 0.1165
Batch 200, Loss: 0.1201
Batch 300, Loss: 0.1222
Batch 400, Loss: 0.1224
Batch 500, Loss: 0.1230
Batch 600, Loss: 0.1217
Batch 700, Loss: 0.1226
Batch 800, Loss: 0.1226
Batch 900, Loss: 0.1221
Epoch 3, Train Loss: 0.1229, Time: 16.46s

Test set: Av