# Setup

In [1]:
import sys
sys.path.append('..')

In [2]:
from dataclasses import dataclass
from tqdm.autonotebook import tqdm
from utils.utils import set_seeds
import wandb

import numpy as np
import torch
from datasets import Dataset
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    TrainingArguments
)

set_seeds(seed=42)
tqdm.pandas()

  from tqdm.autonotebook import tqdm


# Config

In [3]:
@dataclass
class Config:
    data_path: str = "../../../data/unlp-2025/"
    cv_path: str = "../../../data/unlp-2025/cv_split.csv"
    
    pretrained: str = "google/gemma-3-4b-pt"
    max_length: int = 1024

    
    wandb_init_args = {
        'project': "sl-unlp-2025",
        'entity': "havlytskyi-thesis",
        'name': "gemma-3-4B"
    }

config = Config()

# Training Arguments

In [4]:
training_args = TrainingArguments(
    output_dir=f'./checkpoints/{config.wandb_init_args["name"]}',
    logging_dir=f'./logs/{config.wandb_init_args["name"]}',
    learning_rate=2e-5,
    weight_decay=0.01,
    lr_scheduler_type='cosine',
    warmup_ratio=0.0,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=4,
    # gradient_accumulation_steps=1,
    bf16=True,
    report_to="wandb",
    optim='adamw_torch',
    eval_strategy='steps',
    save_strategy="steps",
    eval_steps=100,
    logging_steps=10,
    save_steps=100,
    save_total_limit=10,
    metric_for_best_model='eval_f1',
    greater_is_better=True,
    load_best_model_at_end=True,
)

# Instantiate the tokenizer & model

In [5]:
from typing import Optional, Union
import copy

import torch
from torch import nn

from transformers.utils import logging
from transformers.cache_utils import Cache, StaticCache, HybridCache
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    MaskedLMOutput,
    TokenClassifierOutput
)
from transformers.processing_utils import Unpack

from transformers import Gemma3PreTrainedModel, Gemma3TextModel, Gemma3TextConfig, Gemma3ForCausalLM
from transformers.models.gemma3.modeling_gemma3 import (
    Gemma3TextScaledWordEmbedding,
    Gemma3DecoderLayer,
    Gemma3Attention,
    Gemma3MLP,
    Gemma3RMSNorm,
    Gemma3RotaryEmbedding, # TODO: try to customize this
)

from transformers.utils import (
    is_torch_flex_attn_available,
    logging,
)


class Gemma3ForTokenClassification(Gemma3PreTrainedModel):
    config_class = Gemma3TextConfig
    base_model_prefix = "language_model"

    def __init__(self, config: Gemma3TextConfig):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = Gemma3TextModel(config)
        if getattr(config, "classifier_dropout", None) is not None:
            classifier_dropout = config.classifier_dropout
        elif getattr(config, "hidden_dropout", None) is not None:
            classifier_dropout = config.hidden_dropout
        else:
            classifier_dropout = 0.1
        self.dropout = nn.Dropout(classifier_dropout)
        self.score = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value


    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs
    ) -> TokenClassifierOutput:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """

        outputs: BaseModelOutputWithPast = self.model(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.score(sequence_output)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.config)

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
from transformers import LlamaForTokenClassification, BitsAndBytesConfig

tokenizer = AutoTokenizer.from_pretrained(config.pretrained)


quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)


model = Gemma3ForTokenClassification.from_pretrained(
    config.pretrained,
    id2label={0: 0, 1: 1},
    label2id={0: 0, 1: 1},
    quantization_config=quant_config
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Gemma3ForTokenClassification were not initialized from the model checkpoint at google/gemma-3-4b-pt and are newly initialized: ['score.bias', 'score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig, TaskType


lora_config = LoraConfig(
    r=64,  # the dimension of the low-rank matrices
    lora_alpha=128, # scaling factor for LoRA activations vs pre-trained weight activations
    lora_dropout=0.05, 
    bias='none',
    inference_mode=False,
    task_type=TaskType.CAUSAL_LM,
    target_modules=['o_proj', 'v_proj', "q_proj", "k_proj", "gate_proj", "down_proj", "up_proj"]
) 

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
# Trainable Parameters
model.print_trainable_parameters()


trainable params: 119,209,984 || all params: 3,999,478,274 || trainable%: 2.9806


# Data

In [8]:
import pandas as pd

df = pd.read_parquet(config.data_path + "train.parquet")
cv = pd.read_csv(config.cv_path)
df = df.merge(cv, on='id', how='left')

df_test = pd.read_csv(config.data_path + "test.csv")

In [9]:
from utils.data import preprocess_df

df.trigger_words = df.trigger_words.apply(lambda x: [] if x is None else x)

is_valid_mask = (df.fold == 4)
df_train = df[~is_valid_mask].copy()
df_valid = df[is_valid_mask].copy()


df_train = preprocess_df(df_train, tokenizer=tokenizer, max_length=config.max_length)
df_valid = preprocess_df(df_valid, tokenizer=tokenizer, max_length=None)
df_test = preprocess_df(df_test, tokenizer=tokenizer, max_length=None)

  0%|          | 0/3058 [00:00<?, ?it/s]

  0%|          | 0/764 [00:00<?, ?it/s]

  0%|          | 0/5735 [00:00<?, ?it/s]

In [10]:
train_columns = list(df_train.seq_labels.iloc[0].keys()) +\
                ['content', 'trigger_words']
test_columns = list(df_train.seq_labels.iloc[0].keys()) + ['content']

ds_train = Dataset.from_pandas(df_train[train_columns].reset_index(drop=True))
ds_valid = Dataset.from_pandas(df_valid[train_columns].reset_index(drop=True))
ds_test = Dataset.from_pandas(df_test[test_columns].reset_index(drop=True))

# Train

In [11]:
from itertools import chain

train_labels = df_train.labels.tolist() + df_valid.labels.tolist()
positive_class_balance = pd.Series(list(chain(*train_labels))).mean()

positive_class_balance

0.2294743405414578

In [12]:
from transformers import DataCollatorForTokenClassification
from utils.trainer import SpanIdentificationTrainer

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

trainer = SpanIdentificationTrainer(
    model=model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    data_collator=data_collator,
    tokenizer=tokenizer,
    desired_positive_ratio=positive_class_balance
)

  super().__init__(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [13]:
wandb.init(**config.wandb_init_args)
wandb.define_metric("*", summary="none")

trainer.train()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mivan-havlytskyi[0m ([33mivan-havlytskyiz[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss,Precision,Recall,F1,Thold
100,0.4637,0.577363,0.55629,0.617501,0.5853,0.1
200,0.382,0.411677,0.513808,0.74566,0.608393,0.23
300,0.4452,0.416575,0.534237,0.724157,0.614865,0.2
400,0.3698,0.409037,0.517083,0.758719,0.615018,0.25
500,0.3404,0.414226,0.519052,0.75608,0.615536,0.31
600,0.2905,0.449333,0.529823,0.703783,0.604537,0.3
700,0.2621,0.453529,0.51632,0.726049,0.603482,0.23
800,0.2521,0.464003,0.519595,0.716443,0.602344,0.21
900,0.228,0.461278,0.517365,0.725583,0.604034,0.2


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-100)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-200)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-300)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-400)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-500)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-600)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-700)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-800)... Done. 1.5s
  return fn(*args, **kwargs)


  0%|          | 0/41 [00:00<?, ?it/s]

[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-900)... Done. 1.5s
  return fn(*args, **kwargs)
[34m[1mwandb[0m: Adding directory to artifact (./checkpoints/gemma-3-4B/checkpoint-960)... Done. 1.5s
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


TrainOutput(global_step=960, training_loss=0.34725696258246896, metrics={'train_runtime': 3983.7632, 'train_samples_per_second': 3.838, 'train_steps_per_second': 0.241, 'total_flos': 2.015598036588244e+17, 'train_loss': 0.34725696258246896, 'epoch': 5.0})

# Inference

## Checkpoint

In [14]:
from utils.metric import score as char_f1
from utils.utils import inference_aggregation

FINETUNED_MODEL = f'./checkpoints/{config.wandb_init_args["name"]}/checkpoint-500'

In [15]:
trainer._load_from_checkpoint(FINETUNED_MODEL)

## Threshold Selection

In [16]:
valid_preds = trainer.predict(ds_valid)
valid_metrics = trainer.compute_metrics((valid_preds.predictions, valid_preds.label_ids))

valid_metrics

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

{'precision': 0.5190520852144634,
 'recall': 0.7560799687728051,
 'f1': 0.615536066485439,
 'thold': 0.31}

In [17]:
from utils.utils import find_class_balance_threshold

test_preds = trainer.predict(ds_test)
test_probabilities = torch.softmax(torch.tensor(test_preds.predictions), dim=-1).cpu().numpy()

test_distr_th = find_class_balance_threshold(
    desired_positive_ratio=positive_class_balance,
    probabilities=test_probabilities,
    labels=test_preds.label_ids
    )

print(test_distr_th)

  0%|          | 0/41 [00:00<?, ?it/s]

0.495050505050505


In [18]:
final_th = valid_metrics['thold']

## CV-Score

In [19]:
valid_probabilities = torch.softmax(torch.tensor(valid_preds.predictions), dim=-1).cpu().numpy()
valid_results = inference_aggregation(
    probabilities=valid_probabilities,
    labels=valid_preds.label_ids,
    offset_mappings=ds_valid['offset_mapping'],
    thold=final_th
)

In [20]:
from copy import deepcopy

df_valid_gt = df[df.fold==4][['id', 'trigger_words']].reset_index(drop=True)
df_valid = deepcopy(df_valid_gt)
df_valid['trigger_words'] = valid_results

cv_score = char_f1(df_valid_gt, df_valid, row_id_column_name='id')
cv_score

0.615536066485439

## Predict Test

In [21]:
test_results = inference_aggregation(
    probabilities=test_probabilities,
    labels=test_preds.label_ids,
    offset_mappings=ds_test['offset_mapping'],
    thold=final_th
)

In [22]:
df_test_gt = pd.read_csv(config.data_path + 'solution.csv')[['id', 'trigger_words']]
df_test = deepcopy(df_test_gt)
df_test['trigger_words'] = test_results

test_score = char_f1(df_test_gt, df_test, row_id_column_name='id')
test_score

0.6051494792883939