In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    AutoConfig
)
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

MODEL_NAME = "FacebookAI/xlm-roberta-large"
MAX_LEN = 512
BATCH_SIZE = 2
LEARNING_RATE = 2e-5
EPOCHS = 8
SEED = 42
TRAIN_VAL_SPLIT = 0.15
WEIGHT_DECAY = 0.01
ACCUMULATION_STEPS = 4
LLRD_RATE = 0.9
SPAN_MERGE_DISTANCE = 1
PATIENCE = 3

torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

class WeightedFocalLoss(nn.Module):
    """
    Weighted Focal Loss implementation.
    Helps focusing on hard-to-classify examples and handles class imbalance.
    alpha: Weights for each class (e.g., [O_weight, B_weight, I_weight])
    gamma: Focusing parameter (>= 0). Higher gamma focuses more on hard examples.
    """
    def __init__(self, alpha=[0.1, 0.45, 0.45], gamma=2.0, ignore_index=-100):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = torch.tensor(alpha).float()
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, inputs, targets):

        mask = targets != self.ignore_index
        valid_inputs = inputs[mask]
        valid_targets = targets[mask]

        if valid_targets.numel() == 0:
             return torch.tensor(0.0, device=inputs.device, requires_grad=True)

        if self.alpha.device != inputs.device:
            self.alpha = self.alpha.to(inputs.device)

        log_probs = self.log_softmax(valid_inputs)

        gathered_log_probs = log_probs.gather(1, valid_targets.unsqueeze(1)).squeeze(1)

        probs = torch.exp(gathered_log_probs)

        alpha_t = self.alpha[valid_targets]

        focal_loss = alpha_t * torch.pow(1 - probs, self.gamma) * (-gathered_log_probs)

        return focal_loss.mean()

def compute_span_f1(true_spans, pred_spans):
    """Compute span-level precision, recall, and F1 using overlap criterion"""
    true_spans = set(true_spans)
    pred_spans = set(pred_spans)

    if not true_spans and not pred_spans:
        return 1.0, 1.0, 1.0
    if not true_spans:
        return 0.0, 1.0, 0.0
    if not pred_spans:
        return 1.0, 0.0, 0.0

    tp = 0
    for p_span in pred_spans:
        for t_span in true_spans:
            if max(p_span[0], t_span[0]) < min(p_span[1], t_span[1]):
                tp += 1
                break

    precision = tp / len(pred_spans) if pred_spans else 0.0
    recall = tp / len(true_spans) if true_spans else 0.0
    tp_for_recall = 0
    for t_span in true_spans:
        for p_span in pred_spans:
             if max(p_span[0], t_span[0]) < min(p_span[1], t_span[1]):
                tp_for_recall += 1
                break

    recall = tp_for_recall / len(true_spans) if true_spans else 0.0

    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    return precision, recall, f1

def load_data(file_path):
    """Load data from parquet file"""
    df = pd.read_parquet(file_path)
    return df

def process_numpy_trigger_words(trigger_array):
    """Convert numpy array of trigger words to list of tuples"""
    if not isinstance(trigger_array, np.ndarray) or trigger_array.ndim == 0 or trigger_array.size == 0:
        return []

    result = []
    if trigger_array.ndim == 2 and trigger_array.shape[1] == 2:
        for arr in trigger_array:
            if np.issubdtype(arr.dtype, np.number) and len(arr) == 2:
                try:
                     start, end = int(arr[0]), int(arr[1])
                     if start < end :
                         result.append((start, end))
                except (ValueError, TypeError):
                    continue
    elif trigger_array.ndim == 1 and trigger_array.dtype == 'object':
         for item in trigger_array:
             if isinstance(item, (list, tuple, np.ndarray)) and len(item) == 2:
                 try:
                     start, end = int(item[0]), int(item[1])
                     if start < end:
                          result.append((start, end))
                 except (ValueError, TypeError):
                      continue
    return result


def align_tokens_and_spans(tokenizer, text, spans, max_length=MAX_LEN):
    """
    Map character-level spans to token-level spans using BIO tagging scheme.
    Returns token_ids and token-level labels (0=Outside, 1=Beginning, 2=Inside).
    Handles cases with no spans (all labels become 0).
    """
    encoded = tokenizer(
        text,
        return_offsets_mapping=True,
        add_special_tokens=True,
        truncation=True,
        max_length=max_length,
        padding=False
    )
    input_ids = encoded["input_ids"]
    offset_mapping = encoded["offset_mapping"]

    labels = [0] * len(input_ids)

    if spans:
        sorted_spans = sorted([s for s in spans if s[0] < s[1]], key=lambda x: x[0])

        span_idx = 0
        for i, (start, end) in enumerate(offset_mapping):
            if start == end == 0:
                labels[i] = -100
                continue

            token_label = 0

            while span_idx < len(sorted_spans) and sorted_spans[span_idx][1] <= start:
                span_idx += 1

            token_overlaps = False
            for k in range(span_idx, len(sorted_spans)):
                 span_start, span_end = sorted_spans[k]
                 if span_start >= end:
                     break

                 if max(start, span_start) < min(end, span_end):
                     token_overlaps = True
                     is_begin = False
                     if start >= span_start:
                        if start == span_start:
                            is_begin = True
                        if i > 0 and labels[i-1] == 0 :
                             prev_start, prev_end = offset_mapping[i-1]
                             if prev_end <= span_start:
                                 is_begin = True


                     char_tags = ['O'] * (end + 1)
                     try:
                         if span_start < len(char_tags): char_tags[span_start] = 'B'
                         for char_i in range(span_start + 1, min(span_end, len(char_tags))):
                             char_tags[char_i] = 'I'
                     except IndexError: pass

                     token_char_tags = char_tags[start:end]
                     if 'B' in token_char_tags:
                         token_label = 1
                     elif 'I' in token_char_tags:
                         token_label = 2
                     break


            labels[i] = token_label

            if token_label == 1 and i + 1 < len(labels):
                 next_start, next_end = offset_mapping[i+1]
                 if next_start == end:
                      for k in range(span_idx, len(sorted_spans)):
                           span_start, span_end = sorted_spans[k]
                           if span_start >= next_end: break
                           if max(next_start, span_start) < min(next_end, span_end):
                                if labels[i+1] == 0:
                                     labels[i+1] = 2
                                break


    return {
        "input_ids": input_ids,
        "attention_mask": encoded["attention_mask"],
        "labels": labels,
        "offset_mapping": offset_mapping,
        "text": text,

    }


class ManipulationSpanDataset(Dataset):
    def __init__(self, texts, spans, tokenizer, max_len=MAX_LEN):
        self.texts = texts
        self.spans = spans
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.encodings = []

        print(f"Tokenizing and aligning spans for {len(texts)} examples...")
        for text, span_list in tqdm(zip(texts, spans), total=len(texts), desc="Processing dataset"):
             if not isinstance(span_list, list):
                 span_list = []
             processed = align_tokens_and_spans(tokenizer, text, span_list, max_len)
             self.encodings.append(processed)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        encoding = self.encodings[idx]
        item = {
             "input_ids": torch.tensor(encoding["input_ids"], dtype=torch.long),
             "attention_mask": torch.tensor(encoding["attention_mask"], dtype=torch.long),
             "labels": torch.tensor(encoding["labels"], dtype=torch.long)
        }
        item["offset_mapping"] = encoding["offset_mapping"]
        item["text"] = encoding["text"]
        return item


def collate_fn(batch):
    max_len = max([len(item["input_ids"]) for item in batch])
    tokenizer_pad_token_id = AutoTokenizer.from_pretrained(MODEL_NAME).pad_token_id

    input_ids_padded = []
    attention_mask_padded = []
    labels_padded = []
    offset_mappings = []
    texts = []


    for item in batch:
        padding_len = max_len - len(item["input_ids"])

        input_ids = torch.cat([item["input_ids"], torch.tensor([tokenizer_pad_token_id] * padding_len, dtype=torch.long)])
        attention_mask = torch.cat([item["attention_mask"], torch.zeros(padding_len, dtype=torch.long)])
        labels = torch.cat([item["labels"], torch.tensor([-100] * padding_len, dtype=torch.long)])

        input_ids_padded.append(input_ids)
        attention_mask_padded.append(attention_mask)
        labels_padded.append(labels)
        offset_mappings.append(item["offset_mapping"] + [(0, 0)] * padding_len)
        texts.append(item["text"])


    return {
        "input_ids": torch.stack(input_ids_padded),
        "attention_mask": torch.stack(attention_mask_padded),
        "labels": torch.stack(labels_padded),
        "offset_mappings": offset_mappings,
        "texts": texts,

    }

def tokens_to_char_spans(tokenizer, text, token_preds, offset_mapping, merge_distance=1):
    """Convert BIO token-level predictions to character-level spans with merging"""
    char_preds = []
    current_span = None

    for i, (start, end) in enumerate(offset_mapping):
        if start == end == 0:
            continue
        if i >= len(token_preds):
             continue

        pred = token_preds[i]

        if pred == 1:
            if current_span is not None:
                char_preds.append(tuple(current_span))
            current_span = [start, end]
        elif pred == 2:
            if current_span is not None:
                if start >= current_span[0]:
                     current_span[1] = max(current_span[1], end)
                else:
                     char_preds.append(tuple(current_span))
                     current_span = [start, end]

            else:
                current_span = [start, end]
        elif pred == 0:
            if current_span is not None:
                char_preds.append(tuple(current_span))
            current_span = None

    if current_span is not None:
        char_preds.append(tuple(current_span))

    char_preds = [span for span in char_preds if span[0] < span[1]]
    if not char_preds: return []

    char_preds.sort(key=lambda x: x[0])

    if len(char_preds) > 1:
        merged_spans = [char_preds[0]]
        for span in char_preds[1:]:
            prev_span = merged_spans[-1]
            if span[0] - prev_span[1] <= merge_distance:
                merged_spans[-1] = (prev_span[0], max(prev_span[1], span[1]))
            else:
                merged_spans.append(span)
        char_preds = merged_spans

    return char_preds

def get_optimizer_grouped_parameters(
    model, learning_rate, weight_decay, layerwise_lr_decay_rate
):
    """
    Groups parameters for applying Layer-wise Learning Rate Decay (LLRD).
    Assigns different learning rates and weight decay to different parts of the model.
    """
    no_decay = ["bias", "LayerNorm.weight"]
    model_prefix = model.base_model_prefix

    encoder = getattr(model, model_prefix).encoder
    layers = encoder.layer

    num_layers = len(layers)
    print(f"Applying LLRD with rate {layerwise_lr_decay_rate} over {num_layers} layers.")

    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters() if "classifier" in n or "pooler" in n
            ],
            "weight_decay": 0.0,
            "lr": learning_rate,
        },
    ]

    for i, layer in enumerate(layers):
        lr_scale = layerwise_lr_decay_rate ** (num_layers - 1 - i)
        layer_lr = learning_rate * lr_scale

        optimizer_grouped_parameters += [
            {
                "params": [
                    p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": weight_decay,
                "lr": layer_lr,
            },
            {
                "params": [
                    p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
                "lr": layer_lr,
            },
        ]

    embeddings = getattr(model, model_prefix).embeddings
    embeddings_lr_scale = layerwise_lr_decay_rate ** num_layers
    embeddings_lr = learning_rate * embeddings_lr_scale

    optimizer_grouped_parameters += [
        {
            "params": [
                p for n, p in embeddings.named_parameters() if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": weight_decay,
            "lr": embeddings_lr,
        },
        {
            "params": [
                p for n, p in embeddings.named_parameters() if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
            "lr": embeddings_lr,
        },
    ]

    all_param_names = {n for n, p in model.named_parameters()}
    grouped_param_names = set()
    for group in optimizer_grouped_parameters:
        for param in group["params"]:
             for n, p in model.named_parameters():
                 if p is param:
                     grouped_param_names.add(n)
                     break
    assert all_param_names == grouped_param_names, "Not all parameters were assigned to optimizer groups!"


    return optimizer_grouped_parameters


def train_model(model, train_dataloader, val_dataloader, optimizer, scheduler, device, epochs, tokenizer, patience=PATIENCE, accumulation_steps=ACCUMULATION_STEPS, span_merge_distance=SPAN_MERGE_DISTANCE):
    best_val_f1 = 0.0
    best_model_state = None
    early_stop_counter = 0

    criterion = WeightedFocalLoss(alpha=[0.1, 0.45, 0.45], gamma=2.0, ignore_index=-100).to(device)

    num_train_steps = len(train_dataloader) // accumulation_steps * epochs

    global_step = 0
    for epoch in range(epochs):
        print(f"\n--- Epoch {epoch+1}/{epochs} ---")

        model.train()
        total_train_loss = 0
        all_train_preds_spans = []
        all_train_true_spans = []


        train_pbar = tqdm(train_dataloader, desc=f"Training Epoch {epoch+1}", leave=False)
        optimizer.zero_grad()

        for step, batch in enumerate(train_pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            offset_mappings = batch["offset_mappings"]
            texts = batch["texts"]


            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            loss = criterion(logits.view(-1, model.config.num_labels), labels.view(-1))

            loss = loss / accumulation_steps
            total_train_loss += loss.item() * accumulation_steps

            loss.backward()

            if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            token_predictions = torch.argmax(logits, dim=2).cpu().numpy()

            for i in range(len(texts)):
                pred_spans = tokens_to_char_spans(
                    tokenizer, texts[i], token_predictions[i], offset_mappings[i], merge_distance=span_merge_distance
                )
                all_train_preds_spans.append(pred_spans)

                try:
                    original_item_index = train_dataloader.dataset.texts.index(texts[i])
                    true_spans = train_dataloader.dataset.spans[original_item_index]
                    all_train_true_spans.append(true_spans)
                except (ValueError, AttributeError):
                     print(f"Warning: Could not find true spans for text: {texts[i][:50]}...")
                     all_train_true_spans.append([])


            train_pbar.set_postfix({
                 "loss": f"{loss.item() * accumulation_steps:.4f}",
                 "lr": f"{scheduler.get_last_lr()[0]:.2e}"
                 })


        avg_train_loss = total_train_loss / len(train_dataloader)

        train_f1s = []
        for true, pred in zip(all_train_true_spans, all_train_preds_spans):
            _, _, f1 = compute_span_f1(true, pred)
            train_f1s.append(f1)
        avg_train_f1 = np.mean(train_f1s) if train_f1s else 0.0
        print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f} | Train F1: {avg_train_f1:.4f}")


        model.eval()
        total_val_loss = 0
        all_val_preds_spans = []
        all_val_true_spans = []

        val_pbar = tqdm(val_dataloader, desc=f"Validation Epoch {epoch+1}", leave=False)

        with torch.no_grad():
            for batch in val_pbar:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                offset_mappings = batch["offset_mappings"]
                texts = batch["texts"]


                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                loss = criterion(logits.view(-1, model.config.num_labels), labels.view(-1))
                total_val_loss += loss.item()

                token_predictions = torch.argmax(logits, dim=2).cpu().numpy()
                for i in range(len(texts)):
                    pred_spans = tokens_to_char_spans(
                        tokenizer, texts[i], token_predictions[i], offset_mappings[i], merge_distance=span_merge_distance
                    )
                    all_val_preds_spans.append(pred_spans)

                    try:
                        original_item_index = val_dataloader.dataset.texts.index(texts[i])
                        true_spans = val_dataloader.dataset.spans[original_item_index]
                        all_val_true_spans.append(true_spans)
                    except (ValueError, AttributeError):
                         print(f"Warning: Could not find true spans for validation text: {texts[i][:50]}...")
                         all_val_true_spans.append([])


        avg_val_loss = total_val_loss / len(val_dataloader)
        val_f1s = []
        val_precisions = []
        val_recalls = []
        for true, pred in zip(all_val_true_spans, all_val_preds_spans):
            p, r, f1 = compute_span_f1(true, pred)
            val_precisions.append(p)
            val_recalls.append(r)
            val_f1s.append(f1)

        avg_val_precision = np.mean(val_precisions) if val_precisions else 0.0
        avg_val_recall = np.mean(val_recalls) if val_recalls else 0.0
        avg_val_f1 = np.mean(val_f1s) if val_f1s else 0.0

        print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f} | Val Precision: {avg_val_precision:.4f} | Val Recall: {avg_val_recall:.4f} | Val F1: {avg_val_f1:.4f}")


        if avg_val_f1 > best_val_f1:
            best_val_f1 = avg_val_f1
            best_model_state = model.state_dict().copy()
            print(f"✨ New best model saved with F1: {best_val_f1:.4f}!")

            early_stop_counter = 0
        else:
            early_stop_counter += 1
            print(f"No F1 improvement ({avg_val_f1:.4f} vs best {best_val_f1:.4f}). Counter: {early_stop_counter}/{patience}")
            if early_stop_counter >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break

    if best_model_state:
        print(f"Loading best model state with F1: {best_val_f1:.4f}")
        model.load_state_dict(best_model_state)
    else:
        print("Warning: No best model state found (e.g., validation F1 never improved). Using model from last epoch.")

    return model


def predict_spans(model, tokenizer, texts, device, max_len=MAX_LEN, batch_size=16, span_merge_distance=SPAN_MERGE_DISTANCE):
    model.eval()
    all_char_spans = []

    class InferenceDataset(Dataset):
        def __init__(self, texts, tokenizer, max_len):
            self.texts = texts
            self.tokenizer = tokenizer
            self.max_len = max_len

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            text = self.texts[idx]
            encoding = self.tokenizer(
                text,
                return_offsets_mapping=True,
                add_special_tokens=True,
                truncation=True,
                max_length=self.max_len,
                padding=False,
                return_tensors=None,
            )
            return {
                "text": text,
                "input_ids": encoding["input_ids"],
                "attention_mask": encoding["attention_mask"],
                "offset_mapping": encoding["offset_mapping"]
            }

    def inference_collate_fn(batch):
        texts = [item["text"] for item in batch]
        offset_mappings = [item["offset_mapping"] for item in batch]
        max_batch_len = max([len(item["input_ids"]) for item in batch])
        tokenizer_pad_token_id = tokenizer.pad_token_id

        input_ids_padded = []
        attention_mask_padded = []
        offset_mappings_padded = []

        for item in batch:
             padding_len = max_batch_len - len(item["input_ids"])
             input_ids = item["input_ids"] + [tokenizer_pad_token_id] * padding_len
             attention_mask = item["attention_mask"] + [0] * padding_len
             offset_mapping = item["offset_mapping"] + [(0, 0)] * padding_len

             input_ids_padded.append(torch.tensor(input_ids, dtype=torch.long))
             attention_mask_padded.append(torch.tensor(attention_mask, dtype=torch.long))
             offset_mappings_padded.append(offset_mapping)


        return {
            "texts": texts,
            "input_ids": torch.stack(input_ids_padded),
            "attention_mask": torch.stack(attention_mask_padded),
            "offset_mappings": offset_mappings_padded
        }

    inference_dataset = InferenceDataset(texts, tokenizer, max_len)
    inference_dataloader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False, collate_fn=inference_collate_fn)

    print(f"Predicting spans for {len(texts)} texts with batch size {batch_size}...")
    with torch.no_grad():
        for batch in tqdm(inference_dataloader, desc="Prediction"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            offset_mappings_batch = batch["offset_mappings"]
            texts_batch = batch["texts"]

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            token_predictions = torch.argmax(logits, dim=2).cpu().numpy()

            for i in range(len(texts_batch)):
                valid_len = sum(attention_mask[i].cpu().numpy())
                current_offset_mapping = offset_mappings_batch[i]
                current_predictions = token_predictions[i][:len(current_offset_mapping)]

                char_spans = tokens_to_char_spans(
                    tokenizer, texts_batch[i], current_predictions, current_offset_mapping, merge_distance=span_merge_distance
                )
                all_char_spans.append(char_spans)

    return all_char_spans

def format_spans_for_submission(spans_list):
    """Format spans list to match submission format '[(start1, end1), (start2, end2)]'"""
    if not spans_list:
        return "[]"
    return str([(int(s[0]), int(s[1])) for s in spans_list])


def main():
    print("Loading training data...")
    train_df = pd.read_parquet("/kaggle/input/unlp-2025-shared-task-span-identification/train.parquet")

    print("Processing trigger words...")
    train_df['trigger_words'] = train_df['trigger_words'].fillna('').apply(lambda x: x if isinstance(x, np.ndarray) else np.array([]))
    train_df['trigger_words_processed'] = train_df['trigger_words'].apply(process_numpy_trigger_words)

    texts = train_df['content'].tolist()
    spans = train_df['trigger_words_processed'].tolist()
    print(f"Total raw examples: {len(texts)}")

    valid_indices = [i for i, txt in enumerate(texts) if isinstance(txt, str) and len(txt.strip()) > 0]
    texts = [texts[i] for i in valid_indices]
    spans = [spans[i] for i in valid_indices]
    print(f"Using {len(texts)} non-empty text examples for training/validation.")


    print(f"Splitting data into train/validation ({1-TRAIN_VAL_SPLIT:.0%}/{TRAIN_VAL_SPLIT:.0%})...")

    train_texts, val_texts, train_spans, val_spans = train_test_split(
        texts, spans, test_size=TRAIN_VAL_SPLIT, random_state=SEED
    )
    print(f"Training examples: {len(train_texts)}")
    print(f"Validation examples: {len(val_texts)}")


    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    print("Creating datasets (this may take a while)...")
    train_dataset = ManipulationSpanDataset(train_texts, train_spans, tokenizer, max_len=MAX_LEN)
    val_dataset = ManipulationSpanDataset(val_texts, val_spans, tokenizer, max_len=MAX_LEN)

    print("Creating dataloaders...")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE * 2,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2
    )

    print("Loading model...")
    config = AutoConfig.from_pretrained(MODEL_NAME, num_labels=3)
    model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, config=config)
    model.to(device)


    print(f"Setting up AdamW optimizer with LLRD (Rate: {LLRD_RATE})...")
    optimizer_parameters = get_optimizer_grouped_parameters(
        model,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        layerwise_lr_decay_rate=LLRD_RATE
    )
    optimizer = torch.optim.AdamW(optimizer_parameters, lr=LEARNING_RATE, eps=1e-8)


    print("Setting up learning rate scheduler...")
    num_update_steps_per_epoch = (len(train_dataloader) + ACCUMULATION_STEPS - 1) // ACCUMULATION_STEPS
    total_steps = num_update_steps_per_epoch * EPOCHS
    num_warmup_steps = int(0.1 * total_steps)

    print(f"Total optimization steps: {total_steps}, Warmup steps: {num_warmup_steps}")
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps
    )


    print("Starting training...")
    model = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=EPOCHS,
        tokenizer=tokenizer,
        patience=PATIENCE,
        accumulation_steps=ACCUMULATION_STEPS,
        span_merge_distance=SPAN_MERGE_DISTANCE
    )


    print("Saving the fine-tuned model...")
    output_dir = "./manipulation_span_model_improved"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")


    print("Loading test data...")
    test_df = pd.read_csv("/kaggle/input/unlp-2025-shared-task-span-identification/test.csv")
    test_texts = test_df['content'].tolist()
    test_ids = test_df['id'].tolist()

    valid_test_indices = [i for i, txt in enumerate(test_texts) if isinstance(txt, str) and len(txt.strip()) > 0]
    test_texts_filtered = [test_texts[i] for i in valid_test_indices]
    test_ids_filtered = [test_ids[i] for i in valid_test_indices]
    print(f"Predicting on {len(test_texts_filtered)} non-empty test examples.")


    print("Making predictions on test data...")
    predictions = predict_spans(
        model=model,
        tokenizer=tokenizer,
        texts=test_texts_filtered,
        device=device,
        max_len=MAX_LEN,
        batch_size=BATCH_SIZE * 4,
        span_merge_distance=SPAN_MERGE_DISTANCE
    )

    prediction_map = {id_: spans for id_, spans in zip(test_ids_filtered, predictions)}

    print("Formatting predictions for submission...")
    submission_data = []
    for id_ in test_ids:
        spans = prediction_map.get(id_, [])
        submission_data.append({
            'id': id_,
            'trigger_words': format_spans_for_submission(spans)
        })

    submission_df = pd.DataFrame(submission_data)
    submission_df.to_csv("submission.csv", index=False)
    print("Submission file 'submission.csv' saved successfully!")


if __name__ == "__main__":
    main()

Using device: cuda
Loading training data...
Processing trigger words...
Total raw examples: 3822
Using 3822 non-empty text examples for training/validation.
Splitting data into train/validation (85%/15%)...
Training examples: 3248
Validation examples: 574
Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

Creating datasets (this may take a while)...
Tokenizing and aligning spans for 3248 examples...


Processing dataset:   0%|          | 0/3248 [00:00<?, ?it/s]

Tokenizing and aligning spans for 574 examples...


Processing dataset:   0%|          | 0/574 [00:00<?, ?it/s]

Creating dataloaders...
Loading model...


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at FacebookAI/xlm-roberta-large and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Setting up AdamW optimizer with LLRD (Rate: 0.9)...
Applying LLRD with rate 0.9 over 24 layers.
Setting up learning rate scheduler...
Total optimization steps: 3248, Warmup steps: 324
Starting training...

--- Epoch 1/8 ---


Training Epoch 1:   0%|          | 0/1624 [00:00<?, ?it/s]

Epoch 1 Train Loss: 0.0518 | Train F1: 0.4314


Validation Epoch 1:   0%|          | 0/144 [00:00<?, ?it/s]

Epoch 1 Val Loss: 0.0336 | Val Precision: 0.4246 | Val Recall: 0.9721 | Val F1: 0.4786
✨ New best model saved with F1: 0.4786!

--- Epoch 2/8 ---


Training Epoch 2:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 2 Train Loss: 0.0349 | Train F1: 0.5077


Validation Epoch 2:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
    self._shutdown_workers()  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a

Epoch 2 Val Loss: 0.0317 | Val Precision: 0.5223 | Val Recall: 0.9528 | Val F1: 0.5621
✨ New best model saved with F1: 0.5621!

--- Epoch 3/8 ---


Training Epoch 3:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 3 Train Loss: 0.0295 | Train F1: 0.5693


Validation Epoch 3:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__

Traceback (most recent call last):
      File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
self._shutdown_workers()    
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    self._shutdown_workers()if w.is_alive():

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only test a child process'if w.is_alive():
AssertionError
:   File "/usr/lib/python3.10/multiprocessing/pro

Epoch 3 Val Loss: 0.0341 | Val Precision: 0.6404 | Val Recall: 0.9117 | Val F1: 0.6421
✨ New best model saved with F1: 0.6421!

--- Epoch 4/8 ---


Training Epoch 4:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50><function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/u

Epoch 4 Train Loss: 0.0255 | Train F1: 0.6106


Validation Epoch 4:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>self._shutdown_workers()

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        if w.is_alive():
self._shutdown_workers()  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive

    assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

AssertionError:     if w.is_alive():can only test a child process

  File "/usr/lib/

Epoch 4 Val Loss: 0.0384 | Val Precision: 0.6339 | Val Recall: 0.8999 | Val F1: 0.6321
No F1 improvement (0.6321 vs best 0.6421). Counter: 1/3

--- Epoch 5/8 ---


Training Epoch 5:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 5 Train Loss: 0.0217 | Train F1: 0.6575


Validation Epoch 5:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers

    if w.is_alive():Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__

      File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()    
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
assert self._parent_pid == os.getpid(), 'can only test a child process'
    if w.is_alive():AssertionError
:   File "/usr/lib/python3.10/multiprocessing/pro

Epoch 5 Val Loss: 0.0470 | Val Precision: 0.6695 | Val Recall: 0.8816 | Val F1: 0.6454
✨ New best model saved with F1: 0.6454!

--- Epoch 6/8 ---


Training Epoch 6:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50><function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
<function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>Traceback (most recent call last):

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__



Epoch 6 Train Loss: 0.0183 | Train F1: 0.6891


Validation Epoch 6:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 6 Val Loss: 0.0526 | Val Precision: 0.6511 | Val Recall: 0.8879 | Val F1: 0.6328
No F1 improvement (0.6328 vs best 0.6454). Counter: 1/3

--- Epoch 7/8 ---


Training Epoch 7:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 7 Train Loss: 0.0161 | Train F1: 0.7029


Validation Epoch 7:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/

Epoch 7 Val Loss: 0.0554 | Val Precision: 0.6383 | Val Recall: 0.8829 | Val F1: 0.6176
No F1 improvement (0.6176 vs best 0.6454). Counter: 2/3

--- Epoch 8/8 ---


Training Epoch 8:   0%|          | 0/1624 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionErrorException ignored in: : can only test a child process<function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>

Traceback (most recent call last):
Exception ignored in:   File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>self._shutdown_workers()

Traceback (most recent call last):
  File "/usr/local/lib/pyt

Epoch 8 Train Loss: 0.0142 | Train F1: 0.7139


Validation Epoch 8:   0%|          | 0/144 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50><function _MultiProcessingDataLoaderIter.__del__ at 0x790c4f094e50>

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
        if w.is_alive():if w.is_alive():

  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
        assert self._parent_pid == os.getpid(), 'can only te

Epoch 8 Val Loss: 0.0674 | Val Precision: 0.6857 | Val Recall: 0.8649 | Val F1: 0.6467
✨ New best model saved with F1: 0.6467!
Loading best model state with F1: 0.6467
Saving the fine-tuned model...
Model saved to ./manipulation_span_model_improved
Loading test data...
Predicting on 5735 non-empty test examples.
Making predictions on test data...
Predicting spans for 5735 texts with batch size 8...


Prediction:   0%|          | 0/717 [00:00<?, ?it/s]

Formatting predictions for submission...
Submission file 'submission.csv' saved successfully!
