In [5]:
!git clone https://github.com/beckhamtoh/char-llm-assignment.git
%cd char-llm-assignment

Cloning into 'char-llm-assignment'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 15 (delta 0), reused 12 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (15/15), 30.14 MiB | 32.11 MiB/s, done.
/content/char-llm-assignment


In [6]:
# Enable autoreload of local Python modules (e.g., models)
# %load_ext autoreload
# %autoreload 2

# manual reload for local modules
import importlib

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
import optax
import time
import matplotlib.pyplot as plt

# local imports
import models.models as models
import util.generation as generation


In [8]:
# initialize the jax random key
key = jax.random.key(0)

# Load data

In [9]:
# load the ./data/text8_train.txt and ./data/text8_test.txt files
with open("./data/text8_train.txt", "r") as f:
    train_text = f.read()
with open("./data/text8_test.txt", "r") as f:
    test_text = f.read()

# print the length of the training text and test text
print(f"Length of training text: {len(train_text):_} characters")
print(f"Length of test text: {len(test_text):_} characters")

Length of training text: 90_000_000 characters
Length of test text: 5_000_000 characters


In [10]:
# Build vocabulary (lowercase + space + a few punctuations)
char_set = list("abcdefghijklmnopqrstuvwxyz ")
char_to_int = {ch:i for i,ch in enumerate(char_set)}
int_to_char = {i:ch for ch,i in char_to_int.items()}

def encode(s):
    """Encode string to array of integers"""
    ids = [char_to_int[c] for c in s]
    return np.array(ids, dtype=np.uint8)  # use np.uint8 to save space

In [11]:
# encode the text
train_text_int = encode(train_text)
test_text_int = encode(test_text)

In [12]:
# sanity check: display a few random characters from the training text
T = 128
for _ in range(5):
    # choose random position in text
    N = np.random.randint(low=0, high=len(train_text)-T)
    print(train_text[N:N+T])
    print()

of the illness and went to the fragrant mountain to give thanks to the person when he discovered that his own daughter gave up h

together they performed the arcade fire s song wake up from their album funeral he joined them again on one five september singi

ro asiatic language phylum its closest relatives are the berber semitic and beja groups of languages written records of the egyp

es were produced between about one nine three zero and one nine three five but the concept was abandonded because of its limited

xt is read aloud twice during the celebration setting the biblical book of esther is set in the third year of ahasuerus a king o



# Create a basic Transformer model

In [None]:
def create_train_state(rng, vocab_size=27, d_model=64, n_layers=6, n_heads=8, max_len=128, pos_encoding_type='learned'):
    # create a basic Transformer model with specified positional encoding
    model = models.DecoderOnlyTransformer(
        vocab_size, 
        d_model, 
        n_layers, 
        n_heads, 
        max_len,
        pos_encoding_type=pos_encoding_type  
    )
    # create a dummy input for initialization
    dummy = jnp.zeros((1, min(16, max_len)), dtype=jnp.int32)
    # pass the dummy input to the model to initialize the parameters
    params = model.init({"params": rng}, dummy)["params"]
    return model, params

In [None]:
# vocab size
vocab_size= len(char_set)

# internal model dimensions
d_model=256

# number of attention heads
n_heads=8

# number of Transformer layers
n_layers=2

# maximum sequence length
max_len=128

# learning rate for the optimizer
learning_rate = 0.001

# Loss function

In [17]:
@jax.jit
def loss_and_metrics(logits, targets):
    """Compute cross-entropy loss and accuracy.

    Assumes `targets` contains only valid integer class ids in [0, V-1] (no -1 ignore tokens).

    Args:
      logits: (B, T, V) float array of unnormalized scores.
      targets: (B, T) integer array with ground-truth class ids.

    Returns:
      loss: scalar average cross-entropy over all positions.
      metrics: dict with keys "loss" and "acc" (both scalars).
    """
    # Flatten batch/time dims so optax works on shape (N, V) and (N,)
    vocab = logits.shape[-1]
    flat_logits = logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Per-position cross-entropy, then mean over all positions
    per_pos = optax.softmax_cross_entropy_with_integer_labels(flat_logits, flat_targets)
    loss = per_pos.mean()

    # prediction over all positions
    preds = jnp.argmax(logits, axis=-1)  # (B, T)

    # compute accuracy over only the last position
    is_match = preds == targets

    # Accuracy over all positions
    acc_all = jnp.mean(is_match.astype(jnp.float32))

    # Accuracy over only last position
    acc_last = jnp.mean(is_match.astype(jnp.float32)[:,-1])

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

# Optimization step:

In [18]:
# create an update function
def train_step(params, opt_state, x, y, tx):
    """Single optimization step using optax optimizer.

    Args:
      params: pytree of model parameters.
      opt_state: optax optimizer state corresponding to `params`.
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
      tx: optax.GradientTransformation (already initialized).

    Returns:
      new_params: updated parameters after one gradient step.
      new_opt_state: updated optimizer state.
      metrics: dict of scalar metrics (loss, acc).
    """
    def loss_fn(params):
        logits = model.apply({"params": params}, x)
        loss, metrics = loss_and_metrics(logits, y)
        return loss, metrics

    # compute gradients (loss is scalar, metrics is auxiliary)
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # optax update: compute parameter updates and new optimizer state
    updates, new_opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, metrics

# jit: last argument should be static because it is an object
train_step = jax.jit(train_step, static_argnames=("tx",))

# Batch creation:

In [19]:
# create a batch from the training data
def get_batch(text_int, B, T):
    """Create a random batch of data from text_int.

    Args:
      text_int: 1D array of token ids.
      B: batch size (number of sequences).
      T: sequence length (number of tokens per sequence).

    Returns:
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
    """
    # choose random starting indices for each sequence in the batch
    ix = np.random.randint(0, len(text_int) - T, size=B)
    # inputs are text from i to i+T
    x = np.stack([text_int[i:i+T] for i in ix])
    # targets are text from i+1 to i+T+1
    y = np.stack([text_int[i+1:i+T+1] for i in ix])
    return jnp.array(x, dtype=jnp.int32), jnp.array(y, dtype=jnp.int32)

In [None]:
# Helper function for counting parameters
def count_params(params):
    return sum(x.size for x in jax.tree_util.tree_leaves(params))

# Position Encoding Test

In [None]:
def train_model_with_config(pos_encoding_type, niter=5000, B=128, T=32, 
                           d_model=256, n_layers=2, n_heads=8, 
                           max_len=128, learning_rate=0.001,
                           eval_interval=100, seed=42):
    """
    Train a model with specified positional encoding type and return results.
    
    Args:
        pos_encoding_type: 'learned', 'sinusoidal', 'rotary', or 'none'
        niter: Number of training iterations
        B: Batch size
        T: Sequence length
        d_model: Hidden dimension
        n_layers: Number of transformer layers
        n_heads: Number of attention heads
        max_len: Maximum sequence length
        learning_rate: Learning rate for optimizer
        eval_interval: How often to evaluate on test set
        seed: Random seed for reproducibility
        
    Returns:
        Dictionary with training history and final model/params
    """
    print(f"\n{'='*70}")
    print(f"Training with {pos_encoding_type.upper()} positional encoding")
    print(f"{'='*70}")
    
    # Initialize random key
    local_key = jax.random.PRNGKey(seed)
    
    # Create model
    model, params = create_train_state(
        local_key, 
        vocab_size=vocab_size, 
        d_model=d_model, 
        n_layers=n_layers, 
        n_heads=n_heads, 
        max_len=max_len,
        pos_encoding_type=pos_encoding_type
    )
    
    # Print model info
    n_params = count_params(params)
    print(f"Number of parameters: {n_params:_}")
    
    # Create optimizer
    tx = optax.adam(learning_rate=learning_rate)
    opt_state = tx.init(params)
    
    # Training history
    history = {
        'iteration': [],
        'time': [],
        'train_loss': [],
        'train_acc': [],
        'train_acc_last': [],
        'test_loss': [],
        'test_acc': [],
        'test_acc_last': [],
    }
    
    time_start = time.time()
    
    # Training loop
    for it in range(niter):
        # Get batch
        batch = get_batch(train_text_int, B, T)
        input_tokens, target_tokens = batch[0], batch[1]
        
        # Training step
        params, opt_state, metrics = train_step(
            params, opt_state, input_tokens, target_tokens, tx
        )
        
        # Evaluate at intervals
        if it % eval_interval == 0 or it == niter - 1:
            time_elapsed = time.time() - time_start
            
            # Compute test metrics
            B_test, T_test = 1024, 32
            test_batch = get_batch(test_text_int, B_test, T_test)
            test_input, test_target = test_batch[0], test_batch[1]
            test_logits = model.apply({"params": params}, test_input)
            test_loss, test_metrics = loss_and_metrics(test_logits, test_target)
            
            # Store history
            history['iteration'].append(it)
            history['time'].append(time_elapsed)
            history['train_loss'].append(float(metrics['loss']))
            history['train_acc'].append(float(metrics['acc']))
            history['train_acc_last'].append(float(metrics['acc_last']))
            history['test_loss'].append(float(test_loss))
            history['test_acc'].append(float(test_metrics['acc']))
            history['test_acc_last'].append(float(test_metrics['acc_last']))
            
            # Print progress
            print(f"Iter {it:5d}/{niter} | Time: {time_elapsed:6.1f}s | "
                  f"Train Loss: {metrics['loss']:.4f} | Test Loss: {test_loss:.4f} | "
                  f"Train Acc: {100*metrics['acc']:.1f}% | Test Acc: {100*test_metrics['acc']:.1f}%")
    
    print(f"\nTraining completed in {time.time() - time_start:.1f} seconds")
    print(f"Final test accuracy: {100*history['test_acc'][-1]:.2f}%")
    print(f"Final test accuracy (last char): {100*history['test_acc_last'][-1]:.2f}%")
    
    return {
        'model': model,
        'params': params,
        'history': history,
        'pos_encoding_type': pos_encoding_type,
        'config': {
            'niter': niter,
            'B': B,
            'T': T,
            'd_model': d_model,
            'n_layers': n_layers,
            'n_heads': n_heads,
            'max_len': max_len,
            'learning_rate': learning_rate,
            'n_params': n_params,
        }
    }


# Parameters

In [None]:
NITER = 5000
BATCH_SIZE = 128
SEQ_LENGTH = 32
EVAL_INTERVAL = 100

# Test

In [None]:
# Positional encoding types to test
encoding_types = ['learned', 'sinusoidal', 'none']

# Store results
results = {}

# Run experiments
for encoding_type in encoding_types:
    results[encoding_type] = train_model_with_config(
        pos_encoding_type=encoding_type,
        niter=NITER,
        B=BATCH_SIZE,
        T=SEQ_LENGTH,
        d_model=d_model,  # Use the values defined earlier
        n_layers=n_layers,
        n_heads=n_heads,
        max_len=max_len,
        learning_rate=learning_rate,
        eval_interval=EVAL_INTERVAL,
        seed=42
    )

print("\n" + "="*70)
print("ALL EXPERIMENTS COMPLETED")
print("="*70)

# Plotting & Analysis

In [None]:
# Create a comprehensive figure with multiple subplots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Positional Encoding Comparison', fontsize=16, fontweight='bold')

# Define colors for each encoding type
colors = {
    'learned': '#2E86AB',
    'sinusoidal': '#A23B72',
    'none': '#F18F01',
    'rotary': '#06A77D'
}

# 1. Training Loss vs Iterations
ax = axes[0, 0]
for enc_type, result in results.items():
    history = result['history']
    ax.plot(history['iteration'], history['train_loss'], 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Iteration', fontsize=11)
ax.set_ylabel('Training Loss', fontsize=11)
ax.set_title('Training Loss Over Time', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Test Loss vs Iterations
ax = axes[0, 1]
for enc_type, result in results.items():
    history = result['history']
    ax.plot(history['iteration'], history['test_loss'], 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Iteration', fontsize=11)
ax.set_ylabel('Test Loss', fontsize=11)
ax.set_title('Test Loss Over Time', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Training Accuracy vs Iterations
ax = axes[0, 2]
for enc_type, result in results.items():
    history = result['history']
    acc_percent = [a * 100 for a in history['train_acc']]
    ax.plot(history['iteration'], acc_percent, 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Iteration', fontsize=11)
ax.set_ylabel('Training Accuracy (%)', fontsize=11)
ax.set_title('Training Accuracy Over Time', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. Test Accuracy vs Iterations
ax = axes[1, 0]
for enc_type, result in results.items():
    history = result['history']
    acc_percent = [a * 100 for a in history['test_acc']]
    ax.plot(history['iteration'], acc_percent, 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Iteration', fontsize=11)
ax.set_ylabel('Test Accuracy (%)', fontsize=11)
ax.set_title('Test Accuracy Over Time', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 5. Last Character Accuracy (Test) - Most Important Metric
ax = axes[1, 1]
for enc_type, result in results.items():
    history = result['history']
    acc_percent = [a * 100 for a in history['test_acc_last']]
    ax.plot(history['iteration'], acc_percent, 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Iteration', fontsize=11)
ax.set_ylabel('Test Accuracy (Last Char) (%)', fontsize=11)
ax.set_title('Next-Character Prediction Accuracy', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 6. Training Time Comparison
ax = axes[1, 2]
for enc_type, result in results.items():
    history = result['history']
    ax.plot(history['time'], history['test_loss'], 
            label=enc_type.capitalize(), color=colors[enc_type], 
            linewidth=2, alpha=0.8)
ax.set_xlabel('Training Time (seconds)', fontsize=11)
ax.set_ylabel('Test Loss', fontsize=11)
ax.set_title('Test Loss vs Training Time', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('positional_encoding_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nFigure saved as 'positional_encoding_comparison.png'")

# Summary Table

In [None]:
print("\n" + "="*90)
print("FINAL RESULTS SUMMARY")
print("="*90)
print(f"{'Encoding Type':<20} {'Parameters':<15} {'Train Loss':<12} {'Test Loss':<12} "
      f"{'Test Acc':<12} {'Last Char Acc':<15} {'Time (s)':<10}")
print("-"*90)

for enc_type, result in results.items():
    history = result['history']
    config = result['config']
    
    final_train_loss = history['train_loss'][-1]
    final_test_loss = history['test_loss'][-1]
    final_test_acc = history['test_acc'][-1] * 100
    final_test_acc_last = history['test_acc_last'][-1] * 100
    final_time = history['time'][-1]
    n_params = config['n_params']
    
    print(f"{enc_type.capitalize():<20} {n_params:<15,} {final_train_loss:<12.4f} "
          f"{final_test_loss:<12.4f} {final_test_acc:<12.2f}% {final_test_acc_last:<15.2f}% "
          f"{final_time:<10.1f}")

print("="*90)

# Find best performing model
best_enc_type = max(results.keys(), 
                    key=lambda k: results[k]['history']['test_acc_last'][-1])
best_acc = results[best_enc_type]['history']['test_acc_last'][-1] * 100

print(f"\n🏆 Best performing encoding: {best_enc_type.upper()} "
      f"with {best_acc:.2f}% next-character prediction accuracy")

# Individual plots

In [None]:
def plot_detailed_comparison(results, metric='test_acc_last', title='Test Accuracy (Last Character)'):
    """Create a detailed plot for a specific metric."""
    plt.figure(figsize=(12, 6))
    
    for enc_type, result in results.items():
        history = result['history']
        
        if 'acc' in metric:
            values = [v * 100 for v in history[metric]]
            ylabel = 'Accuracy (%)'
        else:
            values = history[metric]
            ylabel = 'Loss'
        
        plt.plot(history['iteration'], values, 
                label=f"{enc_type.capitalize()}", 
                color=colors[enc_type],
                linewidth=2.5, alpha=0.85, marker='o', markersize=3, markevery=5)
    
    plt.xlabel('Iteration', fontsize=12, fontweight='bold')
    plt.ylabel(ylabel, fontsize=12, fontweight='bold')
    plt.title(title, fontsize=14, fontweight='bold')
    plt.legend(fontsize=11, framealpha=0.9)
    plt.grid(True, alpha=0.3, linestyle='--')
    plt.tight_layout()
    
    filename = f"detailed_{metric}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {filename}")

# Create detailed plots for key metrics
plot_detailed_comparison(results, 'test_acc_last', 
                        'Next-Character Prediction Accuracy (Test Set)')
plot_detailed_comparison(results, 'test_loss', 
                        'Test Loss Comparison')
plot_detailed_comparison(results, 'train_loss', 
                        'Training Loss Comparison')

# Generate Sample Text

In [None]:
print("\n" + "="*70)
print("TEXT GENERATION COMPARISON")
print("="*70)

prompts = [
    "the quick brown",
    "hello world",
    "once upon a time"
]

for prompt in prompts:
    print(f"\n{'='*70}")
    print(f"Prompt: '{prompt}'")
    print('='*70)
    
    for enc_type, result in results.items():
        model = result['model']
        params = result['params']
        
        # Encode prompt
        prompt_int = jnp.array(
            [[char_to_int.get(c, 0) for c in prompt.lower()[:64]]], 
            dtype=jnp.int32
        )
        
        # Generate
        rng = jax.random.PRNGKey(42)
        gen_len = 100
        out_ids = generation.generate_tokens(
            model, params, rng, prompt_int, gen_len, 
            block_size=max_len, temperature=0.7, sample=True
        )
        
        # Decode
        generated_text = ''.join(
            int_to_char.get(int(x), '?') for x in list(out_ids[0])
        )
        full_text = prompt + generated_text
        
        print(f"\n{enc_type.upper()}:")
        print(f"{full_text[:150]}...")  # Print first 150 chars