In [31]:
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
import json

# Dataset Preparation

In [32]:
with open("entities.json", "r") as f:
    raw_data = json.load(f)

processed_data = []

for item in raw_data:
    input_text = item.get("input", "")
    output_json = item.get("output", {})
    
    response_str = json.dumps(output_json, ensure_ascii=False)
    
    full_text = (
        f"Extract entities and relationships from the text below as JSON.\n\n"
        f"Input: {input_text}\n\n"
        f"JSON Output:\n{response_str}<eos>"
    )
    
    processed_data.append({"text": full_text})

dataset = Dataset.from_list(processed_data)

len(dataset), dataset[0]

(8425,
 {'text': 'Extract entities and relationships from the text below as JSON.\n\nInput: The new processor manufactured by Intel significantly improves the performance of the latest Macbook Pro.\n\nJSON Output:\n{"entities": ["processor", "Intel", "Macbook Pro"], "relationships": [["processor", "MANUFACTURED_BY", "Intel"], ["processor", "IMPROVES_PERFORMANCE_OF", "Macbook Pro"]]}<eos>'})

# Training Device Configuration

In [33]:
MODEL_ID = "google/gemma-3-1b-it" 
OUTPUT_DIR = "gemma-ner-lora"

if torch.backends.mps.is_available():
    device = "mps"
    torch_dtype = torch.bfloat16 
elif torch.cuda.is_available():
    device = "cuda"
    torch_dtype = torch.bfloat16
else:
    device = "cpu"
    torch_dtype = torch.float32

print(f"Using device: {device}")

Using device: mps


# Model Configuration

In [34]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.padding_side = "right" # Fix for fp16 training

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch_dtype,
    device_map=device, # Auto-moves to MPS/CUDA
    use_cache=False    # Disable cache for training
)

# LoRA allows us to fine-tune only a tiny fraction of parameters
peft_config = LoraConfig(
    r=16,                   
    lora_alpha=32,          
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[         # Gemma specific target modules
        "q_proj", 
        "k_proj", 
        "v_proj", 
        "o_proj", 
        "gate_proj", 
        "up_proj", 
        "down_proj"
    ]
)

# Training Configuration

In [35]:
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    dataset_text_field="text",
    max_length=512,                
    packing=False,
    
    # Standard Training Args
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=10,
    max_steps=100,
    save_strategy="no",
    optim="adamw_torch",
    fp16=False,
    bf16=True if device == "mps" or device == "cuda" else False,
    report_to="none"
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,     
    train_dataset=dataset,
    peft_config=peft_config,
    args=training_args
)

Adding EOS to train dataset: 100%|██████████| 8425/8425 [00:00<00:00, 113598.26 examples/s]
Tokenizing train dataset: 100%|██████████| 8425/8425 [00:00<00:00, 12593.91 examples/s]
Truncating train dataset: 100%|██████████| 8425/8425 [00:00<00:00, 1465475.52 examples/s]
The model is already on multiple devices. Skipping the move to device specified in `args`.


# Train and Save Model

In [36]:
trainer.train()
trainer.save_model(OUTPUT_DIR)

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1}.


Step,Training Loss
10,1.0824
20,0.5457
30,0.4042
40,0.3786
50,0.3939
60,0.3537
70,0.3341
80,0.3652
90,0.327
100,0.3286


# Load Model for inference

In [37]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
model = PeftModel.from_pretrained(model, OUTPUT_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Inference

In [38]:

test_input = "Mukul is writing a blog on how to train LLMs"
prompt = f"Extract entities and relationships from the text below as JSON.\n\nInput: {test_input}\n\nJSON Output:\n"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs, 
    max_new_tokens=200, 
    do_sample=True, 
    temperature=0.1
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Extract entities and relationships from the text below as JSON.

Input: Mukul is writing a blog on how to train LLMs

JSON Output:
{"entities": ["Mukul", "blog", "LLMs"], "relationships": [["Mukul", "WRITING", "blog"], ["blog", "ABOUT", "LLMs"]]}
