In [1]:
!pip uninstall transformers -y
!pip install transformers==4.48.3 -qqq

Found existing installation: transformers 4.44.2
Uninstalling transformers-4.44.2:
  Successfully uninstalled transformers-4.44.2
[0m

In [1]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("TRANSFORMERS_CACHE"):
    os.environ["HF_HOME"] = os.environ.pop("TRANSFORMERS_CACHE")

import math
import re
import shutil
from typing import Any, Dict, List, Optional

import flash_attn
import pandas as pd
import tabulate
import torch
import torch.nn as nn
from datasets import Dataset, load_dataset
from flash_attn import flash_attn_qkvpacked_func
from huggingface_hub import Repository, whoami
from torch.optim import AdamW
from tqdm.auto import tqdm
from transformers import (
    AutoConfig,
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup,
)

In [2]:
# Constants
NUM_EXAMPLES_TO_TRAIN = 3000
MODEL_CHECKPOINT = "answerdotai/ModernBERT-base"
USERNAME = "emdemor"
TOKENIZER_PATH = "domain_tokenizer"
TESTING = True
FLASH_ATTENTION = False
PUSH_INTERVAL = 10_000 if TESTING else 100_000
DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

mlm_probabilities = [0.3, 0.2, 0.18, 0.16, 0.14]

In [3]:
def load_and_prepare_data(num_examples: int = NUM_EXAMPLES_TO_TRAIN) -> Dataset:
    """Load and prepare the dataset for training."""
    dataset = load_dataset("emdemor/news-of-the-brazilian-newspaper", split="train")
    df = dataset.to_pandas().sample(frac=1).reset_index(drop=True)
    temp = df.sample(min(num_examples, len(df)))
    texts = temp["text"].to_list() + temp["title"].to_list()
    texts = [phrase for text in texts if text for phrase in split_into_sentences(text)]
    return Dataset.from_dict({"text": list(set(texts))[:num_examples]})


def split_into_sentences(text: str) -> List[str]:
    """Split text into sentences."""
    return [
        sentence.strip()
        for sentence in re.split(r"(?<=[.!?])\s+", text)
        if sentence.strip()
    ]


def set_attention(model):

    def check_flash_attention_support():
        if not torch.cuda.is_available():
            return False
        try:
            qkv = torch.randn(1, 1, 3, 16, 64, dtype=torch.float16, device="cuda")
            flash_attn_qkvpacked_func(qkv, causal=False)
            return True
        except RuntimeError as e:
            print("Flash Attention não é compatível:", str(e))
            return False

    if FLASH_ATTENTION and check_flash_attention_support():
        print("Replacing standard attention with FlashAttention...")
        for module in model.modules():
            if isinstance(module, nn.MultiheadAttention):
                module.attention = FlashAttention()
        print("FlashAttention integrated.")

    return model


def setup_model_and_tokenizer(
    model_checkpoint: str, tokenizer_path: str, device: torch.device
):
    """Setup model and tokenizer."""
    print(f"Loading custom tokenizer from {tokenizer_path}...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    print(f"Loading model config from {model_checkpoint}...")
    config = AutoConfig.from_pretrained(model_checkpoint)
    config.torch_dtype = torch.float16
    model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, config=config).to(
        device
    )
    model = set_attention(model)
    return model, tokenizer


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        # No truncation and max_length to allow dynamic padding truncation=True, max_length=chunk_size, padding="longest",
        return_special_tokens_mask=True,
    )


def tokenize_dataset(dataset):

    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
    )

    return tokenized_dataset

In [4]:
from dataclasses import dataclass


@dataclass
class ModelInfo:
    model_name: str
    output_dir: str


def get_model_info():
    model_name = MODEL_CHECKPOINT.split("/")[-1]
    model_info = ModelInfo(
        model_name=model_name,
        output_dir=f"{model_name}-ptbr-{'test' if TESTING else 'full'}",
    )
    if os.path.exists(model_info.output_dir):
        shutil.rmtree(model_info.output_dir)
    return model_info

In [5]:
import tabulate


@dataclass
class TrainingConfig:
    num_train_epochs: int
    chunk_size: int | None
    per_device_train_batch_size: int
    gradient_accumulation_steps: int
    eval_size_ratio: float
    total_save_limit: int
    estimated_dataset_size_in_rows: int

    @property
    def effective_batch_size(self):
        return self.per_device_train_batch_size * self.gradient_accumulation_steps

    @property
    def total_steps_per_epoch(self):
        return math.ceil(
            self.estimated_dataset_size_in_rows / self.effective_batch_size
        )

    @property
    def total_train_steps(self):
        return self.total_steps_per_epoch * self.num_train_epochs

    @property
    def eval_size_per_chunk(self):
        return int(self.estimated_dataset_size_in_rows * self.eval_size_ratio)

    def __repr__(self):
        data = [
            ["num_train_epochs", self.num_train_epochs],
            ["chunk_size", self.chunk_size],
            ["per_device_train_batch_size", self.per_device_train_batch_size],
            ["gradient_accumulation_steps", self.gradient_accumulation_steps],
            ["eval_size_ratio", self.eval_size_ratio],
            ["total_save_limit", self.total_save_limit],
            ["estimated_dataset_size_in_rows", self.estimated_dataset_size_in_rows],
            ["effective_batch_size", self.effective_batch_size],
            ["total_steps_per_epoch", self.total_steps_per_epoch],
            ["total_train_steps", self.total_train_steps],
            ["eval_size_per_chunk", self.eval_size_per_chunk],
        ]

        return tabulate.tabulate(data, headers=["Attribute", "Value"], tablefmt="grid")

In [6]:
# --- Helper Function to Fix Batch Inputs ---
def fix_batch_inputs(inputs: dict) -> dict:
    """
    Ensures that input tensors have the correct shape and dtype.
    - Removes any extra dimensions (e.g., [1, batch, seq_len] -> [batch, seq_len]).
    - Casts input_ids to torch.long.
    """
    for key in ["input_ids", "attention_mask", "token_type_ids"]:
        if key in inputs:
            if inputs[key].dim() == 3 and inputs[key].shape[0] == 1:
                inputs[key] = inputs[key].squeeze(0)
            elif inputs[key].dim() > 2:
                raise ValueError(
                    f"Unexpected tensor shape for {key}: {inputs[key].shape}"
                )
    if "input_ids" in inputs and inputs["input_ids"].dtype != torch.long:
        inputs["input_ids"] = inputs["input_ids"].long()
    return inputs


# --- Forward Pass Function ---
def forward_pass(model, inputs):
    """
    Performs a forward pass with autocast for FP16.
    Returns the loss.
    """
    inputs = fix_batch_inputs(inputs)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
        outputs = model(**inputs, return_dict=True)
    if outputs.loss is None:
        raise ValueError("Model did not return a loss.")
    return outputs.loss


# --- Evaluation Function ---
def evaluate(model, eval_dataset, data_collator):
    """
    Evaluates the model on the evaluation dataset.
    Returns the average loss.
    """
    model.eval()
    losses = []
    eval_iterator = eval_dataset.iter(batch_size=per_device_train_batch_size)
    for batch in tqdm(eval_iterator, desc="Evaluating"):
        with torch.no_grad(), torch.amp.autocast(
            "cuda", enabled=(DEVICE.type == "cuda")
        ):
            inputs = data_collator(batch)
            try:
                loss = forward_pass(model, inputs)
                losses.append(loss.item())
            except Exception as e:
                print(f"Evaluation batch failed: {e}. Skipping.")
                continue
    model.train()
    average_loss = sum(losses) / len(losses) if losses else float("inf")
    return average_loss


class DynamicPaddingDataCollator(DataCollatorForLanguageModeling):
    """
    Data collator that dynamically pads the inputs for language modeling.
    This ensures that all sequences within a batch have the same length,
    but the overall length can vary between batches.
    """

    def __call__(self, examples: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        # Find the maximum length within the current batch
        max_length = max(len(input_ids) for input_ids in examples["input_ids"])

        # Pad or truncate each example to the max_length
        batch = []
        input_ids = examples["input_ids"]
        attention_mask = examples["attention_mask"]

        for ids, mask in zip(input_ids, attention_mask):
            padding_length = max_length - len(ids)
            if padding_length > 0:
                # Pad
                ids = torch.tensor(ids + [self.tokenizer.pad_token_id] * padding_length)
                mask = torch.tensor(mask + [0] * padding_length)
            elif padding_length <= 0:
                # Truncate (if enabled in your tokenizer)
                ids = torch.tensor(ids[:max_length])
                mask = torch.tensor(mask[:max_length])

            batch.append({"input_ids": ids, "attention_mask": mask})

        # Apply the rest of the data collation logic (MLM masking, etc.)
        batch = self.torch_call(
            batch
        )  # Use torch_call instead of __call__ to call the parent's method

        # Ensure correct shapes and dtypes
        batch = fix_batch_inputs(batch)

        return batch

In [7]:
dataset = load_and_prepare_data()

In [8]:
dataset_iterator = iter(dataset)

In [9]:
model, tokenizer = setup_model_and_tokenizer(
    model_checkpoint=MODEL_CHECKPOINT,
    tokenizer_path=TOKENIZER_PATH,
    device=DEVICE,
)

Loading custom tokenizer from domain_tokenizer...
Loading model config from answerdotai/ModernBERT-base...


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [10]:
tokenized_dataset = tokenize_dataset(dataset)

Tokenizing dataset...


Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [11]:
model_info = get_model_info()

In [12]:
mlm_probabilities = [0.3, 0.2, 0.18, 0.16, 0.14]

chunk_size_dataset = len(dataset) // len(mlm_probabilities)

In [13]:
training_config = TrainingConfig(
    num_train_epochs=1,
    chunk_size=None,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    eval_size_ratio=0.05,
    total_save_limit=2,
    estimated_dataset_size_in_rows=len(dataset),
)

print(training_config)

+--------------------------------+---------+
| Attribute                      |   Value |
| num_train_epochs               |    1    |
+--------------------------------+---------+
| chunk_size                     |         |
+--------------------------------+---------+
| per_device_train_batch_size    |    4    |
+--------------------------------+---------+
| gradient_accumulation_steps    |    2    |
+--------------------------------+---------+
| eval_size_ratio                |    0.05 |
+--------------------------------+---------+
| total_save_limit               |    2    |
+--------------------------------+---------+
| estimated_dataset_size_in_rows | 3000    |
+--------------------------------+---------+
| effective_batch_size           |    8    |
+--------------------------------+---------+
| total_steps_per_epoch          |  375    |
+--------------------------------+---------+
| total_train_steps              |  375    |
+--------------------------------+---------+
| eval_siz

In [17]:
from torch.optim import AdamW

# --- Optimizer and Scheduler ---
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=training_config.total_train_steps
)

# --- AMP scaler for mixed precision ---
scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

model.train()
global_step = 0

for epoch in range(training_config.num_train_epochs):
    for i, mlm_probability in enumerate(mlm_probabilities):
        print(
            f"\nEpoch {epoch + 1}/{training_config.num_train_epochs}, MLM Probability: {mlm_probability}"
        )

        data_collator = DynamicPaddingDataCollator(
            tokenizer=tokenizer, mlm_probability=mlm_probability
        )

        train_dataset = (
            tokenized_dataset.skip(i * chunk_size_dataset + training_config.eval_size_per_chunk)
            .take(chunk_size_dataset)
            .shuffle(seed=42)
        )
        eval_dataset = tokenized_dataset.skip(i * chunk_size_dataset).take(
            training_config.eval_size_per_chunk
        )

        train_iterator = train_dataset.iter(batch_size=training_config.per_device_train_batch_size)
        for step, batch in enumerate(
            tqdm(train_iterator, desc=f"Training (MLM {mlm_probability})")
        ):
            try:
                inputs = data_collator(batch)
                loss = forward_pass(model, inputs)
            except Exception as e:
                print(f"Training batch failed: {e}. Skipping.")
                continue

            scaler.scale(loss / gradient_accumulation_steps).backward()

            if (step + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
                torch.cuda.empty_cache()  # Clear cache
                global_step += 1

                # Evaluation
                eval_interval = total_steps_per_epoch // (num_train_epochs * 4)
                if eval_interval > 0 and (global_step % eval_interval == 0):
                    eval_loss = evaluate(model, eval_dataset, data_collator)
                    print(f"Evaluation loss at step {global_step}: {eval_loss}")

                # Push to hub incl TESTING
                if global_step % push_interval == 0:
                    print(f"Saving and pushing model at step {global_step}...")
                    model.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    print(f"Model saved and pushed at step {global_step}.")


Epoch 1/1, MLM Probability: 0.3


Training (MLM 0.3): 0it [00:00, ?it/s]

Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs o

Training (MLM 0.2): 0it [00:00, ?it/s]

Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs o

Training (MLM 0.18): 0it [00:00, ?it/s]

Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs o

Training (MLM 0.16): 0it [00:00, ?it/s]

Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs or newer.. Skipping.
Training batch failed: FlashAttention only supports Ampere GPUs o

IndexError: Index 599 out of range for dataset of size 450.