<a href="https://colab.research.google.com/github/hks-9697-v2/Jax-training/blob/main/Convergence_Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#https://arxiv.org/pdf/2301.05217v1

In [2]:
import jax
import flax

In [3]:
jax.devices()

[CpuDevice(id=0)]

In [4]:
import jax.numpy as jnp
import flax.linen as nn
# User input: The model takes 2 numbers as input 0 to 100 and get one number as output.
class Grok(nn.Module):
  n_attention_heads: int
  d_dense_layers: int
  x_dense_layer_size: int
  num_embeddings: int = 101 # For numbers 1 to 100 (index 0 to 100)
  embed_dim: int = 64 # Default embedding dimension

  @nn.compact
  def __call__(self, x):
    # x is expected to be a batch of two numbers (indices), e.g., (batch_size, 2)
    # 1. Apply embedding layer to the input x
    # The embedding layer will look up embeddings for each index in the input x
    # x.shape = (batch_size, 2)
    # embedded_x.shape = (batch_size, 2, embed_dim)
    embedded_x = self.embedding_layer(x)

    # For attention, we typically need a single sequence, or process elements sequentially.
    # Given 'x' is a batch of two numbers, we can treat them as two tokens in a sequence
    # and apply attention to allow them to interact.
    # The attention block usually expects (batch, sequence_length, features).
    # Here, sequence_length is 2.

    # 2. Pass the embedded output through the attention block
    # MultiHeadDotProductAttention expects query, key, value. For self-attention, all are the same.
    # Output of attention block will have the same shape as input: (batch_size, 2, embed_dim)
    attended_x = self.attention_block(inputs_q=embedded_x, inputs_kv=embedded_x)

    # After attention, we might want to aggregate or process further. Let's assume
    # we flatten the last two dimensions (sequence_length and embed_dim) for the dense layers
    # or perhaps take the mean/sum across the sequence length if a single output vector per batch is desired.
    # For now, let's assume we want to apply dense layers to each embedded token independently after attention,
    # or reshape to handle the sequence. A common pattern is to apply a global pooling or reshape.
    # Let's flatten for simplicity to feed into the first dense layer as a single feature vector per batch item,
    # or treat the attended tokens as separate features for a combined dense layer processing.
    # For this task, let's treat the output of the attention block as a combined feature for the dense layers.
    # Reshape from (batch_size, 2, embed_dim) to (batch_size, 2 * embed_dim) or apply dense layers per token.
    # Let's apply dense layers to the averaged representation of the two tokens for simplicity.
    # Averaging along the sequence dimension (axis=1) to get (batch_size, embed_dim)
    processed_x = jnp.mean(attended_x, axis=1) # (batch_size, embed_dim)

    # 3. Iterate through self.dense_layers, applying each dense layer with ReLU activation
    for i, dense_layer in enumerate(self.dense_layers):
      processed_x = dense_layer(processed_x)
      # Apply ReLU activation, excluding the very last dense layer if desired, but typically included.
      if i < len(self.dense_layers) - 1 or self.d_dense_layers > 0: # Always apply if there are layers
        processed_x = nn.relu(processed_x)

    # 4. Pass the result through the unembedding layer
    # The unembedding layer projects back to num_embeddings for prediction (e.g., next number token).
    logits = self.unembedding_layer(processed_x)

    return logits

  def setup(self):
    # a. Define an embedding layer
    self.embedding_layer = nn.Embed(num_embeddings=self.num_embeddings, features=self.embed_dim)

    # b. Define the attention mechanism (MultiHeadDotProductAttention)
    # Set num_heads to self.n_attention_heads
    # Set qkv_features and out_features to self.embed_dim
    self.attention_block = nn.MultiHeadDotProductAttention(
        num_heads=self.n_attention_heads,
        qkv_features=self.embed_dim, # Features for query, key, value projections
        out_features=self.embed_dim # Output features dimension
    )

    # c. Create a series of 'd' dense layers, each of size 'x'
    self.dense_layers = [nn.Dense(features=self.x_dense_layer_size) for _ in range(self.d_dense_layers)]

    # d. Define a final unembedding layer
    # This layer projects to 'self.num_embeddings' for classification/prediction of the next number.
    self.unembedding_layer = nn.Dense(features=self.num_embeddings)


In [5]:
import jax
import jax.numpy as jnp

# Define the model parameters as specified by the user
n_attention_heads = 4
d_dense_layers = 3
x_dense_layer_size = 512
num_embeddings = 101 # User specified 100, original was 101 for 0-100
embed_dim = 32

# Create an instance of the Grok model
grok_model = Grok(
    n_attention_heads=n_attention_heads,
    d_dense_layers=d_dense_layers,
    x_dense_layer_size=x_dense_layer_size,
    num_embeddings=num_embeddings,
    embed_dim=embed_dim
)

# Generate a dummy input for initialization. The model expects (batch_size, 2)
# Let's assume a batch size of 1 and two input numbers (indices).
# Since num_embeddings is 100, valid indices are 0 to 99.
dummy_input = jnp.array([[10, 20]], dtype=jnp.int32)

# Initialize the model's parameters
key = jax.random.PRNGKey(1234)
params = grok_model.init(key, dummy_input)

print("Grok model initialized successfully!")
print("Initialized parameters structure:")
print(jax.tree.map(lambda x: x.shape, params))

Grok model initialized successfully!
Initialized parameters structure:
{'params': {'attention_block': {'key': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'out': {'bias': (32,), 'kernel': (4, 8, 32)}, 'query': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'value': {'bias': (4, 8), 'kernel': (32, 4, 8)}}, 'dense_layers_0': {'bias': (512,), 'kernel': (32, 512)}, 'dense_layers_1': {'bias': (512,), 'kernel': (512, 512)}, 'dense_layers_2': {'bias': (512,), 'kernel': (512, 512)}, 'embedding_layer': {'embedding': (101, 32)}, 'unembedding_layer': {'bias': (101,), 'kernel': (512, 101)}}}


In [6]:
# Generate a batch of random input numbers (indices)
# Ensure the random numbers are within the valid range [0, num_embeddings - 1]
random_key, subkey = jax.random.split(key)
random_inputs = jax.random.randint(subkey, (10, 2), minval=0, maxval=num_embeddings, dtype=jnp.int32)

print(f"Random input numbers: {random_inputs}")

# Run the model with the initialized parameters and the random inputs
logits = grok_model.apply(params, random_inputs)
probabilities = jax.nn.softmax(logits)

print(f"Model predicted output: {jnp.argmax(probabilities, -1)}")


Random input numbers: [[88 75]
 [ 3 70]
 [94 28]
 [58  5]
 [ 2 54]
 [58 46]
 [ 4 25]
 [31 62]
 [56 87]
 [33 48]]
Model predicted output: [67 79  5 11 11 11 67 11 67 21]


In [7]:
import optax # Optimization library for JAX

# 1. Define the loss function
# For classification tasks with integer labels, softmax_cross_entropy_with_integer_labels is suitable.
# It expects logits (model output before softmax) and integer labels.
# The function takes `logits` and `labels` as input and returns the per-example loss.
def cross_entropy_loss(logits, labels):
    # The optax function expects labels to be integers representing class indices
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    return jnp.mean(loss) # Return the mean loss over the batch

# 2. Define the optimizer
# AdamW is a popular choice for deep learning models.
learning_rate = 1e-3 # A common starting learning rate
optimizer = optax.adamw(learning_rate=learning_rate)


In [8]:
import jax.numpy as jnp

# 3. Generate a dataset with all possible combinations
# Inputs x and y are numbers from 0 to num_embeddings-1
# Output is (x + y) % num_embeddings

def generate_full_dataset(num_embeddings):
    # Generate all possible values for x and y
    x_values = jnp.arange(num_embeddings, dtype=jnp.int32)
    y_values = jnp.arange(num_embeddings, dtype=jnp.int32)

    # Create all possible pairs (x, y)
    # jnp.meshgrid with indexing='ij' ensures (x,y) pairs are generated column by column for X then row by row for Y
    X, Y = jnp.meshgrid(x_values, y_values, indexing='ij')
    inputs = jnp.stack([X.flatten(), Y.flatten()], axis=1) # shape (num_embeddings^2, 2)

    # Calculate targets: (x + y) % num_embeddings
    targets = jnp.sum(inputs, axis=1) % num_embeddings

    # Sort the combinations: first by x, then by y
    # jnp.lexsort sorts by multiple keys, with the last key being primary.
    # So, we want to sort by inputs[:, 0] (x) first, then inputs[:, 1] (y).
    # The keys for lexsort are provided in reverse order of precedence.
    sort_indices = jnp.lexsort((inputs[:, 1], inputs[:, 0]))
    sorted_inputs = inputs[sort_indices]
    sorted_targets = targets[sort_indices]

    return sorted_inputs, sorted_targets

# Set parameters for the dataset generation
# num_embeddings is already defined globally from previous cells (e.g., 100)

# Generate the full dataset
full_inputs, full_targets = generate_full_dataset(num_embeddings)

print(f"Full dataset generated with {full_inputs.shape[0]} samples (all {num_embeddings}x{num_embeddings} combinations).")
print(f"Full inputs shape: {full_inputs.shape}, dtype: {full_inputs.dtype}")
print(f"Full targets shape: {full_targets.shape}, dtype: {full_targets.dtype}")

# Split the data into training and test sets
train_ratio = 0.6
num_total_samples = full_inputs.shape[0]
num_train_samples = int(num_total_samples * train_ratio)

# Shuffle indices to ensure a random split, even though the data is sorted
# Use a fixed key for reproducibility
key, _ = jax.random.split(jax.random.PRNGKey(42)) # Re-initialize key for this step if needed
shuffled_indices = jax.random.permutation(key, num_total_samples)

train_indices = shuffled_indices[:num_train_samples]
test_indices = shuffled_indices[num_train_samples:]

train_inputs = full_inputs[train_indices] # Renaming to dummy_inputs for consistency with train_model signature
train_targets = full_targets[train_indices]

test_inputs = full_inputs[test_indices]
test_targets = full_targets[test_indices]

print(f"\nDataset split: {train_inputs.shape[0]} training samples, {test_inputs.shape[0]} test samples.")
print("First 5 training input-target pairs:")
for i in range(5):
    print(f"  Input: {train_inputs[i]}, Target: {train_targets[i]}")

print("\nFirst 5 test input-target pairs:")
for i in range(5):
    print(f"  Input: {test_inputs[i]}, Target: {test_targets[i]}")

Full dataset generated with 10201 samples (all 101x101 combinations).
Full inputs shape: (10201, 2), dtype: int32
Full targets shape: (10201,), dtype: int32

Dataset split: 6120 training samples, 4081 test samples.
First 5 training input-target pairs:
  Input: [98 24], Target: 21
  Input: [45 71], Target: 15
  Input: [100  69], Target: 68
  Input: [25 96], Target: 20
  Input: [70 53], Target: 22

First 5 test input-target pairs:
  Input: [28 80], Target: 7
  Input: [43 15], Target: 58
  Input: [47 99], Target: 45
  Input: [27 29], Target: 56
  Input: [63  6], Target: 69


In [9]:
import jax
import jax.numpy as jnp
import optax

# The Grok model instance and cross_entropy_loss function are already defined in previous cells.
# We need to make sure `grok_model` and `optimizer` are accessible in this scope.
# For the purpose of this cell, we assume they are globally available from previous executions.

@jax.jit # Compile the function for performance
def train_step(params, opt_state, batch_inputs, batch_labels):
    # Define the loss function to be differentiated
    def loss_fn(current_params):
        # Get logits from the model
        logits = grok_model.apply(current_params, batch_inputs)
        # Calculate the loss using the predefined cross_entropy_loss
        loss = cross_entropy_loss(logits, batch_labels)
        return loss

    # Compute the loss and gradients using jax.value_and_grad
    loss_value, grads = jax.value_and_grad(loss_fn)(params)

    # Update the optimizer's state and get parameter updates
    updates, new_opt_state = optimizer.update(grads, opt_state, params)

    # Apply the updates to the model parameters
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss_value

print("The 'train_step' function has been defined and compiled with JIT.")

The 'train_step' function has been defined and compiled with JIT.


In [10]:
opt_state = optimizer.init(params)

print("Optimizer state initialized successfully!")
print(f"Type of opt_state: {type(opt_state)}")
print("First few elements of opt_state:")
print(jax.tree.map(lambda x: x.shape if hasattr(x, 'shape') else x, opt_state))

Optimizer state initialized successfully!
Type of opt_state: <class 'tuple'>
First few elements of opt_state:
(ScaleByAdamState(count=(), mu={'params': {'attention_block': {'key': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'out': {'bias': (32,), 'kernel': (4, 8, 32)}, 'query': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'value': {'bias': (4, 8), 'kernel': (32, 4, 8)}}, 'dense_layers_0': {'bias': (512,), 'kernel': (32, 512)}, 'dense_layers_1': {'bias': (512,), 'kernel': (512, 512)}, 'dense_layers_2': {'bias': (512,), 'kernel': (512, 512)}, 'embedding_layer': {'embedding': (101, 32)}, 'unembedding_layer': {'bias': (101,), 'kernel': (512, 101)}}}, nu={'params': {'attention_block': {'key': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'out': {'bias': (32,), 'kernel': (4, 8, 32)}, 'query': {'bias': (4, 8), 'kernel': (32, 4, 8)}, 'value': {'bias': (4, 8), 'kernel': (32, 4, 8)}}, 'dense_layers_0': {'bias': (512,), 'kernel': (32, 512)}, 'dense_layers_1': {'bias': (512,), 'kernel': (512, 512)}, 'dense_layer

**Reasoning**:
The previous step successfully initialized the optimizer state. Now, I will implement the `train_model` function that encapsulates the full training loop as requested by the main task. This function will iterate for a specified number of epochs, process the dummy dataset in batches, call the `train_step` function, and collect loss values.



In [11]:
import jax.numpy as jnp
import numpy as np # For data shuffling
import time # Import time module for step time calculation

# Helper function to calculate loss and accuracy
def calculate_metrics(params, inputs, targets):
    logits = grok_model.apply(params, inputs)
    loss = cross_entropy_loss(logits, targets)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == targets)
    return loss, accuracy

def train_model(initial_params, initial_opt_state, train_inputs, train_targets, test_inputs, test_targets, num_epochs, batch_size, accuracy_threshold=0.8, random_seed=0):
    current_params = initial_params
    current_opt_state = initial_opt_state
    loss_history = []
    train_accuracy_history = []
    test_accuracy_history = []
    avg_step_time_history = [] # To store average step time per epoch

    key = jax.random.PRNGKey(random_seed)

    num_train_samples = train_inputs.shape[0]
    num_train_batches = num_train_samples // batch_size

    print(f"Starting training for {num_epochs} epochs with batch size {batch_size}...")
    print(f"Early stopping if test accuracy reaches {accuracy_threshold}")

    for epoch in range(num_epochs):
        key, subkey = jax.random.split(key)
        # Shuffle indices for each epoch
        shuffled_indices = jax.random.permutation(subkey, num_train_samples)
        shuffled_train_inputs = train_inputs[shuffled_indices]
        shuffled_train_targets = train_targets[shuffled_indices]

        epoch_losses = []
        epoch_step_times = [] # To store step times for the current epoch

        for i in range(num_train_batches):
            batch_start = i * batch_size
            batch_end = (i + 1) * batch_size
            batch_inputs = shuffled_train_inputs[batch_start:batch_end]
            batch_labels = shuffled_train_targets[batch_start:batch_end]

            # Perform a training step and measure its time
            start_step_time = time.time()
            current_params, current_opt_state, loss_value = train_step(
                current_params, current_opt_state, batch_inputs, batch_labels
            )
            end_step_time = time.time()
            epoch_step_times.append(end_step_time - start_step_time)

            epoch_losses.append(loss_value)
            loss_history.append(loss_value)

        key, subkey_train, subkey_test = jax.random.split(key, 3)
        # Sample 1000 random indices for Train set
        train_sample_idx = jax.random.choice(subkey_train, num_train_samples, shape=(1000,), replace=False)
        sample_train_inputs = train_inputs[train_sample_idx]
        sample_train_targets = train_targets[train_sample_idx]

        # Sample 1000 random indices for Test set (handling cases where test set < 1000)
        num_test_samples = test_inputs.shape[0]
        test_sample_size = min(1000, num_test_samples)
        test_sample_idx = jax.random.choice(subkey_test, num_test_samples, shape=(test_sample_size,), replace=False)
        sample_test_inputs = test_inputs[test_sample_idx]
        sample_test_targets = test_targets[test_sample_idx]
        # Calculate metrics on the subsets
        _, train_accuracy = calculate_metrics(current_params, sample_train_inputs, sample_train_targets)
        _, test_accuracy = calculate_metrics(current_params, sample_test_inputs, sample_test_targets)
        # Evaluate metrics on full train and test sets after each epoch
        train_loss, train_accuracy = calculate_metrics(current_params, train_inputs, train_targets)
        test_loss, test_accuracy = calculate_metrics(current_params, test_inputs, test_targets)

        train_accuracy_history.append(train_accuracy)
        test_accuracy_history.append(test_accuracy)

        # Calculate average step time for the epoch
        avg_epoch_step_time = jnp.mean(jnp.array(epoch_step_times))
        avg_step_time_history.append(avg_epoch_step_time)

        # Print average loss and accuracy for the epoch
        avg_epoch_loss = jnp.mean(jnp.array(epoch_losses))

        print(f"Epoch {epoch + 1}/{num_epochs}, Avg Train Loss: {avg_epoch_loss:.32f}, Train Acc: {train_accuracy:.4f}, Test Acc: {test_accuracy:.4f}, Avg Step Time: {avg_epoch_step_time:.4f}s")

        # Early stopping condition
        if test_accuracy >= accuracy_threshold:
            print(f"Early stopping triggered: Test accuracy {test_accuracy:.4f} >= {accuracy_threshold:.4f}")
            break

    print("Training complete!")
    return current_params, current_opt_state, loss_history, train_accuracy_history, test_accuracy_history, avg_step_time_history

In [12]:
batch_sizes_to_test = [1, 2, 4, 8, 16, 32, 64, 128]
num_epochs = 30
accuracy_threshold = 1

print(f"Defined batch sizes to test: {batch_sizes_to_test}")

Defined batch sizes to test: [1, 2, 4, 8, 16, 32, 64, 128]


In [13]:
final_accuracies = {}
all_loss_history = {}

# Re-initialize the model and optimizer for each run. Need initial_params and initial_opt_state
# The original 'params' and 'opt_state' were initialized in previous cells.
# We need to make sure we start from a fresh model for each batch size.

# Function to re-initialize model parameters and optimizer state
def reinitialize_model_and_optimizer(key, dummy_input, grok_model, optimizer):
    new_params = grok_model.init(key, dummy_input)
    new_opt_state = optimizer.init(new_params)
    return new_params, new_opt_state


print("Starting training iterations for different batch sizes...")
key_initial_params = jax.random.PRNGKey(0) # Use a fixed key for reproducible initial params

for bs in batch_sizes_to_test:
    print(f"\n----------------------------------------")
    print(f"Training with batch_size = {bs}")
    print(f"----------------------------------------")

    # Re-initialize parameters and optimizer state for the current batch size
    # Make sure to use a new subkey for initialization each time if needed,
    # but for overall reproducibility, using the same initial key for params creation is fine.
    # The training key for shuffling inside train_model will be different anyway.

    # Re-using the dummy_input defined earlier
    new_params, new_opt_state = reinitialize_model_and_optimizer(key_initial_params, dummy_input, grok_model, optimizer)

    # Train the model with the current batch size
    # Use a fresh random seed for train_model for each batch size run
    key_initial_params, subkey_train = jax.random.split(key_initial_params)

    _, _, loss_history, _, test_acc_history, _ = train_model(
        new_params, new_opt_state, train_inputs, train_targets, test_inputs, test_targets,
        num_epochs, bs, accuracy_threshold, random_seed=subkey_train[0].item() # Use subkey as random_seed
    )

    final_accuracies[bs] = test_acc_history[-1]
    all_loss_history[bs] = loss_history
    print(f"Final test accuracy for batch_size {bs}: {final_accuracies[bs]:.4f}")

print("\n----------------------------------------")
print("Training iterations complete!")
print("Summary of final test accuracies:")
for bs, acc in final_accuracies.items():
    print(f"Batch Size: {bs}, Final Test Accuracy: {acc:.4f}")

Starting training iterations for different batch sizes...

----------------------------------------
Training with batch_size = 1
----------------------------------------
Starting training for 30 epochs with batch size 1...
Early stopping if test accuracy reaches 1
Epoch 1/30, Avg Train Loss: 4.62979793548583984375000000000000, Train Acc: 0.0118, Test Acc: 0.0071, Avg Step Time: 0.0013s
Epoch 2/30, Avg Train Loss: 4.62303113937377929687500000000000, Train Acc: 0.0118, Test Acc: 0.0066, Avg Step Time: 0.0010s
Epoch 3/30, Avg Train Loss: 4.62077379226684570312500000000000, Train Acc: 0.0119, Test Acc: 0.0069, Avg Step Time: 0.0010s


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Retrieve num_train_samples from the kernel state, which was defined in cell 167ab1ed.
num_train_samples = train_inputs.shape[0]

plt.figure(figsize=(12, 8))

# Iterate through each batch size and its corresponding flat loss history
for bs, loss_history_flat in all_loss_history.items():
    # Calculate the number of batches per epoch for the current batch size
    num_train_batches = num_train_samples // bs

    # Skip if batch size is larger than training samples, as no training would occur
    if num_train_batches == 0:
        print(f"Warning: Batch size {bs} is larger than num_train_samples {num_train_samples}. Skipping plot for this batch size.")
        continue

    # Determine the actual number of epochs that were run for this batch size
    # This is inferred from the total length of the loss history divided by batches per epoch.
    actual_epochs_run = len(loss_history_flat) // num_train_batches

    # Calculate the average loss for each epoch
    epoch_losses = []
    for i in range(actual_epochs_run):
        batch_losses_in_epoch = loss_history_flat[i * num_train_batches : (i + 1) * num_train_batches]
        if batch_losses_in_epoch: # Ensure there are losses to average
            epoch_losses.append(np.mean(batch_losses_in_epoch))
        else:
            # This case should ideally not happen if actual_epochs_run is calculated correctly.
            print(f"Warning: No batch losses found for epoch {i+1} for batch size {bs}. Stopping processing for this batch size.")
            break

    if epoch_losses: # Only plot if there are valid epoch losses
        # Convert epoch losses to a numpy array and apply logarithm.
        # A small epsilon is added to avoid issues with log(0) if losses become very small.
        log_epoch_losses = np.log(np.array(epoch_losses) + 1e-9)

        # Plot the log of average epoch loss against the epoch number
        plt.plot(range(1, actual_epochs_run + 1), log_epoch_losses, label=f'Batch Size: {bs}')
    else:
        print(f"No valid epoch losses to plot for batch size {bs}.")

plt.xlabel('Epoch')
plt.ylabel('Log(Average Training Loss)')
plt.title('Logarithm of Average Training Loss per Epoch for Different Batch Sizes')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

print("Plot of log of average training loss per epoch for different batch sizes generated.")

In [None]:
import matplotlib.pyplot as plt

# Ensure final_accuracies is populated (assuming previous cell completes execution)
# For demonstration, if previous cell output is truncated, we might have to manually define it for plotting
# if not final_accuracies:
#     # This is a fallback if the previous cell didn't fully execute for ALL batch sizes.
#     # In a real scenario, the previous cell would complete.
#     print("Warning: final_accuracies not fully populated. Using dummy data for plot.")
#     final_accuracies = {
#         2: 0.0047,
#         4: 0.0050,
#         8: 0.0120,
#         16: 0.5000,
#         32: 0.9500,
#         64: 0.9966,
#         128: 0.9970,
#         256: 0.9972
#     }

# Convert results to lists for plotting
batch_sizes = list(final_accuracies.keys())
accuracies = list(final_accuracies.values())

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(batch_sizes, accuracies, marker='o', linestyle='-', color='skyblue')

# Add labels and title
plt.xscale('log', base=2) # Batch sizes are powers of 2, so a log scale is appropriate
plt.xticks(batch_sizes, labels=[str(bs) for bs in batch_sizes])
plt.xlabel('Batch Size (log scale)')
plt.ylabel('Final Test Accuracy')
plt.title('Impact of Batch Size on Final Test Accuracy')
plt.grid(True, which="both", ls="--", c='0.7')
plt.ylim(0, 1.05) # Accuracy is between 0 and 1

# Add text labels for each point
for i, txt in enumerate(accuracies):
    plt.annotate(f'{txt:.4f}', (batch_sizes[i], accuracies[i]), textcoords="offset points", xytext=(0,5), ha='center')

plt.tight_layout()
plt.show()

print("Plot of final test accuracy vs. batch size generated.")


Batch Testing Output: [link](https://pastebin.com/f2VncuYV)