In [None]:
!python -V
!pip list

In [None]:
!pip install -q -U git+https://github.com/fabienfrfr/tptt@main

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m925.2 kB/s[0m eta [36m0:00:00[0m0:01[0m00:01[0mm
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━

In [None]:
import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding, DataCollatorForLanguageModeling, TrainerCallback
from datasets import load_dataset
import tptt

In [None]:
# Step 1: Load configuration and initialize the TPTT model
# Using a pretrained backbone (TinyLlama in this example)
config = tptt.TpttConfig(base_model_name="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", mag_weight=0.1, inject_liza=False)
model = tptt.TpttModel(config)

# Step 2: (Optional) Inject LoRA adapters for parameter-efficient fine-tuning
model.add_lora()

# Step 3: Load the tokenizer corresponding to the base model
tokenizer = AutoTokenizer.from_pretrained(config.base_tokenizer_name)
# Ensure the tokenizer has a padding token for batching
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token or "[PAD]"

In [None]:
model

In [None]:
# Step 4: Prepare the training dataset
# Here we use a small subset of the Alpaca dataset for demonstration purposes
raw_dataset = load_dataset("yahma/alpaca-cleaned")["train"].select(range(100))

def preprocess_fn(samples):
    """
    Tokenize the samples for causal language modeling.
    Concatenate instruction, input, and output as needed.
    """
    prompts = [
        f"{instr}\n{inp}" if inp else instr
        for instr, inp in zip(samples["instruction"], samples["input"])
    ]
    # Optionally, append output for supervised fine-tuning
    prompts = [f"{p}\n{out}" for p, out in zip(prompts, samples["output"])]
    tokens = tokenizer(
        prompts,
        truncation=True,
        max_length=256,
        padding="max_length",
        return_attention_mask=True,
    )
    tokens["labels"] = tokens["input_ids"].copy()
    return tokens

tokenized_dataset = raw_dataset.map(
    preprocess_fn, batched=True, remove_columns=raw_dataset.column_names
)

# Tokenize the dataset in batches and remove original columns
tokenized_dataset = raw_dataset.map(
    preprocess_fn, batched=True, remove_columns=raw_dataset.column_names)

In [None]:
# Step 5: Set up a data collator for dynamic padding during training
#data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding="max_length")
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)


# Step 6: Define HuggingFace TrainingArguments for reproducible training
training_args = TrainingArguments(
    output_dir="./tptt_output",
    per_device_train_batch_size=2, # 2 GPU
    num_train_epochs=1,
    learning_rate= 1e-5, #2e-4, too brutal ?
    max_grad_norm=1.0, # gradiant clipping
    fp16=True, #fp16=True,  # Use mixed precision if supported by hardware --> doesn't work : return NaN
    logging_steps=1,
    save_strategy="epoch",
    report_to="tensorboard",
)

# Step 7: Initialize the HuggingFace Trainer
initial_weight=0.01,
final_weight=0.5,
transition_step=500,
liza_callback = tptt.AdjustMaGWeightCallback(
            model,
            initial_weight=initial_weight,
            final_weight=final_weight,
            transition_step=transition_step,)

# Trainer will automatically handle device placement (CPU/GPU)
trainer = Trainer(
    model=model,#.backbone,  # Use the underlying HF model for training
    args=training_args,
    #label_names=["labels"],  # (peft warning, but doesn't exist!)
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    processing_class=tokenizer,
    #callbacks=[liza_callback],
    #callbacks=[DebugLossCallback()],
)

# batch = next(iter(trainer.get_train_dataloader()))
# print("Labels:", batch["labels"][0]) # Verify if special tokens (-100) for padding
# print(batch.keys()) # dict_keys(['input_ids', 'attention_mask', 'labels'])

# Step 8: Launch training
trainer.train()

In [6]:
# Step 9: Prepare the model for inference
# Move model to the desired device (e.g., "cuda:0" or "cpu") for generation
device = 0 if torch.cuda.is_available() else -1
model.backbone.to(f"cuda:{device}" if device != -1 else "cpu")

# Step 10: Build the inference pipeline with the correct device and tokenizer
pipe = tptt.TpttPipeline(model=model.backbone, tokenizer=tokenizer, device=device)

# Step 11: Generate text from a prompt
result = pipe("Once upon a time,", max_new_tokens=150)
print(result[0]["generated_text"])  # Print the generated text

Device set to use cuda:0


Once upon a time,


In [None]:
# Step 12: Save the trained model and tokenizer for future use or deployment
model.save_pretrained("./my_tptt_model")
tokenizer.save_pretrained("./my_tptt_model")

#### DEBUG PARTS

In [6]:
### DEBUG Callback
class DebugLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and "loss" in logs:
            print(f"[DEBUG][Callback] Step {state.global_step} loss: {logs['loss']}")

batch = next(iter(trainer.get_train_dataloader()))
print("\n[DEBUG] Batch keys:", batch.keys())
print("[DEBUG] input_ids (first 10):", batch["input_ids"][0][:10])
print("[DEBUG] labels (first 10):", batch["labels"][0][:10])
print("[DEBUG] labels unique:", torch.unique(batch["labels"]))
print("[DEBUG] nb non-masked labels:", (batch["labels"] != -100).sum().item())
output = model(**batch)
print("[DEBUG] Forward output.loss:", output.loss)


[DEBUG] Batch keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
[DEBUG] input_ids (first 10): tensor([    1, 12027,  7420,  2020,   278,  2183,  5023,   338,  2743, 29889],
       device='cuda:0')
[DEBUG] labels (first 10): tensor([    1, 12027,  7420,  2020,   278,  2183,  5023,   338,  2743, 29889],
       device='cuda:0')
[DEBUG] labels unique: tensor([ -100,     1,    13,   263,   278,   286,   287,   289,   292,   297,
          304,   310,   322,   338,   339,   341,   353,   363,   367,   372,
          385,   393,   394,   403,   408,   411,   421,   440,   445,   450,
          451,   470,   471,   472,   488,   491,   508,   512,   513,   515,
          526,   573,   607,   798,   884,   896,   921,   935,   947,   964,
          967,  1009,  1033,  1090,  1152,  1206,  1209,  1304,  1338,  1363,
         1412,  1438,  1494,  1565,  1576,  1598,  1840,  2000,  2020,  2057,
         2089,  2183,  2319,  2330,  2485,  2486,  2688,  2743,  2920,  2998,
         3148,  