## MNIST

In [None]:
def load_mnist_data(n_train, n_test, binary=True):
    """Load and preprocess MNIST dataset.
    
    Args:
        n_train: Number of training samples
        n_test: Number of test samples
        binary: If True, only use digits 0 and 1
    """
    # Load MNIST
    X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
    X = X.reshape(-1, 28, 28)
    
    if binary:
        # Filter for digits 0 and 1
        mask = (y == '0') | (y == '1')
        X, y = X[mask], y[mask].astype(float)
    else:
        # Convert labels to float
        y = y.astype(float)
    
    # Normalize pixel values to [0, 1]
    X = X / 255.0
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=n_train, test_size=n_test, stratify=y)
    
    # Create patches
    X_train_patches = create_patches(X_train)
    X_test_patches = create_patches(X_test)
    
    return (
        X_train_patches,
        y_train.reshape(-1, 1),
        X_test_patches,
        y_test.reshape(-1, 1)
    )

In [None]:
def train_qvit(n_train, n_test, n_epochs):
    # Load data
    x_train, y_train, x_test, y_test = load_mnist_data(n_train, n_test)

    # Initialize model and parameters (S=16 for 4x4 patches)
    model = QSANN_image_classifier(S=16, n=4, Denc=2, D=1, num_layers=1)
    params = init_params(S=16, n=4, Denc=2, D=1, num_layers=1)
    
    # Define optimizer with same learning rate as PyTorch
    optimizer = optax.adam(learning_rate=0.01)
    opt_state = optimizer.init(params)

    # Create arrays to store metrics
    train_cost_epochs = []
    test_cost_epochs = []
    train_acc_epochs = []
    test_acc_epochs = []

    # Loss function
    def loss_fn(p, x, y):
        y_pred = model(x, p)
        return binary_cross_entropy(y, y_pred), y_pred

    # JIT-compiled update step
    @jax.jit
    def update_step(params, opt_state, x_train, y_train, x_test, y_test):
        # Get both value and gradient, along with model predictions
        (loss_val, y_pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x_train, y_train)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)

        # Compute metrics
        train_acc = accuracy(y_train, y_pred)
        test_loss, test_acc = evaluate(model, new_params, x_test, y_test)

        return new_params, new_opt_state, loss_val, train_acc, test_loss, test_acc

    # Training loop
    start = time.time()
    
    for epoch in range(n_epochs):
        params, opt_state, train_cost, train_acc, test_cost, test_acc = update_step(
            params, opt_state, x_train, y_train, x_test, y_test
        )
        
        # Store metrics
        train_cost_epochs.append(float(train_cost))
        train_acc_epochs.append(float(train_acc))
        test_cost_epochs.append(float(test_cost))
        test_acc_epochs.append(float(test_acc))
        
        # Print progress every 10 epochs
        if (epoch + 1) % 10 == 0:
            print(f"Train Size: {n_train}, Epoch: {epoch + 1}/{n_epochs}, "
                  f"Train Loss: {train_cost:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Test Loss: {test_cost:.4f}, Test Acc: {test_acc:.4f}")

    training_time = time.time() - start
    print(f"\nTraining completed in {training_time:.2f} seconds")

    return dict(
        n_train=[n_train] * n_epochs,
        step=np.arange(1, n_epochs + 1, dtype=int),
        train_cost=train_cost_epochs,
        train_acc=train_acc_epochs,
        test_cost=test_cost_epochs,
        test_acc=test_acc_epochs,
    )


In [None]:
def evaluate(model, params, x, y):
    y_pred = model(x, params)
    loss = binary_cross_entropy(y, y_pred)
    acc = accuracy(y, y_pred)
    return loss, acc

## CIFAR

In [None]:
def load_cifar_data(n_train, n_test, batch_size, binary=True, augment=True):
    """Load and preprocess CIFAR-10 dataset with optional data augmentation.
    Returns a batched and shuffled tf.data.Dataset for training, and JAX arrays for testing.
    """
    # Load CIFAR-10
    (X_train_full, y_train_full), (X_test_full, y_test_full) = tf.keras.datasets.cifar10.load_data()
    
    if binary:
        # Use only two classes (0: airplane, 1: automobile)
        mask_train = (y_train_full[:, 0] == 0) | (y_train_full[:, 0] == 1)
        mask_test = (y_test_full[:, 0] == 0) | (y_test_full[:, 0] == 1)
        X_train_full = X_train_full[mask_train]
        y_train_full = y_train_full[mask_train]
        X_test_full = X_test_full[mask_test]
        y_test_full = y_test_full[mask_test]
        # Convert labels to binary (0 or 1)
        y_train_full = (y_train_full == 1).astype(float)
        y_test_full = (y_test_full == 1).astype(float)

    # Normalize pixel values to [0, 1]
    X_train_full = X_train_full.astype('float32') / 255.0
    X_test_full = X_test_full.astype('float32') / 255.0

    # Select subset of data
    indices_train = np.random.choice(len(X_train_full), n_train, replace=False)
    indices_test = np.random.choice(len(X_test_full), n_test, replace=False)
    X_train = X_train_full[indices_train]
    y_train = y_train_full[indices_train]
    X_test = X_test_full[indices_test]
    y_test = y_test_full[indices_test]

    # Data augmentation (only for training set)
    if augment:
        X_train_tf = tf.convert_to_tensor(X_train)
        X_train_tf = tf.map_fn(augment_image, X_train_tf)
        X_train = X_train_tf.numpy()

    # Create patches
    X_train_patches = create_patches(X_train)
    X_test_patches = create_patches(X_test)
    
    # Create TensorFlow Dataset for training
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train_patches, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=n_train, seed=42)  # Shuffle training data
    train_dataset = train_dataset.batch(batch_size)
    train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
    
    return (
        train_dataset,  # Batched and shuffled training data
        jnp.array(X_test_patches),
        jnp.array(y_test)
    )

In [None]:
def train_qvit(n_train, n_test, n_epochs, batch_size=128):
    # Load data
    train_dataset, x_test, y_test = load_cifar_data(n_train, n_test, batch_size)

    # Initialize model and parameters
    model = QSANN_image_classifier(S=64, n=5, Denc=2, D=1, num_layers=1)
    params = init_params(S=64, n=5, Denc=2, D=1, num_layers=1)
    
    # Define optimizer with cosine annealing learning rate schedule
    initial_lr = 0.003
    lr_schedule = optax.cosine_decay_schedule(init_value=initial_lr, decay_steps=n_epochs)
    optimizer = optax.adam(learning_rate=lr_schedule)

    opt_state = optimizer.init(params)

    # Create arrays to store metrics
    train_costs = []
    test_costs = []
    train_accs = []
    test_accs = []
    steps = []

    # Loss function
    def loss_fn(p, x, y):
        y_pred = model(x, p)
        return binary_cross_entropy(y, y_pred), y_pred

    # JIT-compiled update step for a single batch
    @jax.jit
    def update_batch(params, opt_state, x_batch, y_batch):
        (loss_val, y_pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x_batch, y_batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        batch_acc = accuracy(y_batch, y_pred)
        return new_params, new_opt_state, loss_val, batch_acc

    # Training loop
    current_params = params
    current_opt_state = opt_state
    start = time.time()
    
    for epoch in range(n_epochs):
        epoch_train_loss = 0.0
        epoch_train_acc = 0.0
        num_batches = 0

        for x_batch_tf, y_batch_tf in train_dataset:
            # Convert TensorFlow tensors to JAX arrays
            x_batch_jax = jnp.array(x_batch_tf.numpy())
            y_batch_jax = jnp.array(y_batch_tf.numpy())
            
            current_params, current_opt_state, batch_loss, batch_acc = update_batch(
                current_params, current_opt_state, x_batch_jax, y_batch_jax
            )
            epoch_train_loss += batch_loss
            epoch_train_acc += batch_acc
            num_batches += 1
        
        avg_epoch_train_loss = epoch_train_loss / num_batches
        avg_epoch_train_acc = epoch_train_acc / num_batches

        # Evaluate on test set at the end of each epoch
        test_loss, test_acc = evaluate(model, current_params, x_test, y_test)
        
        # Store metrics
        train_costs.append(float(avg_epoch_train_loss))
        train_accs.append(float(avg_epoch_train_acc))
        test_costs.append(float(test_loss))
        test_accs.append(float(test_acc))
        steps.append(epoch + 1)
        
        # Print progress every 10 epochs
        if (epoch + 1) % 1 == 0:
            print(f"Epoch {epoch+1}/{n_epochs} | "
                  f"Train Loss: {avg_epoch_train_loss:.4f} | "
                  f"Train Acc: {avg_epoch_train_acc:.4f} | "
                  f"Test Loss: {test_loss:.4f} | "
                  f"Test Acc: {test_acc:.4f}")

    training_time = time.time() - start
    print(f"\nTraining completed in {training_time:.2f} seconds")

    # Create DataFrame with results
    results_df = pd.DataFrame({
        'step': steps,
        'train_cost': train_costs,
        'train_acc': train_accs,
        'test_cost': test_costs,
        'test_acc': test_accs,
        'n_train': [n_train] * len(steps),
        'batch_size': [batch_size] * len(steps)  # Add batch_size to results
    })
    
    return results_df

In [None]:
def evaluate(model, params, x, y):
    y_pred = model(x, params)
    loss = binary_cross_entropy(y, y_pred)
    acc = accuracy(y, y_pred)
    return loss, acc

## Digits

In [None]:
def load_digits_data(n_train, n_test):
    digits = load_digits()
    X, y = digits.data, digits.target
    mask = (y == 0) | (y == 1)
    X, y = X[mask], y[mask]
    X = X / 16.0  # Normalize to [0, 1]
    X = X.reshape(-1, 4, 16)
    y = y.astype(jnp.float32)
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=n_train, test_size=n_test)
    return (
        jnp.array(X_train),
        jnp.array(y_train).reshape(-1, 1),
        jnp.array(X_test),
        jnp.array(y_test).reshape(-1, 1)
    )

In [None]:
def train_qvit(n_train, n_test, n_epochs):
    # Load data
    x_train, y_train, x_test, y_test = load_digits_data(n_train, n_test)

    # Initialize model and parameters
    model = QSANN_image_classifier(S=4, n=4, Denc=2, D=1, num_layers=1)
    params = init_params(S=4, n=4, Denc=2, D=1, num_layers=1)
    
    # Define optimizer with same learning rate as PyTorch
    optimizer = optax.adam(learning_rate=0.01)
    opt_state = optimizer.init(params)

    # Create arrays to store metrics
    train_cost_epochs = []
    test_cost_epochs = []
    train_acc_epochs = []
    test_acc_epochs = []

    # Loss function
    def loss_fn(p, x, y):
        y_pred = model(x, p)
        return binary_cross_entropy(y, y_pred), y_pred

    # JIT-compiled update step
    @jax.jit
    def update_step(params, opt_state, x_train, y_train, x_test, y_test):
        # Get both value and gradient, along with model predictions
        (loss_val, y_pred), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x_train, y_train)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)

        # Compute metrics
        train_acc = accuracy(y_train, y_pred)
        test_loss, test_acc = evaluate(model, new_params, x_test, y_test)

        return new_params, new_opt_state, loss_val, train_acc, test_loss, test_acc

    # Training loop
    start = time.time()
    
    for epoch in range(n_epochs):
        epoch_start = time.time()
        params, opt_state, train_cost, train_acc, test_cost, test_acc = update_step(
            params, opt_state, x_train, y_train, x_test, y_test
        )
        
        # Store metrics
        train_cost_epochs.append(float(train_cost))
        train_acc_epochs.append(float(train_acc))
        test_cost_epochs.append(float(test_cost))
        test_acc_epochs.append(float(test_acc))
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Train Size: {n_train}, Epoch: {epoch + 1}/{n_epochs}, "
                  f"Train Loss: {train_cost:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Test Loss: {test_cost:.4f}, Test Acc: {test_acc:.4f}")

    training_time = time.time() - start
    print(f"\nTraining completed in {training_time:.2f} seconds")

    return dict(
        n_train=[n_train] * n_epochs,
        step=np.arange(1, n_epochs + 1, dtype=int),
        train_cost=train_cost_epochs,
        train_acc=train_acc_epochs,
        test_cost=test_cost_epochs,
        test_acc=test_acc_epochs,
    )

In [None]:
def evaluate(model, params, x, y):
    y_pred = model(x, params)
    loss = binary_cross_entropy(y, y_pred)
    acc = accuracy(y, y_pred)
    return loss, acc