In [1]:
!pip install -r requirements.txt



In [85]:
import torch, enum, time, os, random, math, csv
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn
from mask import *
from process_data import *
from utils import *

from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

## Loading in and inspecting model

In [9]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

## Inference Test

In [None]:
prompt = "Help me code up the fibonacci sequence in Python"

max_new_tokens = 100
temperature = 1.0
top_k = 50

generate_text(model, tokenizer, prompt, device, max_new_tokens, temperature, top_k)

 2 or 3. As well as knowing that you can have multiple numbers, you can also do math if you don't already know which key is the right one. After this, the solution will then be stored and processed through the Riemann family for a specific number and we call it the Fibonacci Sequence. Now we will generate the result of the second Fibonacci Sequence with the Fibonacci sequence.

function main(string[] args) { var FibonacciSequence
---
Full output:
Help me code up the fibonacci sequence in Python 2 or 3. As well as knowing that you can have multiple numbers, you can also do math if you don't already know which key is the right one. After this, the solution will then be stored and processed through the Riemann family for a specific number and we call it the Fibonacci Sequence. Now we will generate the result of the second Fibonacci Sequence with the Fibonacci sequence.

function main(string[] args) { var FibonacciSequence


## Loading in txt files

In [None]:
ds = load_dataset("OpenCoder-LLM/opc-sft-stage2", "educational_instruct", split="train")

In [None]:
text_ds = ds.remove_columns([c for c in ds.column_names if c != "code"])

In [None]:
data1 = ""
for i in range(len(text_ds)):
    data1 += text_ds[i]['code']
len(data1)

In [None]:
data1 = data1[:5337651]

In [None]:
with open("sample_code.txt", "w") as f:
    f.write(data1)

In [None]:
with open("data/shakespeare.txt", "r") as f:
    data2 = f.read()
len(data2)

## Creating datasets

In [13]:
context = model.config.n_ctx
batch_size = 8

In [14]:
paths = ['data/shakespeare.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_1 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_1   = DataLoader(val_ds, batch_size=batch_size)
test_loader_1  = DataLoader(test_ds, batch_size=batch_size)

Token indices sequence length is longer than the specified maximum sequence length for this model (1809728 > 1024). Running this sequence through the model will result in indexing errors


In [15]:
len(train_loader_1), len(val_loader_1), len(test_loader_1)

(199, 11, 11)

In [16]:
paths = ['data/code.txt']

train_ds = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.train)
val_ds   = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.valid)
test_ds  = SimpleTextDataset(paths, tokenizer, context, DatasetSplit.test)

train_loader_2 = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader_2   = DataLoader(val_ds, batch_size=batch_size)
test_loader_2  = DataLoader(test_ds, batch_size=batch_size)

In [17]:
len(train_loader_2), len(val_loader_2), len(test_loader_2)

(294, 17, 17)

## Registering Parametrizations, Getting Masks

In [None]:
# register hooks with split fraction a
# active, masks_1, masks_2 = register_hooks(model, a=0.5)
controller, masks_1, masks_2 = register_parametrizations(model, a=0.5)

## Training loop

In [None]:
def evaluate(loader, model, controller, max_batches=None):
    controller.set_active("ALL")
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for step, (x, y) in enumerate(loader, start=1):
            x, y = x.to(device), y.to(device)
            out = model(input_ids=x, labels=y)
            total_loss += out.loss.item() * x.numel()
            total_tokens += x.numel()

            if max_batches is not None and step >= max_batches:
                break

    model.train()
    return total_loss / max(1, total_tokens)   # per-token average loss

In [None]:
lr = 3.5e-4
clip = 0.25
epochs = 3
log_interval = 5

In [None]:
params_D1, params_D2 = [], []
for name, p in model.named_parameters():
    if name in masks_1:
        (params_D1 if "belongs_to_L1_tensor" else params_D2).append(p)
    else:
        pass

opt1 = torch.optim.Adam(params_D1, lr=lr)
opt2 = torch.optim.Adam(params_D2, lr=lr)

In [None]:
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
CKPT_DIR = f"checkpoints/{ts}/"
os.makedirs(CKPT_DIR, exist_ok=True)

LOG_EPOCH = os.path.join(CKPT_DIR, "log.csv")
LOG_STEP = os.path.join(CKPT_DIR, "steps_raw.csv")

with open(LOG_EPOCH, "w") as f:
    f.write("epoch,elapsed_sec,train_loss1,train_ppl1,train_loss2,train_ppl2,val_loss1,val_ppl1,val_loss2,val_ppl2\n")

with open(LOG_STEP, "w") as f:
    f.write("global_step,loss1,loss2\n")

best_val_loss_1 = None
best_val_loss_2 = None
start_time = time.time()
global_step = 0

try:
    # step-level and log-level accumulators
    loss_steps1 = []
    loss_steps2 = []
    loss_win1   = []
    loss_win2   = []
    step_index  = []

    for epoch in range(1, epochs + 1):
        print(f"\n===== Epoch {epoch} =====")
        model.train()
        t_batch = time.time()

        # epoch-level accumulators
        train_loss1_tokensum = 0.0
        train_loss2_tokensum = 0.0
        train_tokens1 = 0
        train_tokens2 = 0

        eff_train_len = min(len(train_loader_1), len(train_loader_2))

        print(f"{'ep':>3}{'step':>11}{'ms/b':>14}{'lr':>12}"
              f"{'loss1':>12}{'ppl1':>11}{'loss2':>12}{'ppl2':>11}")

        for step, ((x1, y1), (x2, y2)) in enumerate(zip(train_loader_1, train_loader_2), start=1):
            x1, y1 = x1.to(device), y1.to(device)
            x2, y2 = x2.to(device), y2.to(device)

            # pass 1 (dataset 1)
            controller.set_active("L1")
            opt1.zero_grad(set_to_none=True)
            out1 = model(x1, labels=y1)
            out1.loss.backward()
            opt1.step()

            # pass 2 (dataset 2)
            controller.set_active("L2")
            opt1.zero_grad(set_to_none=True)
            out2 = model(x2, labels=y2)
            out2.loss.backward()
            opt2.step()

            l1 = out1.loss.item()
            l2 = out2.loss.item() 
            loss_steps1.append(l1)
            loss_steps2.append(l2)
            loss_win1.append(l1)
            loss_win2.append(l2)
            step_index.append(global_step + 1)

            train_loss1_tokensum += l1 * x1.numel()
            train_tokens1 += x1.numel()
            train_loss2_tokensum += l2 * x2.numel()
            train_tokens2 += x2.numel()

            global_step += 1

            if (step % log_interval) == 0:
                elapsed_ms = (time.time() - t_batch) * 1000.0 / log_interval
                avg1 = sum(loss_win1) / len(loss_win1)
                avg2 = sum(loss_win2) / len(loss_win2)

                print(f"{epoch:3d}{step:7d}/{eff_train_len:<6d}{elapsed_ms:11.2f}{lr:12.1e}"
                      f"{avg1:12.4f}{math.exp(avg1):11.2f}{avg2:12.4f}{math.exp(avg2):11.2f}")

                loss_win1.clear()
                loss_win2.clear()
                t_batch = time.time()

        # ---------- epoch end aggregates ----------
        epoch_train_loss1 = train_loss1_tokensum / max(1, train_tokens1)
        epoch_train_loss2 = train_loss2_tokensum / max(1, train_tokens2)
        epoch_train_ppl1 = math.exp(epoch_train_loss1)
        epoch_train_ppl2 = math.exp(epoch_train_loss2)

        # ---------- evaluation AFTER training ----------
        t0 = time.time()
        eff_val_len = min(len(val_loader_1), len(val_loader_2))
        val_loss_1 = evaluate(val_loader_1, model, controller, max_batches=eff_val_len)
        val_loss_2 = evaluate(val_loader_2, model, controller, max_batches=eff_val_len)
        eval_time = time.time() - t0

        print('-' * 100)
        print(f'| epoch: {epoch:3d} | eval_time: {eval_time:.2f}s | '
              f'val1: {val_loss_1:.3f} (ppl: {math.exp(val_loss_1):.2f}) | '
              f'val2: {val_loss_2:.3f} (ppl: {math.exp(val_loss_2):.2f})')
        print('-' * 100)

        # ---------- checkpointing ----------
        # save latest model
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "val_loss_1": val_loss_1,
            "val_loss_2": val_loss_2,
            "train_loss_1": epoch_train_loss1,
            "train_loss_2": epoch_train_loss2,
        }, os.path.join(CKPT_DIR, f"latest.pt"))

        # save best checkpoints
        best_1 = best_val_loss_1 is None or val_loss_1 < best_val_loss_1
        best_2 = best_val_loss_2 is None or val_loss_2 < best_val_loss_2

        if best_1 or best_2:
            if val_loss_1 < best_val_loss_1:
                best_val_loss_1 = val_loss_1
            if val_loss_2 < best_val_loss_2:
                best_val_loss_2 = val_loss_2

            torch.save({
                "model_state": model.state_dict(),
                "val_loss_1": val_loss_1,
                "val_loss_2": val_loss_2,
            }, os.path.join(CKPT_DIR, f"epoch_{epoch}.pt"))

        # ---------- epoch log row ----------
        elapsed_sec = time.time() - start_time
        with open(LOG_EPOCH, "a") as f:
            f.write(f"{epoch},{elapsed_sec:.2f},"
                    f"{epoch_train_loss1:.6f},{epoch_train_ppl1:.4f},"
                    f"{epoch_train_loss2:.6f},{epoch_train_ppl2:.4f},"
                    f"{val_loss_1:.6f},{math.exp(val_loss_1):.4f},"
                    f"{val_loss_2:.6f},{math.exp(val_loss_2):.4f}\n")
            
    plt.figure()
    plt.plot(step_index, loss_steps1, label="train loss (L1)")
    plt.plot(step_index, loss_steps2, label="train loss (L2)")
    plt.xlabel("global step")
    plt.ylabel("loss")
    plt.title("Training loss vs steps")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(CKPT_DIR, "loss_vs_steps.png"), dpi=150)
    print(f"Saved plot to {os.path.join(CKPT_DIR, 'loss_vs_steps.png')}")

    with open(LOG_STEP, "w") as f:
        for s, a, b in zip(step_index, loss_steps1, loss_steps2):
            f.write(f"{s},{a:.6f},{b:.6f}\n")

except KeyboardInterrupt:
    print('Graceful Exit')

## Testing Model

In [None]:
ckpt_path = "checkpoints/20251106_024828/best.pt"

print(f"\n[TEST] Loading checkpoint from: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
state_dict = ckpt.get("model_state", ckpt)

model.load_state_dict(state_dict, strict=False)
model.to(device)

In [None]:
eff_test_len = min(len(test_loader_1), len(test_loader_2))
test_loss_1 = evaluate(test_loader_1, model, controller, max_batches=eff_test_len)
test_ppl_1  = math.exp(test_loss_1)
test_loss_2 = evaluate(test_loader_2, model, controller, max_batches=eff_test_len)
test_ppl_2  = math.exp(test_loss_2)

print("-" * 100)
print(f"| TEST @ {os.path.basename(ckpt_path)} | "
      f"loss1: {test_loss_1:.4f} (ppl {test_ppl_1:.2f}) | "
      f"loss2: {test_loss_2:.4f} (ppl {test_ppl_2:.2f})")
print("-" * 100)

----------------------------------------------------------------------------------------------------
| TEST @ best_2.pt | loss1: 2.8741 (ppl 17.71) | loss2: 0.6462 (ppl 1.91)
----------------------------------------------------------------------------------------------------
