# Training

In [1]:
import os
import torch
import random
import numpy as np
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model, PeftModel
import argparse



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#only necessary once: loads model into models folder in your data
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "google/gemma-2-9b-it"
local_model_path = "data/models/gemma-2-9b-instruct"
token = "..." # made on hugging face

# Download tokenizer with token
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_auth_token=token,       # <- required for gated model
    cache_dir=local_model_path  # <- store in your data directory
)

# Download model in 8-bit directly to your data directory
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_auth_token=token,       # <- required for gated model
    cache_dir=local_model_path, # <- store in your data directory
    load_in_8bit=True,
    device_map="auto"
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 4/4 [00:17<00:00,  4.30s/it]


# Set which model to train and it's seed

In [None]:
#merge files
import json

def merge_jsonl_unique_ids(file1, file2, output_file):
    print(file1)
    print(file2)
    print(output_file)
    seen_ids = set()
    count = 0
    with open(output_file, "w", encoding="utf-8") as out_f:
        for fname in [file1, file2]:
            with open(fname, "r", encoding="utf-8") as in_f:
                for line in in_f:
                    obj = json.loads(line)
                    obj_id = obj.get("id")
                    if obj_id not in seen_ids:
                        count +=1
                        out_f.write(json.dumps(obj, ensure_ascii=False) + "\n")
                        seen_ids.add(obj_id)
    print(f"There are {count} entries in the merged file")

The following codeblocks can be used to merge different files to get one train, one val and one test file for each adapter.

In [None]:
simple_train1= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_train.jsonl"#"data/Accessibility_Seminar/datasets_for_models_chunked/simplifier_train.jsonl",
simple_train2= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_train_part2.jsonl"       
simple_val1= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_val.jsonl"
simple_val2= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_val_part2.jsonl"
simple_test1= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_test.jsonl"
simple_test2= "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_test_part2.jsonl"

simple_train_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_train_final.jsonl"
simple_test_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_test_final.jsonl"
simple_val_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_val_final.jsonl"

merge_jsonl_unique_ids(simple_train1, simple_train2, simple_train_final)
merge_jsonl_unique_ids(simple_test1, simple_test2, simple_test_final)
merge_jsonl_unique_ids(simple_val1, simple_val2, simple_val_final)

In [None]:
highlighter_train1= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_train.jsonl"#"data/Accessibility_Seminar/datasets_for_models_chunked/simplifier_train.jsonl",
highlighter_train2= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_train_part2.jsonl"       
highlighter_val1= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_val.jsonl"
highlighter_val2= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_val_part2.jsonl"
highlighter_test1= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_test.jsonl"
highlighter_test2= "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_test_part2.jsonl"

highlighter_train_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_train_final.jsonl"
highlighter_test_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_test_final.jsonl"
highlighter_val_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_val_final.jsonl"

merge_jsonl_unique_ids(highlighter_train1, highlighter_train2, highlighter_train_final)
merge_jsonl_unique_ids(highlighter_test1, highlighter_test2, highlighter_test_final)
merge_jsonl_unique_ids(highlighter_val1, highlighter_val2, highlighter_val_final)

In [None]:
end2end_train1= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_train.jsonl"#"data/Accessibility_Seminar/datasets_for_models_chunked/simplifier_train.jsonl",
end2end_train2= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_train_part2.jsonl"       
end2end_val1= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_val.jsonl"
end2end_val2= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_val_part2.jsonl"
end2end_test1= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_test.jsonl"
end2end_test2= "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_test_part2.jsonl"

end2end_train_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_train_final.jsonl"
end2end_test_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_test_final.jsonl"
end2end_val_final = "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_val_final.jsonl"

merge_jsonl_unique_ids(end2end_train1, end2end_train2, end2end_train_final)
merge_jsonl_unique_ids(end2end_test1, end2end_test2, end2end_test_final)
merge_jsonl_unique_ids(end2end_val1, end2end_val2, end2end_val_final)

In [3]:
# -----------------------------
# 2. Model configurations
# -----------------------------
models_config = {
    "simplifier": {
        "train_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_train_final.jsonl",#"data/Accessibility_Seminar/datasets_for_models_chunked/simplifier_train.jsonl",
        "val_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/simplifier_val_final.jsonl",#"data/Accessibility_Seminar/datasets_for_models_chunked/simplifier_val.jsonl",
        "output_dir": "output_jupyter/simplifier_syn_final"
    },
    "highlighter": {
        "train_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_train_final.jsonl",
        "val_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/highlighter_val_final.jsonl",
        "output_dir": "output_jupyter/highlighter_syn_final"
    },
    "end2end": {
        "train_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_train_final.jsonl",
        "val_file": "data/Accessibility_Seminar/datasets_for_models_synthetic/end2end_val_final.jsonl",
        "output_dir": "output_jupyter/end2end_syn_final"
    }
}

#cfg = models_config["simplifier"]#models_config[args.model_name]
#seed = 42#args.seed

cfg = models_config["end2end"]#models_config[args.model_name]
seed = 42#args.seed

In [None]:

#cfg = models_config[args.model_name]
#seed = args.seed
torch.cuda.empty_cache()

# -----------------------------
# 3. Training settings
# -----------------------------
training_args_template = {
    "eval_strategy": "steps",
    "save_strategy": "steps",
    "save_steps": 250,
    "eval_steps": 250,
    "logging_steps": 50,
    "num_train_epochs": 3,
    "learning_rate": 1e-4,
    "fp16": True,
    "logging_dir": "logs"
}

# -----------------------------
# 4. LoRA configuration
# -----------------------------
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# -----------------------------
# 5. Batch size based on VRAM
# -----------------------------
def get_batch_size():
    if not torch.cuda.is_available():
        return 1
    try:
        vram = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
        if vram >= 80:  # H100
            return 4
        if vram >= 40:  # A100
            return 2
        return 1
    except:
        return 1



def tokenize_example(batch):
    texts = [
        f"Instruction: {i}\nInput: {inp}\nOutput: {out}"
        for i, inp, out in zip(batch["instruction"], batch["input"], batch["output"])
    ]
    
    tokenized = tokenizer(
        texts,
        truncation=True,
        padding="longest",
        max_length=2048
    )

    labels = []
    for i, inp, out in zip(batch["instruction"], batch["input"], batch["output"]):
        # Tokenize separately to know lengths
        inp_text = f"Instruction: {i}\nInput: {inp}\nOutput: "
        input_ids = tokenizer(inp_text, truncation=True, max_length=2048)["input_ids"]
        output_ids = tokenizer(out, truncation=True, max_length=2048)["input_ids"]

        # For labels: ignore instruction+input, only predict output
        labels.append([-100] * len(input_ids) + output_ids)

    # Pad labels to match input_ids
    max_len = max(len(l) for l in labels)
    labels = [l + [-100]*(max_len - len(l)) for l in labels]

    tokenized["labels"] = labels
    return tokenized

# -----------------------------
# 7. Set seeds
# -----------------------------
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# -----------------------------
# 8. Load datasets
# -----------------------------
train_dataset = load_dataset("json", data_files=cfg["train_file"])["train"]
val_dataset   = load_dataset("json", data_files=cfg["val_file"])["train"]

train_dataset = train_dataset.map(tokenize_example, batched=True)
val_dataset   = val_dataset.map(tokenize_example, batched=True)

for col in ["metadata", "id", "instruction", "input", "output"]:
    if col in train_dataset.column_names:
        train_dataset = train_dataset.remove_columns(col)
    if col in val_dataset.column_names:
        val_dataset = val_dataset.remove_columns(col)

# -----------------------------
# 9. Apply LoRA
# -----------------------------
if isinstance(model, PeftModel):
    model = model.base_model

model = get_peft_model(model, lora_config)

# -----------------------------
# 10. Data collator
# -----------------------------
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# -----------------------------
# 11. Trainer
# -----------------------------
batch_size = get_batch_size()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = f"{cfg['output_dir']}_seed{seed}_{timestamp}"
os.makedirs(out_dir, exist_ok=True)

training_args = TrainingArguments(
    output_dir=out_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    remove_unused_columns=False,
    seed=seed,
    **training_args_template
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=collator
)

# -----------------------------
# 12. Train
# -----------------------------
trainer.train()

# -----------------------------
# 13. Save LoRA adapters
# -----------------------------
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
print(f"✅ Saved LoRA adapter at {out_dir}")


Generating train split: 198 examples [00:00, 22682.44 examples/s]
Generating train split: 25 examples [00:00, 999.74 examples/s]
Map: 100%|██████████| 198/198 [00:00<00:00, 276.08 examples/s]
Map: 100%|██████████| 25/25 [00:00<00:00, 184.69 examples/s]
  trainer = Trainer(


Step,Training Loss,Validation Loss
250,1.09,1.199405


✅ Saved LoRA adapter at output_jupyter/end2end_syn_final_seed42_20251118_201751


In [None]:
# Sample a batch
sample = train_dataset[0]
print("input_ids length:", len(sample["input_ids"]))
print("labels length:", len(sample["labels"]))

# Check label masking
masked_count = sum([1 for x in sample["labels"] if x == -100])
print("Number of masked tokens (should cover prompt):", masked_count)
