<a href="https://colab.research.google.com/github/carloseduds/fine-tuning/blob/main/LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Fine Tuning via Transformers

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

In [None]:
# 1. Carregar o Modelo Base (ex: uma versão pequena do Llama ou Bloom)
model_name = "bigscience/bloom-1b7"
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto"
)

tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/715 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

In [None]:
# 2. Definir a Configuração do LoRA
# Aqui definimos onde os "Post-its" (adaptadores) serão colados.
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,            # O "rank" (tamanho das matrizes A e B)
    lora_alpha=32,   # Fator de escala (geralmente 2x o rank)
    lora_dropout=0.05,
    # target_modules: Quais camadas receberão o LoRA.
    # Em Transformers, geralmente focamos nas projeções de Query (q) e Value (v) da Attention.
    target_modules=["query_key_value"]
)

In [None]:
# 3. Criar o modelo PEFT
# Isso "envelopa" o modelo base e injeta as matrizes A e B.
model = get_peft_model(base_model, peft_config)

In [None]:
# 4. Verificar a redução de parâmetros
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(
        f"trainable params: {trainable_params} || "
        f"all params: {all_param} || "
        f"trainable%: {100 * trainable_params / all_param:.2f}%"
    )

print_trainable_parameters(model)

trainable params: 3145728 || all params: 1725554688 || trainable%: 0.18%


## # Fine Tuning via Classe

In [None]:
import torch
import torch.nn as nn
import math

class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, dropout=0.0):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # A: projeta para o espaço de rank r (down)
        self.A = nn.Parameter(torch.empty(in_dim, rank))
        # B: volta para o espaço original (up)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))

        # Inicialização típica: A "normal", B = 0
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)

        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()

    def forward(self, x):
        # x: [batch, in_dim]
        # x @ A -> [batch, rank]
        # (x @ A) @ B -> [batch, out_dim]
        lora_out = self.dropout(x) @ self.A @ self.B
        return self.scaling * lora_out

In [None]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear: nn.Linear, rank: int, alpha: int, dropout: float = 0.0):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features,
            linear.out_features,
            rank,
            alpha,
            dropout=dropout
        )

        # Congelar os pesos originais
        for param in self.linear.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.linear(x) + self.lora(x)

In [None]:
def replace_linear_with_lora(model: nn.Module, rank: int, alpha: int, dropout: float = 0.0):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, LinearWithLoRA(module, rank, alpha, dropout))
        else:
            replace_linear_with_lora(module, rank, alpha, dropout)