In [1]:
!pip install -r requirements.txt



In [1]:
import torch, enum, time, os, random, math
from datetime import datetime
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn
from mask import *
from process_data import *

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


## Loading in and inspecting model

In [2]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

  Referenced from: <3F789787-FE38-3CE7-8599-064BDD0416EE> /Users/jchang153/miniforge3/envs/tf-metal/lib/python3.9/site-packages/torchvision/image.so
  Expected in:     <6B8AC17B-04CC-36D0-BD01-780381EFB0CC> /Users/jchang153/miniforge3/envs/tf-metal/lib/python3.9/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

## Testing model

In [4]:
model.eval()

# Prompt
prompt = "Help me code up the fibonacci sequence in Python"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

max_new_tokens = 100
temperature = 1.0
top_k = 50

generated = input_ids.clone()

# Step-by-step autoregressive loop
for step in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(generated)
        # model outputs logits for each token in the input sequence; only get last token's logits
        next_token_logits = outputs.logits[:, -1, :] / temperature 

        # Top-k sampling
        topk_logits, topk_indices = torch.topk(next_token_logits, top_k)
        probs = torch.softmax(topk_logits, dim=-1)
        next_token = topk_indices[0][torch.multinomial(probs, num_samples=1)]

    # append new token
    generated = torch.cat([generated, next_token], dim=1)

    # decode just the new token
    new_text = tokenizer.decode(next_token[0])
    print(new_text, end="", flush=True)

print("\n---\nFull output:")
print(tokenizer.decode(generated[0], skip_special_tokens=True))



def fibonacci(n, x): self.i = x

return 10

def lerp(n, x: Int): return 2+'1

return 30

def fibonacci(n, x: Color): self.x = x

return 12

def fibonacci4(n, x: Int): self.x1 = x

return 30

def fibonacci3(n, x:
---
Full output:
Help me code up the fibonacci sequence in Python

def fibonacci(n, x): self.i = x

return 10

def lerp(n, x: Int): return 2+'1

return 30

def fibonacci(n, x: Color): self.x = x

return 12

def fibonacci4(n, x: Int): self.x1 = x

return 30

def fibonacci3(n, x:


## Registering hooks & getting masks

In [5]:
# register hooks with split fraction a = 0.5
active, masks_1, masks_2 = register_hooks(model, a=0.5)

## Loading in txt files

In [None]:
ds = load_dataset("OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train")

In [20]:
text_ds = ds.remove_columns([c for c in ds.column_names if c != "code"])

In [21]:
data1 = ""
for i in range(len(text_ds)):
    data1 += text_ds[i]['code']
len(data1)

36322229

In [22]:
data1 = data1[:5337651]

In [23]:
with open("sample_code.txt", "w") as f:
    f.write(data1)

In [187]:
with open("data/shakespeare.txt", "r") as f:
    data2 = f.read()
len(data2)

5337651

## Creating datasets

In [6]:
context = model.config.n_ctx
batch_size = 8

In [7]:
paths = ['data/shakespeare.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_1 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_1   = DataLoader(val_ds, batch_size=batch_size)
test_loader_1  = DataLoader(test_ds, batch_size=batch_size)

Token indices sequence length is longer than the specified maximum sequence length for this model (1809728 > 1024). Running this sequence through the model will result in indexing errors


In [8]:
len(train_loader_1), len(val_loader_1), len(test_loader_1)

(199, 11, 11)

In [9]:
paths = ['data/code.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_2 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_2   = DataLoader(val_ds, batch_size=batch_size)
test_loader_2  = DataLoader(test_ds, batch_size=batch_size)

In [10]:
len(train_loader_2), len(val_loader_2), len(test_loader_2)

(294, 17, 17)

## Training loop

In [11]:
def evaluate(loader, active_masks=None, max_batches=None):
    if active_masks is not None:
        active.update(active_masks)   # switch masks for this eval

    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for step, (x, y) in enumerate(loader, start=1):
            x, y = x.to(device), y.to(device)     # [B, T]
            out = model(input_ids=x, labels=y)    # GPT-2 computes CE internally
            # weight by number of supervised tokens (B*T)
            total_loss += out.loss.item() * x.numel()
            total_tokens += x.numel()

            if max_batches is not None and step >= max_batches:
                break

    model.train()
    return total_loss / max(1, total_tokens)   # per-token average loss

In [13]:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
CKPT_DIR = f"checkpoints/{ts}/"
os.makedirs(CKPT_DIR, exist_ok=True)

LOG_EPOCH = os.path.join(CKPT_DIR, "log.csv")
LOG_STEP  = os.path.join(CKPT_DIR, "steps.csv")

if not os.path.exists(LOG_EPOCH):
    with open(LOG_EPOCH, "w") as f:
        f.write("epoch,elapsed_sec,train_loss1,train_ppl1,train_loss2,train_ppl2,val_loss1,val_ppl1,val_loss2,val_ppl2\n")
if not os.path.exists(LOG_STEP):
    with open(LOG_STEP, "w") as f:
        f.write("global_step,epoch,step_in_epoch,ms_per_batch,lr,avg_loss1,avg_loss2\n")

lr = 3.5e-4
clip = 0.25
epochs = 5
log_interval = 1

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

best_val_loss_1 = None
best_val_loss_2 = None
start_time = time.time()
global_step = 0

def _effective_len(loader1, loader2):
    return min(len(loader1), len(loader2))

try:
    for epoch in range(1, epochs + 1):
        print(f"\n===== Epoch {epoch} =====")
        model.train()
        running_loss1 = 0.0
        running_loss2 = 0.0
        t_batch = time.time()

        # epoch-level (token-weighted) accumulators
        train_loss1_tokensum = 0.0
        train_loss2_tokensum = 0.0
        train_tokens1 = 0
        train_tokens2 = 0

        eff_train_len = _effective_len(train_loader_1, train_loader_2)

        for step, ((x1, y1), (x2, y2)) in enumerate(zip(train_loader_1, train_loader_2), start=1):
            x1, y1 = x1.to(device), y1.to(device)
            x2, y2 = x2.to(device), y2.to(device)

            optimizer.zero_grad(set_to_none=True)

            # pass 1 (dataset 1)
            active.update(masks_1)
            loss1 = model(input_ids=x1, labels=y1).loss
            loss1.backward()

            # pass 2 (dataset 2)
            active.update(masks_2)
            loss2 = model(input_ids=x2, labels=y2).loss
            loss2.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            # running stats
            l1 = loss1.item()
            l2 = loss2.item()
            running_loss1 += l1
            running_loss2 += l2

            # token-weighted epoch aggregates
            train_loss1_tokensum += l1 * x1.numel()
            train_tokens1 += x1.numel()
            train_loss2_tokensum += l2 * x2.numel()
            train_tokens2 += x2.numel()

            global_step += 1

            if (step % log_interval) == 0:
                elapsed_ms = (time.time() - t_batch) * 1000.0 / log_interval
                avg1 = running_loss1 / log_interval
                avg2 = running_loss2 / log_interval
                pct = 100.0 * step / float(eff_train_len)

                print(f'{epoch:3d}\t{pct:5.1f}%\t{elapsed_ms:8.2f}\t{lr:.3g}\t'
                      f'{avg1:9.5f}\t{math.exp(avg1):8.2f}\t{avg2:9.5f}\t{math.exp(avg2):8.2f}')

                with open(LOG_STEP, "a") as f:
                    f.write(f"{global_step},{epoch},{step},{elapsed_ms:.3f},{lr:.6g},"
                            f"{avg1:.6f},{avg2:.6f}\n")

                running_loss1 = 0.0
                running_loss2 = 0.0
                t_batch = time.time()

        # ---------- epoch end aggregates ----------
        epoch_train_loss1 = train_loss1_tokensum / max(1, train_tokens1)
        epoch_train_loss2 = train_loss2_tokensum / max(1, train_tokens2)
        epoch_train_ppl1 = math.exp(epoch_train_loss1)
        epoch_train_ppl2 = math.exp(epoch_train_loss2)

        # ---------- evaluation AFTER training ----------
        t0 = time.time()
        eff_val_len = _effective_len(val_loader_1, val_loader_2)
        val_loss_1 = evaluate(val_loader_1, active_masks=masks_1, max_batches=eff_val_len)
        val_loss_2 = evaluate(val_loader_2, active_masks=masks_2, max_batches=eff_val_len)
        eval_time = time.time() - t0

        print('-' * 110)
        print(f'| epoch: {epoch:3d} | eval_time: {eval_time:6.2f}s | '
              f'val1: {val_loss_1:6.3f} (ppl: {math.exp(val_loss_1):8.2f}) | '
              f'val2: {val_loss_2:6.3f} (ppl: {math.exp(val_loss_2):8.2f})')
        print('-' * 110)

        # ---------- checkpointing ----------
        # always save latest
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_loss_1": val_loss_1,
            "val_loss_2": val_loss_2,
            "train_loss_1": epoch_train_loss1,
            "train_loss_2": epoch_train_loss2,
        }, os.path.join(CKPT_DIR, "latest.pt"))

        # best checkpoints
        if best_val_loss_1 is None or val_loss_1 < best_val_loss_1:
            best_val_loss_1 = val_loss_1
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "which": "best_1",
                "val_loss_1": val_loss_1,
                "val_loss_2": val_loss_2,
            }, os.path.join(CKPT_DIR, "best_1.pt"))

        if best_val_loss_2 is None or val_loss_2 < best_val_loss_2:
            best_val_loss_2 = val_loss_2
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "which": "best_2",
                "val_loss_1": val_loss_1,
                "val_loss_2": val_loss_2,
            }, os.path.join(CKPT_DIR, "best_2.pt"))

        # ---------- epoch log row ----------
        elapsed_sec = time.time() - start_time
        with open(LOG_EPOCH, "a") as f:
            f.write(f"{epoch},{elapsed_sec:.2f},"
                    f"{epoch_train_loss1:.6f},{epoch_train_ppl1:.4f},"
                    f"{epoch_train_loss2:.6f},{epoch_train_ppl2:.4f},"
                    f"{val_loss_1:.6f},{math.exp(val_loss_1):.4f},"
                    f"{val_loss_2:.6f},{math.exp(val_loss_2):.4f}\n")

except KeyboardInterrupt:
    print('Graceful Exit')



===== Epoch 1 =====


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Graceful Exit


## Testing Model

In [None]:
ckpt_path = os.path.join(CKPT_DIR, "latest.pt")  # <--- set this to whichever you want

# Load checkpoint safely
print(f"\n[TEST] Loading checkpoint from: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)

# Handle both raw state_dict or wrapped payloads
state_dict = ckpt if isinstance(ckpt, dict) and all(k.startswith(("transformer", "wte", "lm_head")) for k in ckpt.keys()) \
    else ckpt.get("model_state", ckpt)

missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
    print(f"[TEST] load_state_dict: missing={len(missing)}, unexpected={len(unexpected)}")
    if len(missing) < 20:
        print("  missing:", missing)
    if len(unexpected) < 20:
        print("  unexpected:", unexpected)

model.to(device)
model.eval()

In [None]:
def _effective_len(loader1, loader2):
    return min(len(loader1), len(loader2))

with torch.no_grad():
    eff_test_len = _effective_len(test_loader_1, test_loader_2)

    # Objective 1
    active.update(masks_1)
    test_loss_1 = evaluate(test_loader_1, active_masks=masks_1, max_batches=eff_test_len)
    test_ppl_1  = math.exp(test_loss_1)

    # Objective 2
    active.update(masks_2)
    test_loss_2 = evaluate(test_loader_2, active_masks=masks_2, max_batches=eff_test_len)
    test_ppl_2  = math.exp(test_loss_2)

print("-" * 110)
print(f"| TEST @ {os.path.basename(ckpt_path)} | "
      f"loss1: {test_loss_1:6.4f} (ppl {test_ppl_1:8.2f}) | "
      f"loss2: {test_loss_2:6.4f} (ppl {test_ppl_2:8.2f})")
print("-" * 110)
