In [1]:
import torch, enum, time, os, random, math
from datetime import datetime
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn
from mask import *
from process_data import *

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


## Loading in and inspecting model

In [3]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

  Referenced from: <3F789787-FE38-3CE7-8599-064BDD0416EE> /Users/jchang153/miniforge3/envs/tf-metal/lib/python3.9/site-packages/torchvision/image.so
  Expected in:     <6B8AC17B-04CC-36D0-BD01-780381EFB0CC> /Users/jchang153/miniforge3/envs/tf-metal/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

## Testing model

In [6]:
model.eval()

# Prompt
prompt = "Help me code up the fibonacci sequence in Python"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

max_new_tokens = 100
temperature = 1.0
top_k = 50

generated = input_ids.clone()

# Step-by-step autoregressive loop
for step in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(generated)
        # model outputs logits for each token in the input sequence; only get last token's logits
        next_token_logits = outputs.logits[:, -1, :] / temperature 

        # Top-k sampling
        topk_logits, topk_indices = torch.topk(next_token_logits, top_k)
        probs = torch.softmax(topk_logits, dim=-1)
        next_token = topk_indices[0][torch.multinomial(probs, num_samples=1)]

    # append new token
    generated = torch.cat([generated, next_token], dim=1)

    # decode just the new token
    new_text = tokenizer.decode(next_token[0])
    print(new_text, end="", flush=True)

print("\n---\nFull output:")
print(tokenizer.decode(generated[0], skip_special_tokens=True))

 so you have this thing that shows what it does.

$ fibonacci 1

There are four fibonacci blocks (which are also called the N-dimensional array) that we know to be N-dimensional: the integer, the square root, and the nonce. The third form of the N-dimensional array is the diagonal of the array. If you multiply the array by two, then it will represent the diagonal of the entire N-dimensional array (in other words
---
Full output:
Help me code up the fibonacci sequence in Python so you have this thing that shows what it does.

$ fibonacci 1

There are four fibonacci blocks (which are also called the N-dimensional array) that we know to be N-dimensional: the integer, the square root, and the nonce. The third form of the N-dimensional array is the diagonal of the array. If you multiply the array by two, then it will represent the diagonal of the entire N-dimensional array (in other words


## Registering hooks & getting masks

In [6]:
# register hooks with split fraction a = 0.5
active, masks_1, masks_2 = register_hooks(model, a=0.5)

## Loading in txt files

In [None]:
ds = load_dataset("OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train")

In [20]:
text_ds = ds.remove_columns([c for c in ds.column_names if c != "code"])

In [21]:
data1 = ""
for i in range(len(text_ds)):
    data1 += text_ds[i]['code']
len(data1)

36322229

In [22]:
data1 = data1[:5337651]

In [23]:
with open("sample_code.txt", "w") as f:
    f.write(data1)

In [187]:
with open("data/shakespeare.txt", "r") as f:
    data2 = f.read()
len(data2)

5337651

## Creating datasets

In [6]:
context = model.config.n_ctx
batch_size = 32

In [7]:
paths = ['data/shakespeare.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_1 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_1   = DataLoader(val_ds, batch_size=batch_size)
test_loader_1  = DataLoader(test_ds, batch_size=batch_size)

Token indices sequence length is longer than the specified maximum sequence length for this model (1809728 > 1024). Running this sequence through the model will result in indexing errors


In [8]:
len(train_loader_1), len(val_loader_1), len(test_loader_1)

(50, 3, 3)

In [22]:
paths = ['data/code.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_2 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_2   = DataLoader(val_ds, batch_size=batch_size)
test_loader_2  = DataLoader(test_ds, batch_size=batch_size)

In [23]:
len(train_loader_2), len(val_loader_2), len(test_loader_2)

(74, 5, 5)

## Training loop

In [24]:
def evaluate(loader, active_masks=None, max_batches=None):
    if active_masks is not None:
        active.update(active_masks)   # switch masks for this eval

    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for step, (x, y) in enumerate(loader, start=1):
            x, y = x.to(device), y.to(device)     # [B, T]
            out = model(input_ids=x, labels=y)    # GPT-2 computes CE internally
            # weight by number of supervised tokens (B*T)
            total_loss += out.loss.item() * x.numel()
            total_tokens += x.numel()

            if max_batches is not None and step >= max_batches:
                break

    model.train()
    return total_loss / max(1, total_tokens)   # per-token average loss


In [None]:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
CKPT_DIR = f"checkpoints/{ts}/"
LOG_PATH = CKPT_DIR + "log.txt"
os.makedirs(CKPT_DIR, exist_ok=True)

# write header once if log doesn't exist
if not os.path.exists(LOG_PATH):
    with open(LOG_PATH, "w") as f:
        f.write("epoch,elapsed_sec,train_loss1,train_ppl1,train_loss2,train_ppl2,"
                "val_loss1,val_ppl1,val_loss2,val_ppl2\n")

lr = 3.5e-4
clip = 0.25
best_val_loss_1 = None
best_val_loss_2 = None
epochs = 1
log_interval = 50

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

start_time = time.time()

try:
    for epoch in range(1, epochs + 1):
        t0 = time.time()
        val_loss_1 = evaluate(val_loader_1, active_masks=masks_1, max_batches=3)
        val_loss_2 = evaluate(val_loader_2, active_masks=masks_2, max_batches=3)

        print('-' * 100)
        print('| checkpoint | epoch {:3d} | time: {:5.2f}s | '
              'validation loss 1 {:5.3f} | validation ppl 1 {:8.2f} | '
              'validation loss 2 {:5.3f} | validation ppl 2 {:8.2f}'
              .format(epoch, (time.time() - t0),
                      val_loss_1, math.exp(val_loss_1),
                      val_loss_2, math.exp(val_loss_2)))
        print('-' * 100)
        print('epoch\tstep%\tms/batch\tlr\tloss1\tppl1\tloss2\tppl2')

        if best_val_loss_1 is None or val_loss_1 < best_val_loss_1:
            best_val_loss_1 = val_loss_1
            torch.save({...}, os.path.join(CKPT_DIR, "best_1.pt"))
        if best_val_loss_2 is None or val_loss_2 < best_val_loss_2:
            best_val_loss_2 = val_loss_2
            torch.save({...}, os.path.join(CKPT_DIR, "best_2.pt"))

        # ---------------- training ----------------
        model.train()
        running_loss1 = 0.0
        running_loss2 = 0.0
        t_batch = time.time()

        # epoch-level (token-weighted) accumulators for train stats
        train_loss1_tokensum = 0.0
        train_loss2_tokensum = 0.0
        train_tokens1 = 0
        train_tokens2 = 0

        for step, ((x1, y1), (x2, y2)) in enumerate(zip(train_loader_1, train_loader_2), start=1):
            x1, y1 = x1.to(device), y1.to(device)  # [B, T]
            x2, y2 = x2.to(device), y2.to(device)  # [B, T]

            # set L1 masks active; compute grads on L1 parameters
            active.update(masks_1)
            optimizer.zero_grad(set_to_none=True)
            loss1 = model(input_ids=x1, labels=y1).loss
            loss1.backward()

            # set L2 masks active; compute grads on L2 parameters
            active.update(masks_2)
            optimizer.zero_grad(set_to_none=True)
            loss2 = model(input_ids=x2, labels=y2).loss
            loss2.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            # running display
            running_loss1 += loss1.item()
            running_loss2 += loss2.item()

            # epoch-level token-weighted stats
            train_loss1_tokensum += loss1.item() * x1.numel()
            train_tokens1 += x1.numel()
            train_loss2_tokensum += loss2.item() * x2.numel()
            train_tokens2 += x2.numel()

            if (step % log_interval) == 0:
                elapsed_ms = (time.time() - t_batch) * 1000.0 / log_interval
                avg1 = running_loss1 / log_interval
                avg2 = running_loss2 / log_interval
                pct = 100.0 * step / float(len(train_loader_1))  # zip() â†’ shorter loader length
                print(f'{epoch:3d}\t{pct:4.1f}%\t{elapsed_ms:7.2f}\t{lr:.3g}\t'
                      f'{avg1:6.3f}\t{math.exp(avg1):7.2f}\t{avg2:6.3f}\t{math.exp(avg2):7.2f}')
                running_loss1 = 0.0
                running_loss2 = 0.0
                t_batch = time.time()

        # ---------------- epoch end: compute epoch-level train stats ----------------
        epoch_train_loss1 = train_loss1_tokensum / max(1, train_tokens1)
        epoch_train_loss2 = train_loss2_tokensum / max(1, train_tokens2)
        epoch_train_ppl1 = math.exp(epoch_train_loss1)
        epoch_train_ppl2 = math.exp(epoch_train_loss2)

        # ---------------- always-save checkpoint for this epoch ----------------
        ckpt_path = os.path.join(CKPT_DIR, f"epoch_{epoch:03d}.pt")
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_loss_1": val_loss_1,
            "val_loss_2": val_loss_2,
            "train_loss_1": epoch_train_loss1,
            "train_loss_2": epoch_train_loss2,
        }, ckpt_path)

        # also keep a convenient "latest.pt"
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_loss_1": val_loss_1,
            "val_loss_2": val_loss_2,
            "train_loss_1": epoch_train_loss1,
            "train_loss_2": epoch_train_loss2,
        }, os.path.join(CKPT_DIR, "latest.pt"))

        # ---------------- append to log.txt ----------------
        elapsed_sec = time.time() - start_time
        with open(LOG_PATH, "a") as f:
            f.write(f"{epoch},{elapsed_sec:.2f},"
                    f"{epoch_train_loss1:.6f},{epoch_train_ppl1:.4f},"
                    f"{epoch_train_loss2:.6f},{epoch_train_ppl2:.4f},"
                    f"{val_loss_1:.6f},{math.exp(val_loss_1):.4f},"
                    f"{val_loss_2:.6f},{math.exp(val_loss_2):.4f}\n")

except KeyboardInterrupt:
    print('Graceful Exit')


Graceful Exit
