<a href="https://colab.research.google.com/github/janbanot/msc-project/blob/main/test_notebooks/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!uv pip install --upgrade transformers datasets captum quantus accelerate

In [None]:
import os
import re
import pandas as pd
import numpy as np
import torch
from datasets import Dataset
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)

In [None]:
from google.colab import drive
drive.mount('/drive')

In [None]:
# --- KONFIGURACJA ---
# Ścieżka do pliku CSV z danymi Jigsaw
DATA_PATH = '/drive/MyDrive/msc-project/jigsaw-toxic-comment/train.csv'
# Gdzie zapisać wytrenowany model
OUTPUT_MODEL_DIR = "/drive/MyDrive/msc-project/models/distilbert-jigsaw-full"
MODEL_CHECKPOINT = "distilbert-base-uncased"
BATCH_SIZE = 16          # Zwiększ do 32 jeśli masz GPU z dużą pamięcią (np. A100), dla T4 (Colab free) 16 jest bezpieczne
NUM_EPOCHS = 3           # Standard dla BERT-a to 2-4 epoki
LEARNING_RATE = 2e-5     # Standardowy LR dla fine-tuningu

In [None]:
# --- 1. PRZYGOTOWANIE DANYCH ---

def clean_text(example):
    text = example['comment_text']
    text = text.lower()
    text = re.sub(r'http\S+|www\S+', '', text)
    text = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', '', text)
    text = text.replace('\n', ' ')
    text = re.sub(r'\s+', ' ', text).strip()
    example['comment_text'] = text
    return example

print(">>> Wczytywanie danych...")
# Wczytujemy pełny zbiór. Jeśli to trwa za długo, możesz użyć .sample(n=50000)
df = pd.read_csv(DATA_PATH)

# (Opcjonalnie) Balansowanie lub ograniczenie danych, jeśli zbiór jest za duży na Colaba
# df = df.sample(frac=1.0, random_state=42) # Wymieszanie
df = df.head(10000) # Odkomentuj, jeśli chcesz trenować szybciej na 50k próbek

dataset = Dataset.from_pandas(df)
dataset = dataset.map(clean_text)

# Tokenizacja
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
def tokenize_function(examples):
    return tokenizer(examples["comment_text"], padding="max_length", truncation=True, max_length=256)

print(">>> Tokenizacja...")
tokenized_dataset = dataset.map(tokenize_function, batched=True)

# Etykietowanie (Multi-label)
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
def format_labels(example):
    example['labels'] = [float(example[col]) for col in label_cols]
    return example

tokenized_dataset = tokenized_dataset.map(format_labels)
# Usuwamy zbędne kolumny, zostawiamy te potrzebne modelowi
tokenized_dataset = tokenized_dataset.remove_columns(['id', 'comment_text'] + label_cols)
tokenized_dataset.set_format("torch")

# Split Train/Eval
splits = tokenized_dataset.train_test_split(test_size=0.1, seed=42)
train_ds = splits['train']
eval_ds = splits['test']

print(f"Dane gotowe. Train size: {len(train_ds)}, Eval size: {len(eval_ds)}")

In [None]:
# --- 2. DEFINICJA MODELU I METRYK ---

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=6,
    problem_type="multi_label_classification"
)

def compute_metrics(p):
    predictions, labels = p
    # Sigmoid, bo to multi-label (logits -> probability)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.tensor(predictions))
    preds = (probs > 0.5).int().numpy()

    # Metryki
    f1_micro = f1_score(labels, preds, average='micro')
    f1_macro = f1_score(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)

    try:
        roc_auc = roc_auc_score(labels, probs, average='macro')
    except:
        roc_auc = 0.0 # Zabezpieczenie gdy w batchu brakuje jakiejś klasy

    return {
        'f1_micro': f1_micro,
        'f1_macro': f1_macro,
        'accuracy': acc,
        'roc_auc': roc_auc
    }

In [None]:
# --- 3. TRENING ---

training_args = TrainingArguments(
    output_dir=f"{OUTPUT_MODEL_DIR}_checkpoints",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=5,
    load_best_model_at_end=True,      # Po treningu wczytaj najlepszy model
    metric_for_best_model="f1_micro", # Kryterium najlepszego modelu
    save_total_limit=2,               # Trzymaj tylko 2 ostatnie checkpointy (oszczędność miejsca)
    fp16=True,                        # Mixed precision (przyspieszenie na GPU)
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print(">>> Rozpoczynanie treningu...")
trainer.train()

In [None]:
# --- 4. ZAPIS FINALNY ---

print(f">>> Zapisywanie modelu do: {OUTPUT_MODEL_DIR}")
trainer.save_model(OUTPUT_MODEL_DIR)
tokenizer.save_pretrained(OUTPUT_MODEL_DIR)

# Zapis statystyk treningu
metrics = trainer.evaluate()
print("Final Metrics:", metrics)

with open(f"{OUTPUT_MODEL_DIR}/training_results.txt", "w") as f:
    f.write(str(metrics))