# Fine-Tune Qwen3-8B for Financial Extraction (Colab)

Fine-tunes Qwen3-8B-Instruct using QLoRA for SEC 10-K extraction.

**Features**: Custom loss, `/no_think` mode, **saves merged model for MLX**

## 1. Install Dependencies

In [None]:
%%capture
%pip install unsloth
%pip install --no-deps trl peft accelerate bitsandbytes
%pip install datasets pydantic loguru

## 2. Clone Repository

In [None]:
!git clone https://github.com/ineedmoney527/fine-tuning-sec-fillings.git
%cd fine-tuning-sec-fillings

## 3. Load Model

In [None]:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template

MODEL_NAME = "unsloth/Qwen3"
MAX_SEQ_LENGTH = 8192

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,
    load_in_4bit=True,
)
print(f"Loaded: {MODEL_NAME}")

## 4. Configure LoRA

In [None]:
model = FastLanguageModel.get_peft_model(
    model, r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    bias="none", use_gradient_checkpointing="unsloth", random_state=42,
)
tokenizer = get_chat_template(tokenizer, chat_template="qwen-2.5")
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
print("LoRA configured")

## 5. Load Data

In [None]:
import json
from datasets import Dataset

data = [json.loads(line) for line in open("data/train.jsonl") if line.strip()]

formatted_data = []
for ex in data:
    messages = [{'role': m['role'], 'content': m['content'] + ' /no_think' if m['role']=='user' else m['content']} for m in ex['messages']]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    formatted_data.append({"text": text})

dataset = Dataset.from_list(formatted_data)
print(f"Loaded {len(dataset)} examples")

## 6. Custom Trainer

In [None]:
import os
import torch
from trl import SFTTrainer, SFTConfig

# CRITICAL: Enable logits for custom loss (Unsloth disables by default since Nov 2024)
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'

FINANCIAL_KEYS = {"revenue", "net_income", "operating_income", "total_assets", "cash_and_equivalents", "diluted_eps", "value", "unit"}
WEIGHTS = {"json_key": 2.0, "number": 1.5, "json_structure": 1.2, "default": 1.0}

class CustomFinancialTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        tok = self.processing_class
        self.key_tokens = set(sum([tok.encode(f'"{k}"', add_special_tokens=False) + tok.encode(k, add_special_tokens=False) for k in FINANCIAL_KEYS], []))
        self.struct_tokens = set(sum([tok.encode(c, add_special_tokens=False) for c in '{}[]:",'], []))
        self.digit_tokens = set(sum([tok.encode(d, add_special_tokens=False) for d in '0123456789.-'], []))
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels", inputs["input_ids"].clone())
        outputs = model(**inputs)
        logits = outputs.logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        
        weights = torch.ones_like(shift_labels, dtype=torch.float32)
        for b in range(shift_labels.shape[0]):
            for i in range(shift_labels.shape[1]):
                lid = shift_labels[b,i].item()
                if lid == -100: weights[b,i] = 0.0
                elif lid in self.key_tokens: weights[b,i] = WEIGHTS["json_key"]
                elif lid in self.struct_tokens: weights[b,i] = WEIGHTS["json_structure"]
                elif lid in self.digit_tokens: weights[b,i] = WEIGHTS["number"]
        weights = weights.to(logits.device)
        
        loss = torch.nn.CrossEntropyLoss(reduction='none')(logits.view(-1, logits.size(-1)), shift_labels.view(-1))
        mask = weights.view(-1) > 0
        loss = (loss * weights.view(-1))[mask].sum() / weights.view(-1)[mask].sum() if mask.sum() > 0 else loss.sum()
        return (loss, outputs) if return_outputs else loss

print("CustomFinancialTrainer ready (logits enabled)")

## 7. Train

In [None]:
training_args = SFTConfig(
    output_dir="./outputs", per_device_train_batch_size=2, gradient_accumulation_steps=4,
    warmup_steps=5, num_train_epochs=3, learning_rate=2e-4, fp16=True,
    logging_steps=1, save_strategy="epoch", optim="adamw_8bit",
    weight_decay=0.01, lr_scheduler_type="linear", seed=42,
    max_seq_length=MAX_SEQ_LENGTH, packing=False, dataset_text_field="text",
)

trainer = CustomFinancialTrainer(model=model, tokenizer=tokenizer, train_dataset=dataset, args=training_args)
trainer.train()
print("Training complete!")