In [1]:
from data.utils import prepare_data
from tokenization.utils import build_asin_id_tokenizer

In [2]:
ds_train, ds_valid, ds_test, asin2id = prepare_data("All_Beauty")
tokenizer = build_asin_id_tokenizer(asin2id)


Map:   0%|          | 0/285 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Map:   0%|          | 0/37 [00:00<?, ? examples/s]

In [5]:
ds_train[2]


{'prompt': '53 257 52 289 407 181 220 68 260 428 338 400 439 400 195 422 416 280 122 248 315 343 101 200 457 119 450',
 'completion': '59'}

In [4]:
def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable ratio: {100 * trainable_params / total_params:.2f}%")


In [None]:
# tiny_qwen3_addition_for_prompts.py
# Same training loop, now for datasets with {prompt, completion} and a prebuilt tokenizer.

import os, random
import numpy as np
import torch

from transformers import (
    Qwen3Config,
    Qwen3ForCausalLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    set_seed,
)

from datasets import DatasetDict

# -------------------------
# 0) Repro
# -------------------------
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); set_seed(SEED)

# You already have:
# ds_train, ds_valid, ds_test, asin2id = prepare_data("All_Beauty")
# tokenizer = build_asin_id_tokenizer(asin2id)

# --- helper: robust encode for str OR list (ints/strs) ---
def _encode_field(value, tok):
    # If it's already a string: standard encode
    if isinstance(value, str):
        return tok(value, add_special_tokens=False)

    # If it's a list (e.g., [123, 45, 6] or ["B08..", "B07.."]):
    # 1) cast ints -> strings
    # 2) tell HF these are pre-tokenized "words"
    if isinstance(value, list):
        tokens = [str(x) for x in value]
        return tok(tokens, add_special_tokens=False, is_split_into_words=True)

    # Anything else is unsupported
    raise ValueError(f"Unsupported field type for tokenization: {type(value)}")

# inverse map: id -> asin token
id2asin = {v: k for k, v in asin2id.items()}

# robust converter: value -> list[str] tokens (ASINs)
def _to_token_words(value):
    if isinstance(value, str):
        return [value]
    if isinstance(value, int):
        return [id2asin.get(value, str(value))]  # fallback: stringified id
    if isinstance(value, list):
        out = []
        for x in value:
            if isinstance(x, int):
                out.append(id2asin.get(x, str(x)))
            else:
                out.append(str(x))
        return out
    raise ValueError(f"Unsupported field type for tokenization: {type(value)}")

# ensure left padding + special tokens exist
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
    tokenizer.add_special_tokens({"pad_token": "<pad>"})
if tokenizer.eos_token_id is None:
    tokenizer.add_special_tokens({"eos_token": "<eos>"})


def tok_train(ex, tok):
    enc = tok(_to_token_words(ex["prompt"]), add_special_tokens=False, is_split_into_words=True)
    dec = tok(_to_token_words(ex["completion"]), add_special_tokens=False, is_split_into_words=True)
    eos = tok.eos_token_id
    input_ids = enc["input_ids"] + dec["input_ids"] + [eos]
    labels    = [-100] * len(enc["input_ids"]) + dec["input_ids"] + [eos]
    return {"input_ids": input_ids, "attention_mask": [1]*len(input_ids), "labels": labels}

def tok_eval(ex, tok):
    enc = tok(_to_token_words(ex["prompt"]), add_special_tokens=False, is_split_into_words=True)
    dec = tok(_to_token_words(ex["completion"]), add_special_tokens=False, is_split_into_words=True)
    eos = tok.eos_token_id
    input_ids = enc["input_ids"]
    labels    = dec["input_ids"] + [eos]
    return {"input_ids": input_ids, "attention_mask": [1]*len(input_ids), "labels": labels}


# -------------------------
# 2) Collator (LEFT padding everywhere)
# -------------------------
def make_collate_fn(tokenizer):
    LABEL_PAD_ID = -100
    PAD_ID = tokenizer.pad_token_id

    def collate_fn(features):
        feats = [{
            "input_ids": list(f["input_ids"]),
            "attention_mask": list(f["attention_mask"]),
            "labels": list(f["labels"]),
        } for f in features]

        max_len = max(len(f["input_ids"]) for f in feats)

        batch_input_ids, batch_attn, batch_labels = [], [], []
        for f in feats:
            L_in, L_lb = len(f["input_ids"]), len(f["labels"])
            pad_in = max_len - L_in
            pad_lb = max_len - L_lb

            input_ids = [PAD_ID]*pad_in + f["input_ids"]
            attn      = [0]*pad_in     + f["attention_mask"]
            labels    = [LABEL_PAD_ID]*pad_lb + f["labels"]

            input_ids = input_ids[:max_len]
            attn      = attn[:max_len]
            labels    = labels[:max_len]

            batch_input_ids.append(input_ids)
            batch_attn.append(attn)
            batch_labels.append(labels)

        return {
            "input_ids": torch.tensor(batch_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(batch_attn, dtype=torch.long),
            "labels": torch.tensor(batch_labels, dtype=torch.long),
        }
    return collate_fn

# -------------------------
# 3) Metrics: compare tail(pred) with labels length (no "=" assumption)
# -------------------------
def make_compute_metrics(tokenizer, max_print=3):
    LABEL_PAD_ID = -100
    eos_id = tokenizer.eos_token_id

    def compute_metrics(eval_pred):
        preds, labels = eval_pred
        n_total = 0
        n_wrong = 0
        printed = 0
        skipped_empty = 0

        for i, (p, l) in enumerate(zip(preds, labels)):
            p = p.tolist() if hasattr(p, "tolist") else list(p)
            l = l.tolist() if hasattr(l, "tolist") else list(l)

            gold_ids = [t for t in l if t != LABEL_PAD_ID]
            if not gold_ids:
                skipped_empty += 1
                continue

            first_gold = gold_ids[0]
            first_pred = p[len(p) - len(gold_ids)]  # first predicted token aligned to labels

            n_total += 1
            if first_pred != first_gold:
                n_wrong += 1
                if printed < max_print:
                    print(f"[mistake {printed+1}] idx={i}\n"
                          f"  pred_id={first_pred}, gold_id={first_gold}\n"
                          f"  pred_tok={tokenizer.convert_ids_to_tokens([first_pred])}\n"
                          f"  gold_tok={tokenizer.convert_ids_to_tokens([first_gold])}")
                    printed += 1

        if n_total == 0:
            print(f"[metric] WARNING: no non-empty labels (skipped={skipped_empty}).")
            return {"accuracy_first_token": 0.0, "n_mistakes": 0}

        acc = 1.0 - n_wrong / n_total
        return {"accuracy_first_token": acc, "n_mistakes": n_wrong}

    return compute_metrics


# -------------------------
# 4) Build tokenized datasets
# -------------------------
ds = DatasetDict({
    "train": ds_train,
    "validation": ds_valid,
    "test": ds_test,
})

tokenized_train = ds["train"].map(lambda ex: tok_train(ex, tokenizer),
                                  remove_columns=["prompt", "completion"])
tokenized_valid = ds["validation"].map(lambda ex: tok_eval(ex, tokenizer),
                                       remove_columns=["prompt", "completion"])
tokenized_test  = ds["test"].map(lambda ex: tok_eval(ex, tokenizer),
                                 remove_columns=["prompt", "completion"])

collate_fn = make_collate_fn(tokenizer)
compute_metrics = make_compute_metrics(tokenizer)

# -------------------------
# 5) Tiny model config
# -------------------------
scale = 16
config = Qwen3Config(
    vocab_size=tokenizer.vocab_size,
    hidden_size=max(1, int(4096/scale)),
    intermediate_size=max(1, int(22016/scale)),
    num_hidden_layers=max(1, int(32/scale)),
    num_attention_heads=max(1, int(32/scale)),
    num_key_value_heads=max(1, int(32/scale)),
    head_dim=max(1, int(128/scale)),
)
model = Qwen3ForCausalLM(config)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.generation_config.max_new_tokens = 1  # adjust if your completions are longer

def count_params(m):
    total = sum(p.numel() for p in m.parameters())
    trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
    print(f"Total parameters:     {total:,}  ({total/1e6:.2f}M)")
    print(f"Trainable parameters: {trainable:,}  ({trainable/1e6:.2f}M)")
count_params(model)

# -------------------------
# 6) Train
# -------------------------
args = Seq2SeqTrainingArguments(
    output_dir="./qwen3_prompt_completion",
    eval_strategy="steps",                  # <- correct arg name
    eval_steps=0.1,  # int steps
    logging_steps=100,
    save_steps=1000,
    save_total_limit=2,
    per_device_train_batch_size=32,               # safer defaults; raise if you have GPU
    per_device_eval_batch_size=64,
    num_train_epochs=100,
    learning_rate=5e-4,
    warmup_steps=0,
    weight_decay=0.0,
    report_to="none",
    predict_with_generate=True,
    include_inputs_for_metrics=False,
    fp16=False, bf16=False,
    seed=SEED,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    data_collator=collate_fn,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_valid,
    compute_metrics=compute_metrics,
)

trainer.train()
print("Validation:", trainer.evaluate())
print("Test:", trainer.evaluate(eval_dataset=tokenized_test))


Map:   0%|          | 0/285 [00:00<?, ? examples/s]

Map:   0%|          | 0/35 [00:00<?, ? examples/s]

Map:   0%|          | 0/37 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 481, 'bos_token_id': 480, 'pad_token_id': 479}.


Total parameters:     2,395,424  (2.40M)
Trainable parameters: 2,395,424  (2.40M)


Step,Training Loss,Validation Loss,Accuracy First Token,N Mistakes
90,No log,4.686253,0.2,28
180,1.697900,5.144432,0.2,28


[mistake 1] idx=0
  pred_id=256, gold_id=129
  pred_tok=['256']
  gold_tok=['129']
[mistake 2] idx=1
  pred_id=458, gold_id=477
  pred_tok=['458']
  gold_tok=['477']
[mistake 3] idx=2
  pred_id=297, gold_id=38
  pred_tok=['297']
  gold_tok=['38']
[mistake 1] idx=0
  pred_id=256, gold_id=129
  pred_tok=['256']
  gold_tok=['129']
[mistake 2] idx=1
  pred_id=458, gold_id=477
  pred_tok=['458']
  gold_tok=['477']
[mistake 3] idx=2
  pred_id=297, gold_id=38
  pred_tok=['297']
  gold_tok=['38']
