In [5]:
!git clone https://github.com/beckhamtoh/char-llm-assignment.git
%cd char-llm-assignment

Cloning into 'char-llm-assignment'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 15 (delta 0), reused 12 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (15/15), 30.14 MiB | 32.11 MiB/s, done.
/content/char-llm-assignment


In [6]:
# Enable autoreload of local Python modules (e.g., models)
# %load_ext autoreload
# %autoreload 2

# manual reload for local modules
import importlib

In [7]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
import optax
import time

# local imports
import models.models as models
import util.generation as generation


In [8]:
# initialize the jax random key
key = jax.random.key(0)

# Load data

In [9]:
# load the ./data/text8_train.txt and ./data/text8_test.txt files
with open("./data/text8_train.txt", "r") as f:
    train_text = f.read()
with open("./data/text8_test.txt", "r") as f:
    test_text = f.read()

# print the length of the training text and test text
print(f"Length of training text: {len(train_text):_} characters")
print(f"Length of test text: {len(test_text):_} characters")

Length of training text: 90_000_000 characters
Length of test text: 5_000_000 characters


In [10]:
# Build vocabulary (lowercase + space + a few punctuations)
char_set = list("abcdefghijklmnopqrstuvwxyz ")
char_to_int = {ch:i for i,ch in enumerate(char_set)}
int_to_char = {i:ch for ch,i in char_to_int.items()}

def encode(s):
    """Encode string to array of integers"""
    ids = [char_to_int[c] for c in s]
    return np.array(ids, dtype=np.uint8)  # use np.uint8 to save space

In [11]:
# encode the text
train_text_int = encode(train_text)
test_text_int = encode(test_text)

In [12]:
# sanity check: display a few random characters from the training text
T = 128
for _ in range(5):
    # choose random position in text
    N = np.random.randint(low=0, high=len(train_text)-T)
    print(train_text[N:N+T])
    print()

of the illness and went to the fragrant mountain to give thanks to the person when he discovered that his own daughter gave up h

together they performed the arcade fire s song wake up from their album funeral he joined them again on one five september singi

ro asiatic language phylum its closest relatives are the berber semitic and beja groups of languages written records of the egyp

es were produced between about one nine three zero and one nine three five but the concept was abandonded because of its limited

xt is read aloud twice during the celebration setting the biblical book of esther is set in the third year of ahasuerus a king o



# Create a basic Transformer model

In [13]:
def create_train_state(rng, vocab_size=27, d_model=64, n_layers=6, n_heads=8, max_len=128):
    # create a basic Transformer model
    model = models.DecoderOnlyTransformer(vocab_size, d_model, n_layers, n_heads, max_len)
    # create a dummy input for initialization
    dummy = jnp.zeros((1, min(16, max_len)), dtype=jnp.int32)
    # pass the dummy input to the model to initialize the parameters
    params = model.init({"params": rng}, dummy)["params"]
    return model, params

In [14]:
# vocab size
vocab_size= len(char_set)

# internal model dimensions
d_model=256

# number of attention heads
n_heads=8

# number of Transformer layers
n_layers=2

# maximum sequence length
max_len=128

model, params = create_train_state(key, vocab_size, d_model, n_layers, n_heads, max_len)

In [15]:
# compute the number of parameters
def count_params(params):
    return sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"Number of parameters: {count_params(params):_}")

Number of parameters: 1_624_576


In [16]:
# sanity check: create a batch of data & run a forward pass
B, T = 4, 32
batch = jax.random.randint(
    key=key,
    shape=(B, T), minval=0, maxval=len(char_set))
logits = model.apply({"params": params}, batch)

print("batch shape:", batch.shape)  # (B, T)
print("logits shape:", logits.shape)  # (B, T, vocab_size)

batch shape: (4, 32)
logits shape: (4, 32, 27)


# Loss function

Original code:

In [17]:
@jax.jit
def loss_and_metrics(logits, targets):
    """Compute cross-entropy loss and accuracy.

    Assumes `targets` contains only valid integer class ids in [0, V-1] (no -1 ignore tokens).

    Args:
      logits: (B, T, V) float array of unnormalized scores.
      targets: (B, T) integer array with ground-truth class ids.

    Returns:
      loss: scalar average cross-entropy over all positions.
      metrics: dict with keys "loss" and "acc" (both scalars).
    """
    # Flatten batch/time dims so optax works on shape (N, V) and (N,)
    vocab = logits.shape[-1]
    flat_logits = logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Per-position cross-entropy, then mean over all positions
    per_pos = optax.softmax_cross_entropy_with_integer_labels(flat_logits, flat_targets)
    loss = per_pos.mean()

    # prediction over all positions
    preds = jnp.argmax(logits, axis=-1)  # (B, T)

    # compute accuracy over only the last position
    is_match = preds == targets

    # Accuracy over all positions
    acc_all = jnp.mean(is_match.astype(jnp.float32))

    # Accuracy over only last position
    acc_last = jnp.mean(is_match.astype(jnp.float32)[:,-1])

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

### Loss everywhere vs. loss at last token

last_token_only = False vs last_token_only = True

In [None]:
@jax.jit
def loss_and_metrics(logits, targets, last_token_only=False):
    """Compute cross-entropy loss and accuracy.

    Args:
      logits: (B, T, V) float array of unnormalized scores.
      targets: (B, T) integer array with ground-truth class ids.
      last_token_only: If True, compute loss only on the last token position.

    Returns:
      loss: scalar average cross-entropy.
      metrics: dict with keys "loss", "acc", and "acc_last".
    """
    vocab = logits.shape[-1]
    flat_logits = logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Compute per-position cross-entropy
    per_pos = optax.softmax_cross_entropy_with_integer_labels(flat_logits, flat_targets)
    
    if last_token_only:
        # Reshape back to (B, T) and take only last token
        per_pos = per_pos.reshape(targets.shape)  # (B, T)
        loss = per_pos[:, -1].mean()  # Average over batch, last position only
    else:
        # Original behavior: average over all positions
        loss = per_pos.mean()

    # Predictions and accuracy (unchanged)
    preds = jnp.argmax(logits, axis=-1)
    is_match = preds == targets
    acc_all = jnp.mean(is_match.astype(jnp.float32))
    acc_last = jnp.mean(is_match[:, -1].astype(jnp.float32))

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

# Label smoothing

- label_smoothing = 0.0 - baseline (hard labels)
- label_smoothing = 0.05 - light smoothing
- label_smoothing = 0.1 - standard smoothing (common choice)
- label_smoothing = 0.2 - heavier smoothing

In [None]:
@jax.jit
def loss_and_metrics(logits, targets, label_smoothing=0.0):
    """Compute cross-entropy loss and accuracy."""
    vocab = logits.shape[-1]
    flat_logits = logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Always use one-hot encoding, label_smoothing=0.0 becomes hard labels
    one_hot_targets = jax.nn.one_hot(flat_targets, vocab)
    smooth_targets = one_hot_targets * (1 - label_smoothing) + label_smoothing / vocab
    per_pos = optax.softmax_cross_entropy(flat_logits, smooth_targets)
    
    loss = per_pos.mean()

    preds = jnp.argmax(logits, axis=-1)
    is_match = preds == targets
    acc_all = jnp.mean(is_match.astype(jnp.float32))
    acc_last = jnp.mean(is_match[:, -1].astype(jnp.float32))

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

In [None]:
# Modified train_step that supports label smoothing
def train_step(params, opt_state, x, y, tx, label_smoothing=0.0):
    """Single optimization step using optax optimizer.

    Args:
      params: pytree of model parameters.
      opt_state: optax optimizer state corresponding to `params`.
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
      tx: optax.GradientTransformation (already initialized).
      label_smoothing: Float in [0, 1] for label smoothing.

    Returns:
      new_params: updated parameters after one gradient step.
      new_opt_state: updated optimizer state.
      metrics: dict of scalar metrics (loss, acc).
    """
    def loss_fn(params):
        logits = model.apply({"params": params}, x)
        loss, metrics = loss_and_metrics(logits, y, label_smoothing=label_smoothing)
        return loss, metrics

    # compute gradients (loss is scalar, metrics is auxiliary)
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # optax update: compute parameter updates and new optimizer state
    updates, new_opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, metrics

# jit: mark both tx and label_smoothing as static
train_step = jax.jit(train_step, static_argnames=("tx", "label_smoothing"))

# create a batch from the training data
def get_batch(text_int, B, T):
    """Create a random batch of data from text_int.

    Args:
      text_int: 1D array of token ids.
      B: batch size (number of sequences).
      T: sequence length (number of tokens per sequence).

    Returns:
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
    """
    # choose random starting indices for each sequence in the batch
    ix = np.random.randint(0, len(text_int) - T, size=B)
    # inputs are text from i to i+T
    x = np.stack([text_int[i:i+T] for i in ix])
    # targets are text from i+1 to i+T+1
    y = np.stack([text_int[i+1:i+T+1] for i in ix])
    return jnp.array(x, dtype=jnp.int32), jnp.array(y, dtype=jnp.int32)

# define optax optimizer
learning_rate = 0.001
# Create Adam optimizer (Optax)
tx = optax.adam(learning_rate=learning_rate)
# Initialize optimizer state for current params
opt_state = tx.init(params)
print(f"Initialized optimizer: Adam lr={learning_rate}")

In [None]:
def train_with_label_smoothing(model, params_init, opt_state_init, tx, train_text_int, 
                                test_text_int, train_step_fn, get_batch, 
                                label_smoothing_values, niter=5000, B=128, T=32):
    """
    Train model with different label smoothing values and collect results.
    
    Args:
        model: The model to train
        params_init: Initial model parameters
        opt_state_init: Initial optimizer state
        tx: Optimizer
        train_text_int: Training data
        test_text_int: Test data
        train_step_fn: Training step function (should accept label_smoothing)
        get_batch: Function to get batches
        label_smoothing_values: List of label smoothing values to try
        niter: Number of training iterations
        B: Batch size
        T: Sequence length
    
    Returns:
        results: Dictionary with training histories for each smoothing value
    """
    results = {}
    
    for ls_val in label_smoothing_values:
        print(f"\n{'='*60}")
        print(f"Training with label_smoothing = {ls_val}")
        print(f"{'='*60}\n")
        
        # Reset model parameters and optimizer state for each run
        params = params_init
        opt_state = opt_state_init
        
        loss_history = []
        time_history = []
        loss_test_history = []
        acc_test_history = []
        acc_last_test_history = []
        time_test_history = []
        
        time_start = time.time()
        
        for it in range(niter):
            batch = get_batch(train_text_int, B, T)
            input, target = batch[0], batch[1]
            
            # Modified train_step to accept label_smoothing
            params_new, opt_state_new, metrics = train_step_fn(
                params, opt_state, input, target, tx, label_smoothing=ls_val  
            )
            
            params = params_new
            opt_state = opt_state_new
            acc = metrics['acc']
            acc_last = metrics['acc_last']
            loss = metrics['loss']
            
            loss_history.append(float(loss))
            time_history.append(time.time() - time_start)
            
            if it % (niter // 50) == 0 or it == niter - 1:
                time_since_start = time.time() - time_start
                
                # Compute loss on test set
                B_test, T_test = 1024, 32
                test_batch = get_batch(test_text_int, B_test, T_test)
                test_input, test_target = test_batch[0], test_batch[1]
                test_logits = model.apply({"params": params}, test_input)
                
                # Use same label_smoothing for test evaluation
                test_loss, test_metrics = loss_and_metrics(
                    test_logits, test_target, label_smoothing=ls_val
                )
                test_acc = test_metrics['acc']
                test_acc_last = test_metrics['acc_last']
                
                loss_test_history.append(float(test_loss))
                acc_test_history.append(float(test_acc))
                acc_last_test_history.append(float(test_acc_last))
                time_test_history.append(time_since_start)
                
                print(f"iteration {it:_}  time: {time_since_start:.1f} seconds")
                print(f"\t \t loss(train :: test): {loss:.4f} :: {test_loss:.4f}")
                print(f"\t \t accuracy (train :: test): {100*acc:.1f}% :: {100*test_acc:.1f}%")
                print(f"\t \t accuracy (last character) (train :: test): {100*acc_last:.1f}% :: {100*test_acc_last:.1f}%")
                print()
        
        results[ls_val] = {
            'loss_train': loss_history,
            'loss_test': loss_test_history,
            'acc_test': acc_test_history,
            'acc_last_test': acc_last_test_history,
            'time_train': time_history,
            'time_test': time_test_history
        }
    
    return results

In [None]:
def plot_label_smoothing_comparison(results):
    """
    Plot comparison of different label smoothing values.
    
    Args:
        results: Dictionary returned by train_with_label_smoothing()
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Label Smoothing Comparison', fontsize=16, fontweight='bold')
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    labels_map = {
        0.0: 'Baseline (0.0)',
        0.05: 'Light (0.05)',
        0.1: 'Standard (0.1)',
        0.2: 'Heavy (0.2)'
    }
    
    # Plot 1: Training Loss over time
    ax1 = axes[0, 0]
    for idx, (ls_val, data) in enumerate(results.items()):
        ax1.plot(data['time_train'], data['loss_train'], 
                label=labels_map.get(ls_val, f'{ls_val}'), linewidth=1.5, 
                color=colors[idx % len(colors)], alpha=0.7)
    ax1.set_xlabel('Time (seconds)', fontsize=12)
    ax1.set_ylabel('Training Loss', fontsize=12)
    ax1.set_title('Training Loss vs Time', fontsize=13, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Test Loss
    ax2 = axes[0, 1]
    for idx, (ls_val, data) in enumerate(results.items()):
        ax2.plot(data['time_test'], data['loss_test'], 
                label=labels_map.get(ls_val, f'{ls_val}'), linewidth=2, 
                marker='o', markersize=4, color=colors[idx % len(colors)])
    ax2.set_xlabel('Time (seconds)', fontsize=12)
    ax2.set_ylabel('Test Loss', fontsize=12)
    ax2.set_title('Test Loss vs Time', fontsize=13, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Test Accuracy
    ax3 = axes[1, 0]
    for idx, (ls_val, data) in enumerate(results.items()):
        acc_percent = [a * 100 for a in data['acc_test']]
        ax3.plot(data['time_test'], acc_percent, 
                label=labels_map.get(ls_val, f'{ls_val}'), linewidth=2, 
                marker='s', markersize=4, color=colors[idx % len(colors)])
    ax3.set_xlabel('Time (seconds)', fontsize=12)
    ax3.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax3.set_title('Overall Test Accuracy', fontsize=13, fontweight='bold')
    ax3.legend(fontsize=10)
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Last Token Accuracy
    ax4 = axes[1, 1]
    for idx, (ls_val, data) in enumerate(results.items()):
        acc_last_percent = [a * 100 for a in data['acc_last_test']]
        ax4.plot(data['time_test'], acc_last_percent, 
                label=labels_map.get(ls_val, f'{ls_val}'), linewidth=2, 
                marker='^', markersize=4, color=colors[idx % len(colors)])
    ax4.set_xlabel('Time (seconds)', fontsize=12)
    ax4.set_ylabel('Last Token Accuracy (%)', fontsize=12)
    ax4.set_title('Last Token Test Accuracy', fontsize=13, fontweight='bold')
    ax4.legend(fontsize=10)
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

def plot_final_comparison(results):
    """
    Create bar charts comparing final metrics.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle('Final Metrics Comparison', fontsize=16, fontweight='bold')
    
    smoothing_vals = list(results.keys())
    labels = [f'{v}' for v in smoothing_vals]
    
    # Extract final values
    final_loss = [results[v]['loss_test'][-1] for v in smoothing_vals]
    final_acc = [results[v]['acc_test'][-1] * 100 for v in smoothing_vals]
    final_acc_last = [results[v]['acc_last_test'][-1] * 100 for v in smoothing_vals]
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    # Plot 1: Final Test Loss
    ax1 = axes[0]
    bars1 = ax1.bar(labels, final_loss, color=colors[:len(labels)], alpha=0.7)
    ax1.set_xlabel('Label Smoothing', fontsize=12)
    ax1.set_ylabel('Test Loss', fontsize=12)
    ax1.set_title('Final Test Loss', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='y')
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', fontsize=9)
    
    # Plot 2: Final Test Accuracy
    ax2 = axes[1]
    bars2 = ax2.bar(labels, final_acc, color=colors[:len(labels)], alpha=0.7)
    ax2.set_xlabel('Label Smoothing', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title('Final Test Accuracy', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_ylim([0, 100])
    for bar in bars2:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=9)
    
    # Plot 3: Final Last Token Accuracy
    ax3 = axes[2]
    bars3 = ax3.bar(labels, final_acc_last, color=colors[:len(labels)], alpha=0.7)
    ax3.set_xlabel('Label Smoothing', fontsize=12)
    ax3.set_ylabel('Accuracy (%)', fontsize=12)
    ax3.set_title('Final Last Token Accuracy', fontsize=13, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')
    ax3.set_ylim([0, 100])
    for bar in bars3:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    return fig


def print_summary(results):
    """Print summary table of results."""
    print("\n" + "="*80)
    print("FINAL RESULTS SUMMARY")
    print("="*80)
    print(f"{'Label Smoothing':<20} {'Test Loss':<15} {'Test Acc (%)':<15} {'Last Token Acc (%)':<20}")
    print("-"*80)
    for ls_val in results.keys():
        data = results[ls_val]
        print(f"{ls_val:<20.2f} {data['loss_test'][-1]:<15.4f} "
              f"{data['acc_test'][-1]*100:<15.2f} {data['acc_last_test'][-1]*100:<20.2f}")

In [None]:
label_smoothing_values = [0.0, 0.05, 0.1, 0.2]

results = train_with_label_smoothing(
    model=model,
    params_init=params,  # Your initial parameters
    opt_state_init=tx.init(params),  # Initial optimizer state
    tx=tx,  # Your optimizer
    train_text_int=train_text_int,
    test_text_int=test_text_int,
    train_step_fn=train_step_with_smoothing,
    get_batch=get_batch,
    label_smoothing_values=label_smoothing_values,
    niter=5000,
    B=128,
    T=32
)

In [None]:
fig1 = plot_label_smoothing_comparison(results)
plt.show()


fig2 = plot_final_comparison(results)
plt.show()


print_summary(results)

### Tempeature scaling

- temperature = 1.0 - baseline (no scaling)
- temperature = 0.8 - sharper, more confident predictions
- temperature = 1.2 - softer, less confident predictions
- temperature = 0.5 - very sharp (risk of overconfidence)

In [None]:
# Remove @jax.jit or add static_argnames
from functools import partial

@partial(jax.jit, static_argnames=['temperature'])
def loss_and_metrics(logits, targets, temperature=1.0):
    """Compute cross-entropy loss and accuracy.

    Args:
      logits: (B, T, V) float array of unnormalized scores.
      targets: (B, T) integer array with ground-truth class ids.
      temperature: Float > 0. Scales logits before softmax.
                   < 1.0 = sharper predictions (more confident)
                   > 1.0 = softer predictions (less confident)

    Returns:
      loss: scalar average cross-entropy.
      metrics: dict with keys "loss", "acc", and "acc_last".
    """
    vocab = logits.shape[-1]
    
    # Apply temperature scaling to logits
    scaled_logits = logits / temperature
    
    flat_logits = scaled_logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Compute cross-entropy
    per_pos = optax.softmax_cross_entropy_with_integer_labels(flat_logits, flat_targets)
    loss = per_pos.mean()

    # Predictions and accuracy (use scaled logits)
    preds = jnp.argmax(scaled_logits, axis=-1)
    is_match = preds == targets
    acc_all = jnp.mean(is_match.astype(jnp.float32))
    acc_last = jnp.mean(is_match[:, -1].astype(jnp.float32))

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

In [None]:
def train_step(params, opt_state, x, y, tx, temperature=1.0):
    """Single optimization step using optax optimizer.

    Args:
      params: pytree of model parameters.
      opt_state: optax optimizer state corresponding to `params`.
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
      tx: optax.GradientTransformation (already initialized).
      temperature: Float > 0 for temperature scaling.

    Returns:
      new_params: updated parameters after one gradient step.
      new_opt_state: updated optimizer state.
      metrics: dict of scalar metrics (loss, acc).
    """
    def loss_fn(params):
        logits = model.apply({"params": params}, x)
        loss, metrics = loss_and_metrics(logits, y, temperature=temperature)
        return loss, metrics

    # compute gradients (loss is scalar, metrics is auxiliary)
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # optax update: compute parameter updates and new optimizer state
    updates, new_opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, metrics

# jit: mark both tx and temperature as static
train_step = jax.jit(train_step, static_argnames=("tx", "temperature"))

# create a batch from the training data
def get_batch(text_int, B, T):
    """Create a random batch of data from text_int.

    Args:
      text_int: 1D array of token ids.
      B: batch size (number of sequences).
      T: sequence length (number of tokens per sequence).

    Returns:
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
    """
    # choose random starting indices for each sequence in the batch
    ix = np.random.randint(0, len(text_int) - T, size=B)
    # inputs are text from i to i+T
    x = np.stack([text_int[i:i+T] for i in ix])
    # targets are text from i+1 to i+T+1
    y = np.stack([text_int[i+1:i+T+1] for i in ix])
    return jnp.array(x, dtype=jnp.int32), jnp.array(y, dtype=jnp.int32)

# define optax optimizer
learning_rate = 0.001
# Create Adam optimizer (Optax)
tx = optax.adam(learning_rate=learning_rate)
# Initialize optimizer state for current params
opt_state = tx.init(params)
print(f"Initialized optimizer: Adam lr={learning_rate}")

In [None]:
import time

# Define temperature values to test
temperature_values = [0.5, 0.8, 1.0, 1.2]

results = {}

for temp_val in temperature_values:
    print(f"\n{'='*60}")
    print(f"Training with temperature = {temp_val}")
    print(f"{'='*60}\n")
    
    # Reset parameters for each run
    params = params_init  # Your initial parameters
    opt_state = tx.init(params)
    
    loss_history = []
    time_history = []
    loss_test_history = []
    acc_test_history = []
    acc_last_test_history = []
    time_test_history = []
    time_start = time.time()
    
    niter = 5000
    B, T = 128, 32
    
    for it in range(niter):
        batch = get_batch(train_text_int, B, T)
        x, y = batch[0], batch[1]
        
        # Train with specific temperature
        params, opt_state, metrics = train_step(params, opt_state, x, y, tx, temperature=temp_val)
        
        loss_history.append(float(metrics['loss']))
        time_history.append(time.time() - time_start)
        
        if it % (niter // 50) == 0 or it == niter - 1:
            time_since_start = time.time() - time_start
            
            # Test evaluation
            B_test, T_test = 1024, 32
            test_batch = get_batch(test_text_int, B_test, T_test)
            test_x, test_y = test_batch[0], test_batch[1]
            test_logits = model.apply({"params": params}, test_x)
            
            # Use same temperature for test evaluation
            test_loss, test_metrics = loss_and_metrics(test_logits, test_y, temperature=temp_val)
            
            loss_test_history.append(float(test_loss))
            acc_test_history.append(float(test_metrics['acc']))
            acc_last_test_history.append(float(test_metrics['acc_last']))
            time_test_history.append(time_since_start)
            
            print(f"iteration {it:_}  time: {time_since_start:.1f} seconds")
            print(f"\t \t loss(train :: test): {metrics['loss']:.4f} :: {test_loss:.4f}")
            print(f"\t \t accuracy (train :: test): {100*metrics['acc']:.1f}% :: {100*test_metrics['acc']:.1f}%")
            print(f"\t \t accuracy (last character) (train :: test): {100*metrics['acc_last']:.1f}% :: {100*test_metrics['acc_last']:.1f}%")
            print()
    
    results[temp_val] = {
        'loss_train': loss_history,
        'loss_test': loss_test_history,
        'acc_test': acc_test_history,
        'acc_last_test': acc_last_test_history,
        'time_train': time_history,
        'time_test': time_test_history
    }

In [None]:
def plot_temperature_comparison(results):
    """
    Plot comparison of different temperature values.
    
    Args:
        results: Dictionary with training histories for each temperature
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Temperature Scaling Comparison', fontsize=16, fontweight='bold')
    
    colors = ['#d62728', '#ff7f0e', '#1f77b4', '#2ca02c']
    labels_map = {
        0.5: 'Very Sharp (0.5)',
        0.8: 'Sharp (0.8)',
        1.0: 'Baseline (1.0)',
        1.2: 'Soft (1.2)'
    }
    
    # Plot 1: Training Loss over time
    ax1 = axes[0, 0]
    for idx, (temp_val, data) in enumerate(results.items()):
        ax1.plot(data['time_train'], data['loss_train'], 
                label=labels_map.get(temp_val, f'{temp_val}'), linewidth=1.5, 
                color=colors[idx % len(colors)], alpha=0.7)
    ax1.set_xlabel('Time (seconds)', fontsize=12)
    ax1.set_ylabel('Training Loss', fontsize=12)
    ax1.set_title('Training Loss vs Time', fontsize=13, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Test Loss
    ax2 = axes[0, 1]
    for idx, (temp_val, data) in enumerate(results.items()):
        ax2.plot(data['time_test'], data['loss_test'], 
                label=labels_map.get(temp_val, f'{temp_val}'), linewidth=2, 
                marker='o', markersize=4, color=colors[idx % len(colors)])
    ax2.set_xlabel('Time (seconds)', fontsize=12)
    ax2.set_ylabel('Test Loss', fontsize=12)
    ax2.set_title('Test Loss vs Time', fontsize=13, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Test Accuracy
    ax3 = axes[1, 0]
    for idx, (temp_val, data) in enumerate(results.items()):
        acc_percent = [a * 100 for a in data['acc_test']]
        ax3.plot(data['time_test'], acc_percent, 
                label=labels_map.get(temp_val, f'{temp_val}'), linewidth=2, 
                marker='s', markersize=4, color=colors[idx % len(colors)])
    ax3.set_xlabel('Time (seconds)', fontsize=12)
    ax3.set_ylabel('Test Accuracy (%)', fontsize=12)
    ax3.set_title('Overall Test Accuracy', fontsize=13, fontweight='bold')
    ax3.legend(fontsize=10)
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Last Token Accuracy
    ax4 = axes[1, 1]
    for idx, (temp_val, data) in enumerate(results.items()):
        acc_last_percent = [a * 100 for a in data['acc_last_test']]
        ax4.plot(data['time_test'], acc_last_percent, 
                label=labels_map.get(temp_val, f'{temp_val}'), linewidth=2, 
                marker='^', markersize=4, color=colors[idx % len(colors)])
    ax4.set_xlabel('Time (seconds)', fontsize=12)
    ax4.set_ylabel('Last Token Accuracy (%)', fontsize=12)
    ax4.set_title('Last Token Test Accuracy', fontsize=13, fontweight='bold')
    ax4.legend(fontsize=10)
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig


def plot_temperature_final_comparison(results):
    """
    Create bar charts comparing final metrics for temperature scaling.
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle('Final Metrics Comparison - Temperature Scaling', fontsize=16, fontweight='bold')
    
    temp_vals = list(results.keys())
    labels = [f'{v}' for v in temp_vals]
    
    # Extract final values
    final_loss = [results[v]['loss_test'][-1] for v in temp_vals]
    final_acc = [results[v]['acc_test'][-1] * 100 for v in temp_vals]
    final_acc_last = [results[v]['acc_last_test'][-1] * 100 for v in temp_vals]
    
    colors = ['#d62728', '#ff7f0e', '#1f77b4', '#2ca02c']
    
    # Plot 1: Final Test Loss
    ax1 = axes[0]
    bars1 = ax1.bar(labels, final_loss, color=colors[:len(labels)], alpha=0.7)
    ax1.set_xlabel('Temperature', fontsize=12)
    ax1.set_ylabel('Test Loss', fontsize=12)
    ax1.set_title('Final Test Loss', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='y')
    for bar in bars1:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', fontsize=9)
    
    # Plot 2: Final Test Accuracy
    ax2 = axes[1]
    bars2 = ax2.bar(labels, final_acc, color=colors[:len(labels)], alpha=0.7)
    ax2.set_xlabel('Temperature', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title('Final Test Accuracy', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_ylim([0, 100])
    for bar in bars2:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=9)
    
    # Plot 3: Final Last Token Accuracy
    ax3 = axes[2]
    bars3 = ax3.bar(labels, final_acc_last, color=colors[:len(labels)], alpha=0.7)
    ax3.set_xlabel('Temperature', fontsize=12)
    ax3.set_ylabel('Accuracy (%)', fontsize=12)
    ax3.set_title('Final Last Token Accuracy', fontsize=13, fontweight='bold')
    ax3.grid(True, alpha=0.3, axis='y')
    ax3.set_ylim([0, 100])
    for bar in bars3:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    return fig


def print_temperature_summary(results):
    """Print summary table of temperature results."""
    print("\n" + "="*80)
    print("FINAL RESULTS SUMMARY - TEMPERATURE SCALING")
    print("="*80)
    print(f"{'Temperature':<20} {'Test Loss':<15} {'Test Acc (%)':<15} {'Last Token Acc (%)':<20}")
    print("-"*80)
    for temp_val in results.keys():
        data = results[temp_val]
        print(f"{temp_val:<20.1f} {data['loss_test'][-1]:<15.4f} "
              f"{data['acc_test'][-1]*100:<15.2f} {data['acc_last_test'][-1]*100:<20.2f}")

In [None]:
# Create visualizations
fig1 = plot_temperature_comparison(results)
plt.show()

fig2 = plot_temperature_final_comparison(results)
plt.show()

# Print summary
print_temperature_summary(results)