In [1]:
# must be set before importing torch/transformers
import os

# If reserved unallocated memory is large
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"

# (optional) avoid the fork/threads warning and nested parallelism
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# Ensures that only 1 GPU is visible to torch/accelerate/transformers/trl
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, setup_chat_format
import torch
from pathlib import Path

torch.backends.cuda.matmul.allow_tf32 = True
torch.set_float32_matmul_precision("high")

# logging.set_verbosity_error()

OUTPUT_DIR = Path.cwd().joinpath("ft")

In [2]:
# ------------------------------
# 0) Setup
# ------------------------------
# model_id = "google/gemma-3-270m"
model_id = "HuggingFaceTB/SmolLM2-135M"

# Use bf16 if available, else fp16
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# Many Qwen tokenizers have no explicit pad_token; for training we usually
# set pad_token = eos_token so padding is benign for loss.
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

max_memory = {
    0: "8GiB",  # keep part of the model on GPU0
    # 1: "8GiB",  # rest of the model goes here
    # "cpu": "24GiB",  # optional spillover/offload safety
}


try:
    # First, try to use community vLLM Flash-Attn 3
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map="auto",  # let HF place layers under the caps
        max_memory=max_memory,
        attn_implementation="kernels-community/vllm-flash-attn3",
        offload_folder="./offload",  # only used if it needs to spill to CPU
    )
    # Fallback to Flash-Attn 2
except:
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map="auto",  # let HF place layers under the caps
        max_memory=max_memory,
        attn_implementation="eager",  # flash_attention_2
        offload_folder="./offload",  # only used if it needs to spill to CPU
    )

print("Attn-Implementation:", model.config._attn_implementation)

# (Training tip) disable cache + enable checkpointing to reduce activations
model.config.use_cache = False
model.gradient_checkpointing_enable()

# Setup for the model specific chat format
if not getattr(tokenizer, "chat_template", None):
    model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# When you create batches, put inputs on the device that holds the FIRST layer.
# Auto placement usually puts embeddings & early blocks on the *smaller* device.
first_device = next(model.parameters()).device

Attn-Implementation: eager


In [3]:
# ------------------------------
# 1) Load the dataset
#    trl-lib/tldr has columns: "prompt" (the post) and "completion" (TL;DR)
# ------------------------------
dataset_name = "trl-lib/tldr"

raw_train = load_dataset(dataset_name, split="train[:5%]")
raw_val = load_dataset(dataset_name, split="test[:1%]")

In [4]:
# ------------------------------
# 2) Formatting: produce a single string per row with a clear response boundary
#
# Important details:
# - The boundary string *must* match what you pass to response_template below
#   (including spaces/punctuation/case).
# - Because Reddit posts can be long, we cap the prompt to keep the completion
#   inside the max_seq_length budget. Two versions are shown:
#     (A) simple char cap (very fast, approximate)
#     (B) token-budgeted cap (more precise, a little slower)
# Pick one and comment out the other.
# ------------------------------

MAX_LENGTH = 512
BOUNDARY = "TL;DR: "  # <-- Will be used as response_template


# (A) Simple, fast char-cap (good enough for many runs)
def format_pc_char_cap(example, max_prompt_chars=MAX_LENGTH):
    prompt = example["prompt"]
    # Trim the prompt aggressively so the summary isn't truncated
    if len(prompt) > max_prompt_chars:
        prompt = prompt[:max_prompt_chars] + "…"
    text = (
        "Summarize the post below in a single concise TL;DR.\n\n"
        f"{prompt}\n\n{BOUNDARY}{example['completion']}"
    )
    return {"text": text}


# (B) Token-budgeted cap (keeps the *end* with completion intact more reliably)
def format_pc_token_cap(example, max_len=MAX_LENGTH, reserve_for_completion=128):
    # Tokenize completion to estimate space needed for summary + boundary + EOS
    comp_ids = tokenizer(example["completion"], add_special_tokens=False)["input_ids"]
    # reserve a bit more for the boundary + eos
    reserve = min(max_len // 3, reserve_for_completion) + 16

    # Budget for the prompt = total - reserve
    prompt_budget = max_len - (len(comp_ids) + reserve)
    prompt_budget = max(prompt_budget, 32)  # still keep some prompt

    # Take only the first `prompt_budget` tokens of the prompt
    prompt_ids = tokenizer(example["prompt"], add_special_tokens=False)["input_ids"][
        :prompt_budget
    ]
    prompt_trimmed = tokenizer.decode(prompt_ids, skip_special_tokens=True)

    text = (
        "Summarize the post below in a single concise TL;DR.\n\n"
        f"{prompt_trimmed}\n\n{BOUNDARY}{example['completion']}{tokenizer.eos_token}"
    )
    return {"text": text}

In [5]:
# Choose ONE formatter:
use_token_budget = True
fmt_fn = format_pc_token_cap if use_token_budget else format_pc_char_cap

train = raw_train.map(
    fmt_fn, remove_columns=raw_train.column_names, desc="Formatting train"
)
val = raw_val.map(fmt_fn, remove_columns=raw_val.column_names, desc="Formatting val")

# Optional: drop everything except the 'text' column (keeps memory small)
train = train.remove_columns([c for c in train.column_names if c != "text"])
val = val.remove_columns([c for c in val.column_names if c != "text"])

Formatting train:   0%|          | 0/5836 [00:00<?, ? examples/s]

In [6]:
# ------------------------------
# 3) Trainer config
# Notes:
# - packing=True concatenates multiple short samples together to reach the max
#   length. This reduces padding and usually increases throughput.
# - bf16 is fast on Ampere+ (RTX 30xx); if unsupported, set bf16=False and fp16=True.
# - Adjust max_length based on your VRAM and throughput goals.
# ------------------------------
ft_filename = "SmoLM2-135M-tldr-sft"

cfg = SFTConfig(
    output_dir=OUTPUT_DIR.joinpath(ft_filename),
    dataset_text_field="text",
    max_length=MAX_LENGTH,
    packing=False,
    bf16=True,  # assumes Ampere; set to False if needed
    # fp16=True,  # mutual exclusive with bf16
    tf32=True,
    gradient_checkpointing=False,  # lowers memory, costs time; toggle as needed
    # gradient_checkpointing_kwargs={"use_reentrant": False},
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    # logging_steps=25,
    # save_steps=100,
    # eval_steps=25,
    use_liger_kernel=True,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_strategy="epoch",
    save_strategy="epoch",
    eval_strategy="epoch",
    optim="adamw_torch_fused",  # if your torch supports it; else "adamw_torch"
    # optim="paged_adamw_8bit",
    learning_rate=1e-4,
    weight_decay=0.01,
    completion_only_loss=True,
    remove_unused_columns=True,
    dataloader_num_workers=8,  # try 4–8, depends on CPU
    dataloader_persistent_workers=True,
    dataloader_prefetch_factor=4,
    dataloader_drop_last=True,
    dataloader_pin_memory=True,
    report_to=[],  # disable W&B by default
)

# If using packing, the attention implementation should be set to
# 'flash_attention_2' or 'kernels-community/vllm-flash-attn3'. Packing flattens
# batches into a single sequence, and Flash Attention is the only known attention
# mechanisms that reliably support this. Using other implementations may lead to
# cross-contamination between batches. To avoid this, either disable packing by setting
# `packing=False`, or set `attn_implementation='flash_attention_2'` or
# `attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration.

In [7]:
# ------------------------------
# 4) Trainer
# If you’re memory-constrained, you can add LoRA or 4-bit later.
# For clarity, this example fine-tunes full weights in bf16/fp16.
# ------------------------------
train = train.with_format("torch", columns=["text"])
val = val.with_format("torch", columns=["text"])


trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train,
    eval_dataset=val,
    args=cfg,
)

# Required for training with checkpointing (turns off KV cache during train)
trainer.model.config.use_cache = False

Adding EOS to train dataset:   0%|          | 0/5836 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/5836 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/5836 [00:00<?, ? examples/s]

In [8]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,2.6157,2.559173


TrainOutput(global_step=182, training_loss=2.6157454605940935, metrics={'train_runtime': 223.1241, 'train_samples_per_second': 26.156, 'train_steps_per_second': 0.816, 'total_flos': 1454780038053888.0, 'train_loss': 2.6157454605940935})

In [9]:
# Save the model
trainer.save_model(OUTPUT_DIR.joinpath(ft_filename))

In [40]:
# BEFORE/AFTER QUICK CHECK — works for full-FT or LoRA outputs
FT_DIR = OUTPUT_DIR.joinpath(ft_filename)  # your SFTConfig.output_dir

device = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

tokenizer = AutoTokenizer.from_pretrained(
    model_id, use_fast=True, trust_remote_code=True
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


def load_base():
    return (
        AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=DTYPE, trust_remote_code=True
        )
        .to(device)
        .eval()
    )


def load_finetuned():
    # If LoRA/PEFT adapters exist, attach them to the base; else load full FT weights.
    if os.path.isfile(os.path.join(FT_DIR, "adapter_config.json")):
        from peft import PeftModel

        base = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=DTYPE, trust_remote_code=True
        ).to(device)
        ft = PeftModel.from_pretrained(base, FT_DIR).to(device)
        ft.eval()
        return ft
    else:
        return (
            AutoModelForCausalLM.from_pretrained(
                FT_DIR, torch_dtype=DTYPE, trust_remote_code=True
            )
            .to(device)
            .eval()
        )


def generate(model, prompt, max_new_tokens=64, do_sample=False):
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=0.8 if do_sample else None,
            top_p=0.95 if do_sample else None,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_type_id,
        )
    gen = tokenizer.decode(
        out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )
    return gen.strip()


ds = load_dataset("trl-lib/tldr", split="validation")
prompt_text = f"Summarize the post below in a single TL;DR.\n\n{ds[42]['prompt']}"

# ---- Run BEFORE/AFTER ----
base_model = load_base()
ft_model = load_finetuned()

In [42]:
print("Prompt Text:", prompt_text)

Prompt Text: Summarize the post below in a single TL;DR.

SUBREDDIT: r/jobs

TITLE: Cold applying for a marketing position in a small local company by attaching a proposal for their business website. Feasible idea?

POST: Hello /r/jobs, I graduated a few months ago and had no luck so far to get a job in marketing/sales. 

There's a small local company (perhaps 30 employees) but they are actually pretty successful in what they're doing (known worldwide). I checked their website and it's awful. Looks like a website from the early 2000's. So I guess they are not pretty good in (online-)marketing. 

I would like to do a cold application (not sure if they are looking for a marketing guy) but I had no luck with this kind of application in the past. That's why I thought I try something different. I have good skills in photoshop, indesign and illustrator. As a teenager I also built websites using HTML, so I thought I build a dummy website fitted to their company and attach some screenshots to 

In [43]:
print("\n--- BEFORE (base) ---")
print(generate(base_model, prompt_text))


--- BEFORE (base) ---
I'm not sure if they are looking for a marketing guy or not.

I'm not sure if they are looking for a marketing guy or not.

I'm not sure if they are looking for a marketing guy or not.

I'm not sure if they are looking for a marketing guy or


In [44]:
print("\n--- AFTER (fine-tuned) ---")
print(generate(ft_model, prompt_text))


--- AFTER (fine-tuned) ---
TL;DR:  I want to do a cold application for a small local company by attaching a proposal for their business website. I don't know if they are looking for a marketing guy or not. TL;DR:  I don't know if they are looking for a marketing guy or not. TL
