In [54]:
!git clone https://github.com/beckhamtoh/char-llm-assignment.git
%cd char-llm-assignment
!ls

Cloning into 'char-llm-assignment'...
remote: Enumerating objects: 18, done.[K
remote: Counting objects: 100% (18/18), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 18 (delta 1), reused 11 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (18/18), 30.17 MiB | 16.77 MiB/s, done.
Resolving deltas: 100% (1/1), done.
/content/char-llm-assignment/char-llm-assignment/char-llm-assignment/char-llm-assignment
data  models  README.md  transformer.ipynb  util


In [55]:
# Enable autoreload of local Python modules (e.g., models)
# %load_ext autoreload
# %autoreload 2

# manual reload for local modules
import importlib

In [56]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
import jax
import jax.numpy as jnp
import optax
import time

# local imports
import models.models as models
import util.generation as generation


In [57]:
# ==== LOGGING: small, no-deps experiment logger ====
import os, csv, json, time, hashlib
from pathlib import Path

class ExperimentLogger:
    def __init__(self, root="runs", run_tag="run"):
        Path(root).mkdir(parents=True, exist_ok=True)
        ts = time.strftime("%Y-%m-%d_%H-%M-%S")
        rid = f"{ts}_{run_tag}_{hashlib.md5(str(time.time()).encode()).hexdigest()[:4]}"
        self.root = Path(root); self.run_id = rid; self.run_dir = self.root / rid
        self.run_dir.mkdir(parents=True, exist_ok=True)
        self.master = self.root / "experiments.csv"
        self.eval_csv = self.run_dir / "eval_log.csv"
        self.samples_txt = self.run_dir / "samples.txt"
        self.t0 = time.time()
        self.best = None; self.best_row = None

        # init header files
        if not self.master.exists():
            with open(self.master, "w", newline="") as f:
                csv.writer(f).writerow([
                    "run_id","run_tag","seed","B","T","steps","tokens_per_step","total_tokens",
                    "vocab_size","d_model","n_layers","n_heads","max_len","lr","optimizer",
                    "params","wall_time_s",
                    "val_loss_best","val_bpc_best","val_acc_best","val_acc_last_best","step_best",
                    "val_loss_final","val_bpc_final","val_acc_final","val_acc_last_final","step_final"
                ])
        with open(self.eval_csv, "w", newline="") as f:
            csv.writer(f).writerow(["step","elapsed_s","val_loss","val_bpc","val_acc","val_acc_last","tokens_seen"])

    def log_config(self, cfg:dict):
        with open(self.run_dir / "config.json", "w") as f:
            json.dump(cfg, f, indent=2)
        self.cfg = cfg
        self.tokens_per_step = int(cfg["B"]) * int(cfg["T"])

    def log_eval(self, step:int, val_loss:float, val_acc:float, val_acc_last:float):
        elapsed = time.time() - self.t0
        val_bpc = float(val_loss) / float(jnp.log(2.0))
        tokens_seen = step * self.tokens_per_step
        with open(self.eval_csv, "a", newline="") as f:
            csv.writer(f).writerow([step, f"{elapsed:.2f}", float(val_loss), val_bpc, float(val_acc), float(val_acc_last), tokens_seen])

        crit = (val_bpc, -val_acc_last)  # lower bpc, higher last-acc is better
        if (self.best is None) or (crit < self.best):
            self.best = crit
            self.best_row = dict(step_best=step,
                                 val_loss_best=float(val_loss),
                                 val_bpc_best=val_bpc,
                                 val_acc_best=float(val_acc),
                                 val_acc_last_best=float(val_acc_last))

    def save_sample(self, prompt:str, generated:str):
        with open(self.samples_txt, "a") as f:
            f.write("=== SAMPLE ===\nPrompt:\n" + prompt + "\n\nGenerated:\n" + generated + "\n\n")

    def finish(self, params_count:int, final_row:dict):
        wall = time.time() - self.t0
        cfg = self.cfg
        total_tokens = int(cfg["steps"]) * self.tokens_per_step
        row = [
            self.run_id, cfg.get("run_tag",""), cfg.get("seed",""), cfg["B"], cfg["T"], cfg["steps"],
            self.tokens_per_step, total_tokens,
            cfg.get("vocab_size",""), cfg.get("d_model",""), cfg.get("n_layers",""), cfg.get("n_heads",""),
            cfg.get("max_len",""), cfg.get("lr",""), cfg.get("optimizer",""),
            params_count, f"{wall:.2f}",
            # best snapshot
            self.best_row.get("val_loss_best","") if self.best_row else "",
            self.best_row.get("val_bpc_best","") if self.best_row else "",
            self.best_row.get("val_acc_best","") if self.best_row else "",
            self.best_row.get("val_acc_last_best","") if self.best_row else "",
            self.best_row.get("step_best","") if self.best_row else "",
            # final snapshot
            final_row["val_loss_final"], final_row["val_bpc_final"],
            final_row["val_acc_final"], final_row["val_acc_last_final"], final_row["step_final"],
        ]
        with open(self.master, "a", newline="") as f:
            csv.writer(f).writerow(row)
        print(f"[LOG] Wrote master row → {self.master}")
        print(f"[LOG] Artifacts in → {self.run_dir}")
# ==== /logger ====


In [58]:
# initialize the jax random key
key = jax.random.key(0)

# Load data

In [59]:
# load the ./data/text8_train.txt and ./data/text8_test.txt files
with open("./data/text8_train.txt", "r") as f:
    train_text = f.read()
with open("./data/text8_test.txt", "r") as f:
    test_text = f.read()

# print the length of the training text and test text
print(f"Length of training text: {len(train_text):_} characters")
print(f"Length of test text: {len(test_text):_} characters")

Length of training text: 90_000_000 characters
Length of test text: 5_000_000 characters


In [60]:
# Build vocabulary (lowercase + space + a few punctuations)
char_set = list("abcdefghijklmnopqrstuvwxyz ")
char_to_int = {ch:i for i,ch in enumerate(char_set)}
int_to_char = {i:ch for ch,i in char_to_int.items()}

def encode(s):
    """Encode string to array of integers"""
    ids = [char_to_int[c] for c in s]
    return np.array(ids, dtype=np.uint8)  # use np.uint8 to save space

In [61]:
# encode the text
train_text_int = encode(train_text)
test_text_int = encode(test_text)

In [62]:
# sanity check: display a few random characters from the training text
T = 128
for _ in range(5):
    # choose random position in text
    N = np.random.randint(low=0, high=len(train_text)-T)
    print(train_text[N:N+T])
    print()

ero one after an anthrax attack was perpetrated on the company newspaper companies of the united states supermarket tabloids two

ne two in addr arpa to its canonical name referrals icann org an ns record or name server record maps a domain name to a list of

its transportation infrastructure asphalting new roads improving its ports and repairing war damaged roads and bridges since the

y describing himself as merely a gold prospector who happened to find a nugget quote i believe the brain like any other organ ca

ed following the workshop a panel of two eight experts worked to develop this report scientific evidence on condom effectiveness



# Create a basic Transformer model

In [63]:
def create_train_state(rng, vocab_size=27, d_model=64, n_layers=6, n_heads=8, max_len=128):
    # create a basic Transformer model
    model = models.DecoderOnlyTransformer(vocab_size, d_model, n_layers, n_heads, max_len)
    # create a dummy input for initialization
    dummy = jnp.zeros((1, min(16, max_len)), dtype=jnp.int32)
    # pass the dummy input to the model to initialize the parameters
    params = model.init({"params": rng}, dummy)["params"]
    return model, params

In [104]:
# vocab size
vocab_size= len(char_set)

# internal model dimensions
d_model=256

# number of attention heads
n_heads=8

# number of Transformer layers
n_layers=3

# maximum sequence length
max_len=128

np.random.seed(42)          # affects get_batch sampling
key = jax.random.key(42)    # affects model parameter initialization

model, params = create_train_state(key, vocab_size, d_model, n_layers, n_heads, max_len)

In [105]:
# compute the number of parameters
def count_params(params):
    return sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"Number of parameters: {count_params(params):_}")

Number of parameters: 2_413_312


In [106]:
# sanity check: create a batch of data & run a forward pass
B, T = 4, 32
batch = jax.random.randint(
    key=key,
    shape=(B, T), minval=0, maxval=len(char_set))
logits = model.apply({"params": params}, batch)

print("batch shape:", batch.shape)  # (B, T)
print("logits shape:", logits.shape)  # (B, T, vocab_size)

batch shape: (4, 32)
logits shape: (4, 32, 27)


# Loss function

In [107]:
@jax.jit
def loss_and_metrics(logits, targets):
    """Compute cross-entropy loss and accuracy.

    Assumes `targets` contains only valid integer class ids in [0, V-1] (no -1 ignore tokens).

    Args:
      logits: (B, T, V) float array of unnormalized scores.
      targets: (B, T) integer array with ground-truth class ids.

    Returns:
      loss: scalar average cross-entropy over all positions.
      metrics: dict with keys "loss" and "acc" (both scalars).
    """
    # Flatten batch/time dims so optax works on shape (N, V) and (N,)
    vocab = logits.shape[-1]
    flat_logits = logits.reshape(-1, vocab)
    flat_targets = targets.reshape(-1)

    # Per-position cross-entropy, then mean over all positions
    per_pos = optax.softmax_cross_entropy_with_integer_labels(flat_logits, flat_targets)
    loss = per_pos.mean()

    # prediction over all positions
    preds = jnp.argmax(logits, axis=-1)  # (B, T)

    # compute accuracy over only the last position
    is_match = preds == targets

    # Accuracy over all positions
    acc_all = jnp.mean(is_match.astype(jnp.float32))

    # Accuracy over only last position
    acc_last = jnp.mean(is_match.astype(jnp.float32)[:,-1])

    return loss, {"loss": loss, "acc": acc_all, "acc_last": acc_last}

# Optimization step:

In [108]:
# create an update function
def train_step(params, opt_state, x, y, tx):
    """Single optimization step using optax optimizer.

    Args:
      params: pytree of model parameters.
      opt_state: optax optimizer state corresponding to `params`.
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
      tx: optax.GradientTransformation (already initialized).

    Returns:
      new_params: updated parameters after one gradient step.
      new_opt_state: updated optimizer state.
      metrics: dict of scalar metrics (loss, acc).
    """
    def loss_fn(params):
        logits = model.apply({"params": params}, x)
        loss, metrics = loss_and_metrics(logits, y)
        return loss, metrics

    # compute gradients (loss is scalar, metrics is auxiliary)
    (loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)

    # optax update: compute parameter updates and new optimizer state
    updates, new_opt_state = tx.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, metrics

# jit: last argument should be static because it is an object
train_step = jax.jit(train_step, static_argnames=("tx",))

# Batch creation:

In [109]:
# create a batch from the training data
def get_batch(text_int, B, T):
    """Create a random batch of data from text_int.

    Args:
      text_int: 1D array of token ids.
      B: batch size (number of sequences).
      T: sequence length (number of tokens per sequence).

    Returns:
      x: (B, T) int array input tokens.
      y: (B, T) int array target tokens.
    """
    # choose random starting indices for each sequence in the batch
    ix = np.random.randint(0, len(text_int) - T, size=B)
    # inputs are text from i to i+T
    x = np.stack([text_int[i:i+T] for i in ix])
    # targets are text from i+1 to i+T+1
    y = np.stack([text_int[i+1:i+T+1] for i in ix])
    return jnp.array(x, dtype=jnp.int32), jnp.array(y, dtype=jnp.int32)

# Optimizer creation:

In [110]:
# define optax optimizer
learning_rate = 0.001
# Create Adam optimizer (Optax)
tx = optax.adam(learning_rate=learning_rate)
# Initialize optimizer state for current params
opt_state = tx.init(params)
print(f"Initialized optimizer: Adam lr={learning_rate}")



Initialized optimizer: Adam lr=0.001


In [None]:
niter = 10000
B, T = 128, 32

# ==== LOGGING: init (place here, after learning_rate / niter / B / T are set) ====
run_tag = "baseline_T32_B128_lr1e-3"   # rename per run each experiment
seed = 42
cfg = dict(
    run_tag=run_tag, seed=seed,
    B=B, T=T, steps=niter,                 # match your loop exactly
    vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, max_len=max_len,
    lr=learning_rate, optimizer="Adam"
)
logger = ExperimentLogger(root="runs", run_tag=run_tag)
logger.log_config(cfg)
# ==== /init ====

loss_history = []
time_history = []
time_test_history = []
loss_test_history = []
time_start = time.time()
for it in range(niter):
    batch = get_batch(train_text_int, B, T)
    input, target = batch[0], batch[1]
    params_new, opt_state_new, metrics = train_step(params, opt_state, input, target, tx)

    # update params and opt_state
    params = params_new
    opt_state = opt_state_new
    acc = metrics['acc']
    acc_last = metrics['acc_last']
    loss = metrics['loss']

    loss_history.append(loss)
    time_history.append(time.time() - time_start)

    if it % (niter // 50) == 0 or it == niter - 1:
        time_since_start = time.time() - time_start
        # compute loss on test set
        B_test, T_test = 1024, 32
        test_batch = get_batch(test_text_int, B_test, T_test)
        test_input, test_target = test_batch[0], test_batch[1]
        test_logits = model.apply({"params": params}, test_input)
        test_loss, test_metrics = loss_and_metrics(test_logits, test_target)
        test_acc = test_metrics['acc']
        test_acc_last = test_metrics['acc_last']
        loss_test_history.append(test_loss)
        time_test_history.append(time_since_start)
        print(f"iteration {it:_}  time: {time_since_start:.1f} seconds")
        print(f"\t \t loss(train :: test): {loss:.4f} :: {test_loss:.4f}")
        print(f"\t \t accuracy (train :: test): {100*acc:.1f}% :: {100*test_acc:.1f}%")
        print(f"\t \t accuracy (last character) (train :: test): {100*acc_last:.1f}% :: {100*test_acc_last:.1f}%")
        print()
        # ==== LOGGING: record eval ====
        logger.log_eval(
            step=it,
            val_loss=float(test_loss),
            val_acc=float(test_acc),
            val_acc_last=float(test_acc_last)
        )
        # ==== /LOGGING ====



iteration 0  time: 4.0 seconds
	 	 loss(train :: test): 3.7973 :: 3.5791
	 	 accuracy (train :: test): 1.3% :: 17.1%
	 	 accuracy (last character) (train :: test): 0.8% :: 18.7%

iteration 200  time: 8.5 seconds
	 	 loss(train :: test): 2.0606 :: 2.0480
	 	 accuracy (train :: test): 38.0% :: 37.7%
	 	 accuracy (last character) (train :: test): 37.5% :: 35.1%

iteration 400  time: 13.2 seconds
	 	 loss(train :: test): 1.6994 :: 1.7725
	 	 accuracy (train :: test): 47.8% :: 45.9%
	 	 accuracy (last character) (train :: test): 49.2% :: 43.0%

iteration 600  time: 17.7 seconds
	 	 loss(train :: test): 1.6878 :: 1.6663
	 	 accuracy (train :: test): 48.1% :: 49.1%
	 	 accuracy (last character) (train :: test): 55.5% :: 53.4%

iteration 800  time: 22.2 seconds
	 	 loss(train :: test): 1.5882 :: 1.6131
	 	 accuracy (train :: test): 51.6% :: 50.6%
	 	 accuracy (last character) (train :: test): 49.2% :: 53.3%

iteration 1_000  time: 26.7 seconds
	 	 loss(train :: test): 1.5070 :: 1.5377
	 	 accu

In [99]:
# ==== LOGGING: finalize ====
final_val_loss = float(loss_test_history[-1]) if len(loss_test_history)>0 else float(loss_history[-1])
final_val_bpc  = final_val_loss / float(jnp.log(2.0))
# reuse last printed test metrics if available; otherwise do a small eval now
try:
    final_val_acc = float(test_acc)
    final_val_acc_last = float(test_acc_last)
    final_step = it
except NameError:
    B_eval, T_eval = 512, T
    test_batch = get_batch(test_text_int, B_eval, T_eval)
    test_logits = model.apply({"params": params}, test_batch[0])
    test_loss, test_metrics = loss_and_metrics(test_logits, test_batch[1])
    final_val_loss = float(test_loss); final_val_bpc = final_val_loss / float(jnp.log(2.0))
    final_val_acc = float(test_metrics['acc']); final_val_acc_last = float(test_metrics['acc_last'])
    final_step = niter



logger.finish(
    params_count=count_params(params),
    final_row=dict(
        val_loss_final=final_val_loss,
        val_bpc_final=final_val_bpc,
        val_acc_final=final_val_acc,
        val_acc_last_final=final_val_acc_last,
        step_final=final_step
    )
)
# ==== /finalize ====


[LOG] Wrote master row → runs/experiments.csv
[LOG] Artifacts in → runs/2025-10-28_09-02-00_baseline_T32_B128_lr1e-3_a9c0


In [101]:
# plot the loss history
import matplotlib.pyplot as plt

fig = plt.figure() # <<< add

plt.plot(time_history, loss_history, '-', label='train', color="blue")
plt.plot(time_test_history, loss_test_history, '-', label='test', lw=2, color="red")
plt.xlabel("Time (seconds)")
plt.ylabel("Loss")
plt.legend(loc='upper right')
plt.title("Training Loss History")
plt.grid()

# SAVE to the current run folder
plot_path_png = str(logger.run_dir / "loss_history.png")   # <<< add
plt.savefig(plot_path_png, dpi=200, bbox_inches='tight')   # <<< add

In [102]:
B = 1
seed = 42
rng = jax.random.PRNGKey(seed)
prompt = "hello my fri"
# prompt_int = encode(prompt.lower())
prompt_int = jnp.array([ [char_to_int.get(c, len(char_set)) for c in prompt.lower()[:64]] ], dtype=jnp.int32)

gen_len = 1000
out_ids = generation.generate_tokens(model, params, rng, prompt_int, gen_len, block_size=64,
                          temperature=0.7, sample=True)
print('generated ids shape:', out_ids.shape)
print('generated text:')
generated_text = ''.join(int_to_char.get(int(x), '?') for x in list(out_ids[0]))
# concatenate with prompt
print(prompt + generated_text)
#print(''.join(int_to_char.get(int(x), '?') for x in list(out_ids[0])))

# save the sample now that we have it
logger.save_sample(prompt, prompt + generated_text)

generated ids shape: (1, 1000)
generated text:
hello my friend socialism electromagnetic connecting behavioured in one construction and township of their translation two seven trilogy and water are known as the massachusetts for service constitution of the one nine eight zero s the consuls of the city of the collection of the subspecial sea rodin cabride the world s all of the three nine eight four three four six seven seven changing but the chinese field along with any education products within a common participated in a modern general policy with east king him for his fans and several percentage or that it community some internet principle you the salin recipient of the country fairly see the rationality of apiral in this subsequent decreased accounts of the problem of political military and prayer language for gas would be able to status was a created community and mandatory and are still used to do not the christian however model and remove however the summer of the first week with

In [103]:
import os, shutil
from google.colab import files

print("CWD:", os.getcwd())              # should show .../char-llm-assignment/char-llm-assignment/...
assert os.path.isdir("runs"), "No 'runs' folder here."

zip_path = "/content/runs_export.zip"
shutil.make_archive("/content/runs_export", "zip", "runs")  # zip the local ./runs
files.download(zip_path)


CWD: /content/char-llm-assignment/char-llm-assignment/char-llm-assignment/char-llm-assignment


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>