# Chezz Fine‑Tuning Notebook

This notebook demonstrates how to prepare data and fine‑tune **TinyLlama** on chess move prediction using LoRA adapters.  
It is reorganized and fully annotated so you can reproduce the results end‑to‑end or adapt it to your own chess dataset.

Run the cells sequentially **top‑to‑bottom** on Google Colab (GPU recommended).  
If you have questions, feel free to comment on the GitHub repo!

---



## Table of Contents
1. [Runtime and libraries (6 min)](#runtime-and-libraries-6-min)
2. [Mount Drive for automatic resumability (1 min)](#mount-drive-for-automatic-resumability-1-min)
3. [⌗ 3  Upload your 500 k file](#3-upload-your-500-k-file)
4. [Patch the tokenizer  ► Code](#patch-the-tokenizer-code)
5. [Load rows → prompt / completion pairs  ► Code](#load-rows-prompt-completion-pairs-code)
6. [Tokenise with masking  ► Code](#tokenise-with-masking-code)
7. [Build TinyLlama + LoRA adapters  ► Code](#build-tinyllama-lora-adapters-code)
8. [Training arguments & Trainer  ► Code](#training-arguments-trainer-code)
9. [(Opt.) peek at generation every 1000 steps  ► Code](#opt-peek-at-generation-every-1000-steps-code)
10. [Train (or resume)  ► Code](#train-or-resume-code)
11. [Save the first stage adapter](#save-the-first-stage-adapter)
12. [Import libraries](#import-libraries)
13. [Paths & knobs](#paths-knobs)
14. [Tokenizer & special tokens](#tokenizer-special-tokens)
15. [Prompt header](#prompt-header)
16. [Helpers to build prompt/completion](#helpers-to-build-prompt-completion)
17. [Load & oversample datasets](#load-oversample-datasets)
18. [Token-level packing](#token-level-packing)
19. [Load previous LoRA checkpoint](#load-previous-lora-checkpoint)
20. [TrainingArguments](#trainingarguments)
21. [peek callback](#peek-callback)
22. [Trainer & launch](#trainer-launch)
23. [Save to adapter stage 2](#save-to-adapter-stage-2)

## Runtime and libraries (6 min)

In [None]:
# Shell commands / package installation
!pip install \
  transformers \
  peft \
  accelerate \
  bitsandbytes \
  triton \
  datasets

## Mount Drive for automatic resumability (1 min)

In [None]:
# Imports
from google.colab import drive
drive.mount("/content/drive")
PROJECT_DIR = "/content/drive/MyDrive/chezz"


## ⌗ 3  Upload your 500 k file

In [None]:
# if the file is already on Drive:
!cp "/content/drive/MyDrive/chezz/data/train_500k.jsonl.gz" /content/


## Patch the tokenizer  ► Code

In [None]:
# Imports
from transformers import AutoTokenizer
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token    = tok.eos_token        # TinyLlama has no PAD by default
tok.padding_side = "right"

SPECIALS = {"additional_special_tokens": ["<|json|>", "</|json|>"]}
tok.add_special_tokens(SPECIALS)
end_json_id = tok.convert_tokens_to_ids("</|json|>")   # keep for generate()


## Load rows → prompt / completion pairs  ► Code

In [None]:
# Imports
import json, re
from datasets import load_dataset

SCHEMA = """
{
  "from": "<square>",        # e.g. "e2"
  "to":   "<square>",        # e.g. "e4"
  "piece": "<piece>",        # "pawn","knight",…
  "explanation": "<text>",   # short rationale
  "taunt": "<text>"          # optional cheeky comment
}
"""


SYSTEM = (
    "<|system|> You are **ChezzBot-β**, a dry-humored, mildly anxious chess coach "
    "who is utterly certain every move you make is textbook-perfect. You always play as "
    "the side to move; your opponent is the other color whom you tease with sarcastic digs. "
    "Explain your move in one confident sentence (≤25 words) using real chess ideas, then "
    "taunt the user in ≤15 words of playful mockery. "
    "Respond *only* with JSON: "
    + SCHEMA + " "
)



text_re = re.compile(
    r"FEN:\s*(.*?)\s*Best Move JSON:\s*(\{.*\})",
    re.DOTALL
)

# ───────────────────────────────────────────────────────────────────────────────
# 3) Build prompt/completion pairs
# ───────────────────────────────────────────────────────────────────────────────
def to_pairs(record):
    m = text_re.search(record.get("text",""))
    if not m:
        raise ValueError("Can't parse record")
    fen, json_str = m.groups()
    bm = json.loads(json_str)
    comp = {
        "from":        bm["from"].lower(),
        "to":          bm["to"].lower(),
        "piece":       bm["piece"].lower(),
        "explanation": bm.get("explanation",""),
        "taunt":       bm.get("taunt","")
    }
    comp_str = json.dumps(comp, separators=(",",":"))
    # CHANGED: include user/assistant roles & wrapper token in the prompt
    prompt = (
        SYSTEM
      + "<|user|>FEN: " + fen + "\n\n"
      + "<|assistant|><|json|>"
    )
    completion = comp_str + tok.eos_token
    return {"prompt": prompt, "completion": completion}


# load & map
raw   = load_dataset("json", data_files="/content/train_500k.jsonl.gz", split="train")
pairs = raw.map(to_pairs, remove_columns=raw.column_names)

In [None]:
# after you’ve built your `pairs` dataset (with the richer prompts)…
import numpy as np

# helper to get length of any text
def tok_len(txt):
    return len(tok(txt)["input_ids"])

# get all prompt lengths
prompt_lens = [tok_len(p) for p in pairs["prompt"]]
# get all completion lengths (including your </|json|> + eos)
completion_lens = [tok_len(c) for c in pairs["completion"]]

max_pr = np.max(prompt_lens)
max_co = np.max(completion_lens)
print("max prompt:", max_pr, "max completion:", max_co,
      "→ total:", max_pr + max_co)




MAX_LEN = max_pr + max_co + 5          # 120 prompt + 136 completion


## Tokenise with masking  ► Code

In [None]:
# Function definitions


def tok_fn(example):
    pr_ids = tok(example["prompt"], add_special_tokens=False)["input_ids"]
    co_ids = tok(example["completion"], add_special_tokens=False)["input_ids"]

    # truncate *only if* needed, keeping JSON intact
    take_pr = min(len(pr_ids), MAX_LEN - len(co_ids))
    pr_ids  = pr_ids[-take_pr:]         # trim from the front if necessary

    ids    = pr_ids + co_ids
    attn   = [1] * len(ids)
    labels = [-100] * len(pr_ids) + co_ids

    # pad right
    pad = [tok.pad_token_id] * (MAX_LEN - len(ids))
    ids    += pad
    attn   += [0] * len(pad)
    labels += [-100] * len(pad)

    return {"input_ids": ids, "attention_mask": attn, "labels": labels}

tok_ds = pairs.map(tok_fn, remove_columns=pairs.column_names)
train_ds, val_ds = tok_ds.train_test_split(test_size=5_000, seed=42).values()


In [None]:
# Save to disk to save sometime next time
tok_ds .save_to_disk('/content/drive/MyDrive/chezz/chezz_tok_ds')
train_ds.save_to_disk('/content/drive/MyDrive/chezz/chezz_train_ds')
val_ds  .save_to_disk('/content/drive/MyDrive/chezz/chezz_val_ds')


In [None]:
# Load from disk
from datasets import load_from_disk

tok_ds   = load_from_disk('/content/drive/MyDrive/chezz/chezz_tok_ds')
train_ds = load_from_disk('/content/drive/MyDrive/chezz/chezz_train_ds')
val_ds   = load_from_disk('/content/drive/MyDrive/chezz/chezz_val_ds')

## Build TinyLlama + LoRA adapters  ► Code

In [None]:
# Imports
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto"
)

lora_cfg = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj"],
    task_type="CAUSAL_LM"
)
model = get_peft_model(base, lora_cfg)

model.resize_token_embeddings(len(tok))
model.config.pad_token_id = tok.pad_token_id

model.print_trainable_parameters()


## Training arguments & Trainer  ► Code

In [None]:
# Imports
from transformers import TrainingArguments, Trainer, default_data_collator

TOTAL_STEPS   = 46_000          # your target budget
WARMUP_STEPS  = 500             # ~1.1 % of the run
PEAK_LR       = 5e-5            # LR right after warm-up
BATCH_SIZE    = 32              # per-device
GRAD_ACCUM    = 1               # effective batch = 32

args = TrainingArguments(
    output_dir                 = PROJECT_DIR,
    max_steps                  = TOTAL_STEPS,    # use steps, not epochs
    per_device_train_batch_size= BATCH_SIZE,
    gradient_accumulation_steps= GRAD_ACCUM,
    learning_rate              = PEAK_LR,
    lr_scheduler_type          = "cosine",       # maps to get_cosine_schedule_with_warmup
    warmup_steps               = WARMUP_STEPS,   # explicit beats warmup_ratio
    fp16                       = True,
    tf32                       = True,
    optim                      = "adamw_torch_fused",
    logging_steps              = 20,
    save_steps                 = 1000,
    eval_strategy        = "steps",
    eval_steps                 = 1000,
    load_best_model_at_end     = True,
    metric_for_best_model      = "eval_loss",
    report_to                  = [],
)


def compute_metrics(pred):
    dec = tok.batch_decode(pred.predictions, skip_special_tokens=False)
    lab = tok.batch_decode(pred.label_ids,      skip_special_tokens=False)
    em  = sum(d.strip()==l.strip() for d,l in zip(dec, lab)) / len(dec)
    return {"exact_match": em}

trainer = Trainer(
    model          = model,
    args           = args,
    train_dataset  = train_ds,
    eval_dataset   = val_ds,
    data_collator  = default_data_collator
)


## (Opt.) peek at generation every 1000 steps  ► Code

In [None]:
# Imports
import re
import torch
from transformers import TrainerCallback


class Peek(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        if state.global_step and state.global_step % 1000 == 0:
            model.eval()

            # 1) Build the prompt with both open & close tags mirrored
            prompt = (
                SYSTEM
              + "<|user|>FEN: R7/8/5pk1/8/5r1p/7K/5P2/8 w - -\n\n"
              + "<|assistant|><|json|></|json|>"
            )

            # 2) Tokenize & send to device
            inputs = tok(
                prompt,
                return_tensors="pt",
                padding="longest",
                truncation=True
            ).to(model.device)

            # 3) Generate under no-grad, without caching to save memory
            with torch.inference_mode():
                out = model.generate(
                    **inputs,
                    max_new_tokens=MAX_LEN,     # large enough for your JSON
                    do_sample=False,            # greedy
                    pad_token_id=tok.pad_token_id,
                    use_cache=False             # ← disables kv‐cache
                )

            # 4) Decode & extract only the JSON payload
            raw = tok.decode(out[0], skip_special_tokens=False)


            print(f"\n[Sample @ step {state.global_step}]\n{raw}\n")

            # 5) Return to training mode & free up GPU memory
            model.train()
            torch.cuda.empty_cache()

# register it
trainer.add_callback(Peek)

## Train (or resume)  ► Code

In [None]:
# first run
# trainer.train()

# if you ever disconnect, resume like this:
last = "/content/drive/MyDrive/chezz/checkpoint-21000"
trainer.train(resume_from_checkpoint=last)


## Save the first stage adapter

In [None]:
# Imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

MODEL_ID    = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
BASE_ID   = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CHKPT_DIR = "/content/drive/MyDrive/chezz/checkpoint-46000"   # your stage-A folder

tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token    = tok.eos_token
tok.padding_side = "right"
tok.add_special_tokens({
    "additional_special_tokens": ["<|json|>", "</|json|>"]
})



# 1. load base weights (fp16 to save VRAM)
base  = AutoModelForCausalLM.from_pretrained(
    BASE_ID, torch_dtype=torch.float16, device_map="cpu"
)

base.resize_token_embeddings(len(tok))

# 2. attach the LoRA adapter that lives inside the checkpoint
skel  = PeftModel.from_pretrained(base, CHKPT_DIR)

# 3. write only the adapter weights+config to disk (~8-10 MB)
skel.save_pretrained("/content/drive/MyDrive/chezz/adpaters/adapter_chhezz_move")


## Test the finished checkpoint

In [None]:
# Test the finished checkpoint here

# Imports
import time, json, re, torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# ───────── Config ─────────
MODEL_ID    = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CHECKPOINT  = "/content/drive/MyDrive/chezz/checkpoint-46000"
TEST_FILE   = "/content/drive/MyDrive/chezz/data/test_5k.jsonl"
TEST_SIZE    = 1000
BATCH_SIZE   = 64
MAX_LEN     = 80
DEVICE      = "cuda"
OUT_PATH    = "/content/predictions.jsonl"

# ───────── Tokenizer ─────────
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token    = tok.eos_token
tok.padding_side = "right"
tok.add_special_tokens({
    "additional_special_tokens": ["<|json|>", "</|json|>"]
})

# ───────── Model + LoRA ─────────
base  = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(DEVICE)
base.resize_token_embeddings(len(tok))
model = PeftModel.from_pretrained(base, CHECKPOINT).to(DEVICE)
model.eval()

# ───────── Prompt context ─────────
SCHEMA = '{"from":"","to":"","piece":"","explanation":"","taunt":""}'
SYSTEM = (
    "<|system|> You are ChezzBot-β, a dry-humored, mildly anxious chess coach "
    "… Respond only with JSON matching this schema: "
    + SCHEMA + " "
)

# ───────── Load test set ─────────
ds     = load_dataset("json", data_files=TEST_FILE, split="train")
fens   = ds["fen"][:TEST_SIZE]
truths = ds["best_move_json"][:TEST_SIZE]

# ───────── Prepare output file ─────────
out_f = open(OUT_PATH, "w")

# ───────── Batched eval + logging ─────────
total = correct = 0
start = time.time()

for i in range(0, TEST_SIZE, BATCH_SIZE):
    batch_fens   = fens[i : i + BATCH_SIZE]
    batch_truths = truths[i : i + BATCH_SIZE]

    prompts = [
        SYSTEM
      + "<|user|>FEN: " + fen + "\n\n"
      + "<|assistant|><|json|>"
        for fen in batch_fens
    ]
    inputs = tok(prompts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)

    with torch.inference_mode():
        outs = model.generate(
            **inputs,
            max_new_tokens=MAX_LEN,
            do_sample=False,
            pad_token_id=tok.pad_token_id,
            use_cache=True
        )
    raws = tok.batch_decode(outs, skip_special_tokens=True)

    for fen, raw, true in zip(batch_fens, raws, batch_truths):
        # 1) Log raw JSON to console
        # print(raw)

        result = raw.split('<|assistant|>', 1)[1]

        print(result)
        # 2) Parse JSON
        try:
            pred = json.loads(result.strip())
        except json.JSONDecodeError:
            print('failed')
            m    = re.search(r"\{.*?\}", raw, re.S)
            jstr = m.group(0) if m else "{}"
            idx  = jstr.rfind("}")
            pred = json.loads(jstr[:idx+1]) if idx>=0 else {}

        # 3) Score
        is_corr = (pred.get("from")==true.get("from") and pred.get("to")==true.get("to"))
        total  += 1
        correct+= 1 if is_corr else 0

        # 4) Write to file
        out_f.write(json.dumps({
            "fen": fen,
            "pred": pred,
            "true": true,
            "correct": is_corr
        }) + "\n")

# ───────── Wrap up ─────────
duration = time.time() - start
out_f.close()

print(f"\nMove-accuracy on {total} positions: {correct/total:.2%}")
print(f"Total GPU eval time: {duration:.1f}s  ({duration/total:.3f}s per pos)")
print(f"Predictions written to {OUT_PATH}")


## Import libraries

In [None]:
# Imports
import json, re, torch, numpy as np
from datasets      import load_dataset, concatenate_datasets
from transformers   import (AutoTokenizer,
                            TrainingArguments, Trainer,
                            default_data_collator)
from peft import AutoPeftModelForCausalLM
from peft           import LoraConfig

## Paths & knobs

In [None]:
# Code execution
CKPT_DIR  = "/content/drive/MyDrive/chezz/stage_moves/checkpoint-46000"     # ← finished run
NEW_PATH  = "/content/drive/MyDrive/chezz/data/final_humor.jsonl"                # ← new data
OLD_PATH  = "/content/drive/MyDrive/chezz/data/train_500k.jsonl.gz"                      # optional slice

BOOST_NEW    = 10          # oversample factor (≈10 % of each batch)
KEEP_OLD_N   = 50_000      # how many old rows to keep for anti-forgetting
EXTRA_STEPS  = 4_000
PEAK_LR      = 2e-5
BATCH_SIZE   = 8           # per-GPU (fp16 fits on ≤16 GB)

MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

## Tokenizer & special tokens

In [None]:
# Tokenizer setup
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token    = tok.eos_token
tok.padding_side = "right"
tok.add_special_tokens({"additional_special_tokens": ["<|json|>", "</|json|>"]})
end_json_id = tok.convert_tokens_to_ids("</|json|>")

## Prompt header

In [None]:
# Code execution
SCHEMA = """
{
  "from": "<square>",
  "to": "<square>",
  "piece": "<piece>",
  "explanation": "<text>",
  "taunt": "<text>"
}
"""

SYSTEM = (
    "<|system|> You are **ChezzBot-β**, a dry-humored, mildly anxious chess coach "
    "who is utterly certain every move you make is textbook-perfect. You always play as "
    "the side to move; your opponent is the other color whom you tease with sarcastic digs. "
    "Explain your move in one confident sentence (≤25 words) using real chess ideas, then "
    "taunt the user in ≤15 words of playful mockery. Respond *only* with JSON: "
    + SCHEMA + " "
)

## Helpers to build prompt/completion

In [None]:
# Function definitions
text_re = re.compile(r"FEN:\s*(.*?)\s*Best Move JSON:\s*(\{.*\})", re.S)

def build_pair(fen, move):
    comp = {
        "from":        move["from"].lower(),
        "to":          move["to"].lower(),
        "piece":       move["piece"].lower(),
        "explanation": move.get("explanation", ""),
        "taunt":       move.get("taunt", "")
    }
    comp_str = json.dumps(comp, separators=(",",":"))
    prompt = SYSTEM + "<|user|>FEN: " + fen + "\n\n<|assistant|><|json|>"
    return {"prompt": prompt, "completion": comp_str + tok.eos_token}

def to_pairs_old(rec):
    m = text_re.search(rec["text"])
    if not m:
        raise ValueError("Can't parse record")
    fen, js = m.groups()
    return build_pair(fen, json.loads(js))

def to_pairs_new(rec):
    return build_pair(rec["fen"], rec["best_move_json"])

## Load & oversample datasets

In [None]:
# Dataset loading
new_raw   = load_dataset("json", data_files=NEW_PATH, split="train")
new_pairs = new_raw.map(to_pairs_new, remove_columns=new_raw.column_names)

# oversample
new_pairs_boost = concatenate_datasets([new_pairs] * BOOST_NEW)

# optional small slice of the old dataset for stability
old_raw   = load_dataset("json", data_files=OLD_PATH, split=f"train[:{KEEP_OLD_N}]")
old_pairs = old_raw.map(to_pairs_old, remove_columns=old_raw.column_names)

pairs = concatenate_datasets([old_pairs, new_pairs_boost]).shuffle(seed=42)

## Token-level packing

In [None]:
# Function definitions
def tok_len(txt): return len(tok(txt)["input_ids"])

MAX_LEN = max(map(tok_len, pairs["prompt"])) + \
          max(map(tok_len, pairs["completion"])) + 5

def tok_fn(ex):
    pr_ids = tok(ex["prompt"], add_special_tokens=False)["input_ids"]
    co_ids = tok(ex["completion"], add_special_tokens=False)["input_ids"]

    take_pr = min(len(pr_ids), MAX_LEN - len(co_ids))
    pr_ids  = pr_ids[-take_pr:]

    ids    = pr_ids + co_ids
    attn   = [1]*len(ids)
    labels = [-100]*len(pr_ids) + co_ids

    pad = [tok.pad_token_id]*(MAX_LEN-len(ids))
    ids   += pad; attn += [0]*len(pad); labels += [-100]*len(pad)
    return {"input_ids": ids, "attention_mask": attn, "labels": labels}

tok_ds  = pairs.map(tok_fn, remove_columns=pairs.column_names,
                    num_proc=4, desc="pack")
train_ds, val_ds = tok_ds.train_test_split(test_size=500, seed=42).values()

## Load previous LoRA checkpoint

In [None]:
# Imports
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft          import PeftModel

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token    = tokenizer.eos_token
tokenizer.padding_side = "right"
tokenizer.add_special_tokens({"additional_special_tokens": ["<|json|>", "</|json|>"]})

base = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype="auto",
            device_map="auto"
        )
base.resize_token_embeddings(len(tokenizer))      # ← 32 002

model = PeftModel.from_pretrained(base, CKPT_DIR)  # loads LoRA cleanly
model.enable_adapter_layers()
# model.print_trainable_parameters()  # sanity-check


## TrainingArguments

In [None]:
# Code execution
args = TrainingArguments(
    output_dir                 = CKPT_DIR,
    max_steps                  = EXTRA_STEPS,
    per_device_train_batch_size= BATCH_SIZE,
    learning_rate              = PEAK_LR,
    lr_scheduler_type          = "cosine",
    warmup_steps               = 200,
    fp16                       = True,
    tf32                       = True,

    # ─── logging / saving / eval cadence ─────────────────────────
    logging_steps              = 20,
    save_steps                 = 500,
    eval_strategy        = "steps",     #  ← already added
    eval_steps                 = 250,
    load_best_model_at_end     = True,
    metric_for_best_model      = "eval_loss",

    label_names                = ["labels"],  #  ← ADD THIS LINE
    report_to                  = [],
)


## peek callback

In [None]:
# Imports
from transformers import TrainerCallback

class Peek(TrainerCallback):
    def on_log(self, args, state, control, **kwargs):
        if state.global_step and state.global_step % 500 == 0:
            model.eval()
            prompt = SYSTEM + (
                "<|user|>FEN: 8/8/8/8/8/8/8/4K3 w - -\n\n"
                "<|assistant|><|json|>"
            )
            with torch.inference_mode():
                out = model.generate(
                    **tok(prompt, return_tensors="pt").to(model.device),
                    max_new_tokens=128, do_sample=False,
                    pad_token_id=tok.pad_token_id, use_cache=False
                )
            print("\n[peek]", tok.decode(out[0]))
            model.train()

peek = Peek()

## Trainer & launch

In [None]:
# Training setup
trainer = Trainer(
    model          = model,
    args           = args,
    train_dataset  = train_ds,
    eval_dataset   = val_ds,
    data_collator  = default_data_collator,
    callbacks      = [peek]
)

trainer.train()        # resumes automatically from CKPT_DIR

## Test the finished checkpoint

In [None]:
# ONE-CELL 1 000-FEN BENCHMARK 

import json, re, textwrap, torch, gc, time
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

BATCH       = 32                         # adjust to GPU RAM (16-32 on a 16 GB T4)
MAX_NEW     = 64                         # JSON rarely needs more
MODEL_ID    = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CKPT_DIR    = "/content/drive/MyDrive/chezz/stage_moves/checkpoint-46000/checkpoint-4000"  # root!
DATA_PATH   = "/content/drive/MyDrive/chezz/data/final_humor.jsonl"

# ── Load tokenizer ─────────────────────────────────────────────────────────
tok = AutoTokenizer.from_pretrained(MODEL_ID)
tok.pad_token    = tok.eos_token
tok.padding_side = "right"
tok.add_special_tokens({"additional_special_tokens": ["<|json|>", "</|json|>"]})

# ── Load base + best adapter, keep entire model on GPU in fp16 if possible ─
base  = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
base.resize_token_embeddings(len(tok))
model = PeftModel.from_pretrained(base, CKPT_DIR).half().to("cuda").eval()

# ── Build prompts for the first 1 000 rows in one go ───────────────────────
SYSTEM = textwrap.dedent("""\
<|system|>You are **ChezzBot-β**, a dry-humored, mildly anxious chess coach…
Respond *only* with JSON like:
{"from":"<square>","to":"<square>","piece":"<piece>","explanation":"<text>","taunt":"<text>"} """)

ds       = load_dataset("json", data_files=DATA_PATH, split="train[:1000]")
prompts  = [SYSTEM + f"<|user|>FEN: {r['fen']}\n\n<|assistant|><|json|>" for r in ds]

# ── Tokenise once (padding to longest) ─────────────────────────────────────
enc = tok(prompts, return_tensors="pt", padding=True).to("cuda")

# ── Batched generation ─────────────────────────────────────────────────────
hits = 0
start = time.time()

for i in range(0, len(prompts), BATCH):
    batch_in = {k: v[i:i+BATCH] for k, v in enc.items()}
    with torch.inference_mode():
        out = model.generate(
            **batch_in,
            max_new_tokens=MAX_NEW,
            do_sample=True,
            pad_token_id=tok.pad_token_id
        )

    decoded = tok.batch_decode(out, skip_special_tokens=True)
    for txt in decoded:
        print(txt.split('assistant|>', 1)[1])
        hits += 1

torch.cuda.empty_cache(); gc.collect()
elapsed = time.time() - start
print(f"\n{hits}/{len(ds)} outputs had both fields.  Time: {elapsed:.1f} s")


## Save to adapter stage 2

In [None]:
# Imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch, pathlib

# ── paths ───────────────────────────────────────────────────────────
BASE_ID   = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"          # base model
CHKPT_DIR = "/content/drive/MyDrive/chezz/stage_moves/checkpoint-46000/checkpoint-4000"  # root that holds adapter_model.bin
DEST_DIR  = "/content/drive/MyDrive/chezz/adapters/adapter_exp_taunt"

# ── tokenizer (must include the 2 extra tokens) ────────────────────
tok = AutoTokenizer.from_pretrained(BASE_ID)
tok.pad_token    = tok.eos_token
tok.padding_side = "right"
tok.add_special_tokens({"additional_special_tokens": ["<|json|>", "</|json|>"]})

# ── 1. load the base weights (CPU or GPU, fp16 saves RAM) ──────────
base = AutoModelForCausalLM.from_pretrained(
           BASE_ID,
           torch_dtype=torch.float16,     # fp16 ≈ 2 GB on CPU
           device_map="cpu"
       )
base.resize_token_embeddings(len(tok))    # now 32 002 tokens

# ── 2. attach the LoRA adapter from your checkpoint ────────────────
model = PeftModel.from_pretrained(base, CHKPT_DIR)

# (optional) verify we really have trainable LoRA weights only
model.print_trainable_parameters()

# ── 3. save just the adapter (weights + config) ────────────────────
pathlib.Path(DEST_DIR).mkdir(parents=True, exist_ok=True)
model.save_pretrained(DEST_DIR)
print("✓ adapter exported to", DEST_DIR)
