# Fine‑tuning Gemma 3 270M (IT) for Plain vs Technical Classification

Este cuaderno está listo para **Google Colab**. Cubre:
- Login a Hugging Face (gated access) y smoke test del repo.
- Carga de CSVs (`text`, `label`), tokenización y `DataCollatorWithPadding`.
- Modelo `Gemma3TextForSequenceClassification` para clasificación.
- Entrenamiento, evaluación (métricas + matriz de confusión).
- Guardado/recarga e inferencia con softmax.
- (Opcional) push al Hub privado.

**Requisitos previos**: Aceptar la licencia de `google/gemma-3-270m-it` en Hugging Face y tener token.

## 1) Instalación de dependencias

In [None]:
!pip -q uninstall -y transformers >/dev/null
!pip -q install --no-deps -U git+https://github.com/huggingface/transformers.git >/dev/null
!pip -q install -U huggingface_hub datasets accelerate evaluate scikit-learn sentencepiece bitsandbytes matplotlib >/dev/null

import transformers, datasets, huggingface_hub
print('Transformers:', transformers.__version__)
print('Datasets:', datasets.__version__)
print('HF Hub:', huggingface_hub.__version__)

## 2) Login a Hugging Face (gated)

In [None]:
from huggingface_hub import login, whoami, hf_hub_download
from getpass import getpass

HF_TOKEN = getpass('Pega tu token de Hugging Face (empieza por hf_): ').strip()
login(HF_TOKEN)
print('Usuario:', whoami().get('name'))

_ = hf_hub_download(repo_id='google/gemma-3-270m-it', filename='config.json', token=HF_TOKEN)
print('OK: acceso a google/gemma-3-270m-it')

## 3) Imports y utilidades

In [None]:
import os, numpy as np, pandas as pd, torch
import matplotlib.pyplot as plt
from datasets import Dataset, DatasetDict, Value
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, roc_auc_score
from transformers import AutoTokenizer, AutoConfig, DataCollatorWithPadding, TrainingArguments, Trainer
try:
    from transformers import Gemma3TextForSequenceClassification
except Exception:
    from transformers.models.gemma3.modeling_gemma3 import Gemma3TextForSequenceClassification

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

## 4) Carga de datos (CSV con columnas `text`, `label`)

In [None]:
# EDITA LAS RUTAS A TUS CSVs
TRAIN_CSV = '/content/train.csv'
VAL_CSV   = '/content/val.csv'
TEST_CSV  = '/content/test.csv'

train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)

assert 'text' in train_df.columns and 'label' in train_df.columns
assert 'text' in val_df.columns and 'label' in val_df.columns
assert 'text' in test_df.columns and 'label' in test_df.columns
print('train:', train_df.shape, '| val:', val_df.shape, '| test:', test_df.shape)
train_df.head(2)

## 5) Tokenizador y preparación de `DatasetDict`

In [None]:
model_id = 'google/gemma-3-270m-it'
num_labels = 2
id2label = {0: 'plain', 1: 'technical'}
label2id = {'plain': 0, 'technical': 1}

tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def ensure_int_labels(df):
    if df['label'].dtype == 'object':
        df['label'] = df['label'].map(label2id)
    df['label'] = df['label'].astype('int64')
    return df

train_df = ensure_int_labels(train_df)
val_df   = ensure_int_labels(val_df)
test_df  = ensure_int_labels(test_df)

train_ds = Dataset.from_pandas(train_df, preserve_index=False).cast_column('label', Value('int64'))
val_ds   = Dataset.from_pandas(val_df,   preserve_index=False).cast_column('label', Value('int64'))
test_ds  = Dataset.from_pandas(test_df,  preserve_index=False).cast_column('label', Value('int64'))

def tokenize_function(batch):
    return tokenizer(batch['text'], truncation=True)

keep_cols = ['text', 'label']
rm_train = [c for c in train_ds.column_names if c not in keep_cols]
rm_val   = [c for c in val_ds.column_names   if c not in keep_cols]
rm_test  = [c for c in test_ds.column_names  if c not in keep_cols]

tokenized = DatasetDict({
    'train': train_ds.map(tokenize_function, batched=True, remove_columns=rm_train),
    'validation': val_ds.map(tokenize_function, batched=True, remove_columns=rm_val),
    'test': test_ds.map(tokenize_function, batched=True, remove_columns=rm_test),
})

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
print(tokenized)

## 6) Modelo Gemma3 Text para clasificación

In [None]:
config = AutoConfig.from_pretrained(model_id, token=HF_TOKEN)
config.num_labels = num_labels
config.id2label = id2label
config.label2id = label2id

model = Gemma3TextForSequenceClassification.from_pretrained(
    model_id,
    token=HF_TOKEN,
    config=config,
)
model.config.pad_token_id = tokenizer.pad_token_id

for n, p in model.named_parameters():
    if 'score' in n:
        print('Capa de clasificación:', n, p.shape)
        break

## 7) Entrenamiento con `Trainer`

In [None]:
bf16_flag = bool(torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8)
fp16_flag = False

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = preds.argmax(-1)
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'precision': p, 'recall': r, 'f1': f1}

training_args = TrainingArguments(
    output_dir='runs_gemma3_270m_cls',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy='epoch',
    logging_steps=50,
    load_best_model_at_end=False,
    metric_for_best_model='f1',
    greater_is_better=True,
    report_to='none',
    bf16=bf16_flag,
    fp16=fp16_flag,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized['train'],
    eval_dataset=tokenized['validation'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
metrics_val = trainer.evaluate(tokenized['validation'])
metrics_val

## 8) Guardar y recargar

In [None]:
save_dir = 'finetuned_gemma3_270m_cls'
trainer.save_model(save_dir)
tokenizer.save_pretrained(save_dir)

tok_ft = AutoTokenizer.from_pretrained(save_dir, use_fast=True)
model_ft = Gemma3TextForSequenceClassification.from_pretrained(save_dir)
model_ft.config.pad_token_id = tok_ft.pad_token_id
print('Recargado desde:', save_dir)

## 9) Inferencia con softmax

In [None]:
import torch.nn.functional as F

def classify_text(text: str, max_length=2048):
    enc = tok_ft(text, return_tensors='pt', truncation=True, max_length=max_length)
    with torch.no_grad():
        logits = model_ft(**enc).logits
        probs_t = F.softmax(logits, dim=-1).squeeze(0)
    pred_id = int(logits.argmax(-1).item())

    id2label_cfg = model_ft.config.id2label
    if isinstance(id2label_cfg, dict):
        label_map = {int(k) if isinstance(k, str) else k: v for k, v in id2label_cfg.items()}
        label = label_map[pred_id]
        probs = {label_map[i]: float(probs_t[i]) for i in range(len(probs_t))}
    else:
        label = id2label_cfg[pred_id]
        probs = {id2label_cfg[i]: float(probs_t[i]) for i in range(len(probs_t))}
    return {'label_id': pred_id, 'label': label, 'probs': probs, 'logits': logits.tolist()}

classify_text('This text should look very plain and simple to read.')

## 10) Métricas en validation + Matriz de confusión

In [None]:
pred = trainer.predict(tokenized['validation'])
logits = pred.predictions
y_true = pred.label_ids
y_pred = np.argmax(logits, axis=1)

m = logits.max(axis=1, keepdims=True)
probs = np.exp(logits - m); probs = probs / probs.sum(axis=1, keepdims=True)
p_pos = probs[:, 1]

acc = accuracy_score(y_true, y_pred)
p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
auc = None
try:
    if len(np.unique(y_true)) == 2:
        auc = roc_auc_score(y_true, p_pos)
except Exception:
    pass

print({'accuracy': round(acc,4), 'precision_tech': round(p,4), 'recall_tech': round(r,4), 'f1_tech': round(f1,4), 'auc_roc_tech': None if auc is None else round(auc,4)})

cm = confusion_matrix(y_true, y_pred, labels=[0,1])
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)

def plot_cm(cm, labels, title):
    fig, ax = plt.subplots(figsize=(4.5,4))
    im = ax.imshow(cm, interpolation='nearest')
    ax.set_title(title)
    ax.set_xlabel('Predicted label'); ax.set_ylabel('True label')
    ax.set_xticks(np.arange(len(labels))); ax.set_yticks(np.arange(len(labels)))
    ax.set_xticklabels(labels); ax.set_yticklabels(labels)
    thresh = cm.max() / 2.0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            val = cm[i,j]
            ax.text(j, i, f"{val:.3f}" if isinstance(val,float) else f"{val}", ha='center', va='center', color='white' if val>thresh else 'black')
    fig.tight_layout(); plt.show()

id2label_cfg = model_ft.config.id2label
if isinstance(id2label_cfg, dict):
    id2label_plot = {int(k) if isinstance(k,str) else k: v for k, v in id2label_cfg.items()}
else:
    id2label_plot = {i: id2label_cfg[i] for i in range(len(id2label_cfg))}

plot_cm(cm,      [id2label_plot[0], id2label_plot[1]], 'Confusion Matrix (counts)')
plot_cm(cm_norm, [id2label_plot[0], id2label_plot[1]], 'Confusion Matrix (row-normalized)')

## 11) (Opcional) Subir al Hugging Face Hub (privado)

In [None]:
from huggingface_hub import create_repo, whoami

me = whoami(); username = me['name']
repo_id = f"{username}/gemma3-270m-plaintech-ft"
print('Creando/subiendo a:', repo_id)
create_repo(repo_id, private=True, exist_ok=True)

tok_local = AutoTokenizer.from_pretrained('finetuned_gemma3_270m_cls', use_fast=True)
mdl_local = Gemma3TextForSequenceClassification.from_pretrained('finetuned_gemma3_270m_cls')

tok_local.push_to_hub(repo_id)
mdl_local.push_to_hub(repo_id)
print('Subido a HF Hub como:', repo_id)

## 12) (Opcional) Demo rápida en esta sesión

In [None]:
import ipywidgets as widgets
from IPython.display import display

def classify_text_runtime(text: str, max_length=2048):
    enc = tok_ft(text, return_tensors='pt', truncation=True, max_length=max_length)
    with torch.no_grad():
        logits = model_ft(**enc).logits
        probs_t = torch.softmax(logits, dim=-1).squeeze(0)
    pred_id = int(torch.argmax(probs_t).item())

    id2label_cfg = model_ft.config.id2label
    if isinstance(id2label_cfg, dict):
        label_map = {int(k) if isinstance(k, str) else k: v for k, v in id2label_cfg.items()}
        label = label_map[pred_id]
        probs = {label_map[i]: float(probs_t[i]) for i in range(len(probs_t))}
    else:
        label = id2label_cfg[pred_id]
        probs = {id2label_cfg[i]: float(probs_t[i]) for i in range(len(probs_t))}
    return {'label': label, 'probs': {k: round(v,4) for k, v in probs.items()}}

ta = widgets.Textarea(value='Type or paste your text here...', layout=widgets.Layout(width='100%', height='120px'))
btn = widgets.Button(description='Classify')
out = widgets.Output()

def on_click(_):
    out.clear_output()
    with out:
        res = classify_text_runtime(ta.value)
        print('Prediction:', res['label'])
        print('Probabilities:', res['probs'])

display(ta, btn, out)
btn.on_click(on_click)