# Imports

In [None]:
%load_ext autoreload
%autoreload 2


import os
from typing import List
import json
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shutil
import sys
import logging 

logging.basicConfig(
     level=logging.INFO, 
     format= '[%(asctime)s|%(levelname)s|%(module)s.py:%(lineno)s] %(message)s',
     datefmt='%H:%M:%S'
 )
import tqdm.notebook as tq
from tqdm import tqdm
# Create new `pandas` methods which use `tqdm` progress
# (can use tqdm_gui, optional kwargs, etc.)
tqdm.pandas()

from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, precision_score, recall_score, ConfusionMatrixDisplay

import torch
from torch import nn
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, 
    TrainingArguments, Trainer, EarlyStoppingCallback, IntervalStrategy, get_linear_schedule_with_warmup
)

from defi_textmine_2025.data.utils import TARGET_COL, INTERIM_DIR, MODELS_DIR, get_cat_var_distribution, compute_class_weights

# Constants

In [None]:
BASE_CHECKPOINT = "camembert/camembert-large"


RANDOM_SEED = 0  # random reproducibility
np.random.seed(RANDOM_SEED)
logging.info(f"{RANDOM_SEED=}")
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


TASK_NAME = "hasrelation"
logging.info(f"{TASK_NAME=}")
STEP1_TASK_TARGET_COL = f"{TASK_NAME}_label"
logging.info(f"{STEP1_TASK_TARGET_COL=}")
TASK_INPUT_COL = "input_text"

FOLD_NUM = 1
logging.info(f"{FOLD_NUM=}")


entity_classes = {'TERRORIST_OR_CRIMINAL', 'LASTNAME', 'LENGTH', 'NATURAL_CAUSES_DEATH', 'COLOR', 'STRIKE', 'DRUG_OPERATION', 'HEIGHT', 'INTERGOVERNMENTAL_ORGANISATION', 'TRAFFICKING', 'NON_MILITARY_GOVERNMENT_ORGANISATION', 'TIME_MIN', 'DEMONSTRATION', 'TIME_EXACT', 'FIRE', 'QUANTITY_MIN', 'MATERIEL', 'GATHERING', 'PLACE', 'CRIMINAL_ARREST', 'CBRN_EVENT', 'ECONOMICAL_CRISIS', 'ACCIDENT', 'LONGITUDE', 'BOMBING', 'MATERIAL_REFERENCE', 'WIDTH', 'FIRSTNAME', 'MILITARY_ORGANISATION', 'CIVILIAN', 'QUANTITY_MAX', 'CATEGORY', 'POLITICAL_VIOLENCE', 'EPIDEMIC', 'TIME_MAX', 'TIME_FUZZY', 'NATURAL_EVENT', 'SUICIDE', 'CIVIL_WAR_OUTBREAK', 'POLLUTION', 'ILLEGAL_CIVIL_DEMONSTRATION', 'NATIONALITY', 'GROUP_OF_INDIVIDUALS', 'QUANTITY_FUZZY', 'RIOT', 'WEIGHT', 'THEFT', 'MILITARY', 'NON_GOVERNMENTAL_ORGANISATION', 'LATITUDE', 'COUP_D_ETAT', 'ELECTION', 'HOOLIGANISM_TROUBLEMAKING', 'QUANTITY_EXACT', 'AGITATING_TROUBLE_MAKING'}

USED_COLUMNS = ["text_index", "e1_id", "e2_id", "e1_type", "e2_type", TARGET_COL, TASK_INPUT_COL, STEP1_TASK_TARGET_COL]
logging.info(f"{USED_COLUMNS=}")

model_checkpoints_dir = os.path.join(MODELS_DIR, f"mth2-{TASK_NAME}-fold{FOLD_NUM}-{BASE_CHECKPOINT.split('/')[-1]}-uncased")
logging.info(f"{model_checkpoints_dir=}")

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

# Load data

In [None]:
def load_preprossed_data(parquet_path: str) -> pd.DataFrame:
    return pd.read_parquet(parquet_path, columns=USED_COLUMNS)

In [None]:
train_df = load_preprossed_data(f"data/defi-text-mine-2025/interim/train-fold{FOLD_NUM}-mth2.parquet")
train_df.head(2)

In [None]:
val_df = load_preprossed_data(f"data/defi-text-mine-2025/interim/validation-fold{FOLD_NUM}-mth2.parquet")
val_df.head(2)

In [None]:
get_cat_var_distribution(train_df[STEP1_TASK_TARGET_COL])

In [None]:
get_cat_var_distribution(val_df[STEP1_TASK_TARGET_COL])

# Create the tokenized datasets for model input

## init the tokenizer

In [None]:
# Hyperparameters
tokenizer = AutoTokenizer.from_pretrained(BASE_CHECKPOINT)
tokenizer

## init the train-valid datasets from dataframe

In [None]:
def tokenize_function(example: dict):
    return tokenizer(example[TASK_INPUT_COL], truncation=True, max_length=300) # max n_token without loosing entity, see setp0_data_preparation
tokenized_datasets = DatasetDict({
    "train": Dataset.from_pandas(train_df, preserve_index=False).shuffle(seed=RANDOM_SEED),
    "validation": Dataset.from_pandas(val_df, preserve_index=False)
}).map(lambda x: {TASK_INPUT_COL: x[TASK_INPUT_COL].lower()}).map(tokenize_function, batched=True)
tokenized_datasets

In [None]:
tokenized_datasets["validation"][:2]

# Init the data collator

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Fine-tuning a model with the Trainer API

## Compute the weight of classes to handle imbalance

In [None]:
get_cat_var_distribution(train_df[STEP1_TASK_TARGET_COL])["count"]

In [None]:
get_cat_var_distribution(train_df[STEP1_TASK_TARGET_COL]).reset_index(drop=False)["count"]

In [None]:
# Source: https://www.tensorflow.org/tutorials/structured_data/imbalanced_data#calculate_class_weights
# Scaling by total/2 helps keep the loss to a similar magnitude.
n_examples = train_df.shape[0]
n_classes = train_df[STEP1_TASK_TARGET_COL].nunique()
# def compute_class_weights2(lbl_df: pd.DataFrame) -> pd.Series:
#     return get_cat_var_distribution(lbl_df[STEP1_TASK_TARGET_COL]).reset_index(drop=False)["count"].apply(lambda x: (1 / x) * (n_examples / n_classes)).rename("weight")
# class_weights_df = compute_class_weights2(train_df)
class_weights_df = compute_class_weights(train_df, label_columns=[STEP1_TASK_TARGET_COL])
pd.concat([get_cat_var_distribution(train_df[STEP1_TASK_TARGET_COL]), class_weights_df], axis=1)

In [None]:
class_weights = class_weights_df.values.tolist()
class_weights

### Init the model

In [None]:
n_classes = train_df[STEP1_TASK_TARGET_COL].nunique()
print(f"{n_classes=}")
model = AutoModelForSequenceClassification.from_pretrained(BASE_CHECKPOINT, num_labels=n_classes)
model.resize_token_embeddings(len(tokenizer))

In [None]:
model

### Init the trainer and launch the training

Source: https://stackoverflow.com/questions/69087044/early-stopping-in-bert-trainer-instances#69087153

1. Use `load_best_model_at_end = True` (EarlyStoppingCallback() requires this to be True).
2. `evaluation_strategy = 'steps'` or IntervalStrategy.STEPS instead of 'epoch'.
3. `eval_steps = 50` (evaluate the metrics after N steps).
4. `metric_for_best_model = 'f1'`

In [None]:
def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average="macro")
    precision = precision_score(y_true=labels, y_pred=pred, average="macro")
    f1 = f1_score(y_true=labels, y_pred=pred, average="macro")    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

MAX_EPOCHS = 50
TRAIN_BATCH_SIZE=8
VAL_BATCH_SIZE=8
LEARNING_RATE = 1e-6
WEIGHT_DECAY = 0.01

training_args = TrainingArguments(
    output_dir=model_checkpoints_dir,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,    
    per_device_eval_batch_size=VAL_BATCH_SIZE,
    num_train_epochs=MAX_EPOCHS,
    eval_strategy=IntervalStrategy.STEPS, # steps
    eval_steps = 3000, # Evaluation and Save happens every 50 steps
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    save_strategy=IntervalStrategy.STEPS,
    save_steps=3000,
    save_total_limit=1, # Only last 2 models are saved. Older ones are deleted
    push_to_hub=False,
    # label_names=[STEP1_TASK_TARGET_COL],
    metric_for_best_model='f1',
    greater_is_better=True,
    load_best_model_at_end=True,
    report_to="none",
)

class CustomTrainer(Trainer):    
    def compute_loss(self, model, inputs, return_outputs=False):
        """customize the loss to leverage class weights"""
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get('logits')
        # compute custom loss
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss


# optimizer = torch.optim.RAdam(
#         model.parameters(),
#         lr=LEARNING_RATE,
#         betas=(0.9, 0.999),
#         eps=1e-8,
#         weight_decay=WEIGHT_DECAY,
#     )
# total_steps = int(tokenized_datasets["train"].num_rows/TRAIN_BATCH_SIZE) * training_args.num_train_epochs
# scheduler = get_linear_schedule_with_warmup(
#     optimizer,
#     num_warmup_steps=int(0.1 * total_steps),
#     num_training_steps=total_steps
# )

trainer = CustomTrainer(
    model,
    args=training_args,
    # optimizers=(optimizer, scheduler),
    train_dataset=tokenized_datasets["train"].rename_column(STEP1_TASK_TARGET_COL, "label"),
    eval_dataset=tokenized_datasets["validation"].rename_column(STEP1_TASK_TARGET_COL, "label"),
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=4)],
)

# trainer.train()

## Launch training

In [None]:
BASE_CHECKPOINT

In [None]:
# trainer.train(resume_from_checkpoint=trainer.state.best_model_checkpoint)
trainer.train()

# Evaluate

In [None]:
# trainer.train(resume_from_checkpoint=True)

In [None]:
trainer.state.best_metric

In [None]:
# After training, access the path of the best checkpoint like this
best_ckpt_path = trainer.state.best_model_checkpoint
best_ckpt_path

In [None]:
trainer.state.best_model_checkpoint

In [None]:
trainer._load_best_model()

## Get the labels

In [None]:
train_pred_output = trainer.predict(tokenized_datasets["train"], metric_key_prefix="train")

In [None]:
train_y_pred = torch.sigmoid(torch.from_numpy(train_pred_output.predictions)).argmax(axis=1).numpy()
train_y_pred

In [None]:
# train_y_true = val_pred_output.label_ids
train_y_true = tokenized_datasets['train'][STEP1_TASK_TARGET_COL]
print(train_y_true)

In [None]:
val_pred_output = trainer.predict(tokenized_datasets["validation"], metric_key_prefix="validation")

In [None]:
val_y_pred = torch.sigmoid(torch.from_numpy(val_pred_output.predictions)).argmax(axis=1).numpy()
val_y_pred

In [None]:
# val_y_true = val_pred_output.label_ids
val_y_true = tokenized_datasets['validation'][STEP1_TASK_TARGET_COL]
print(val_y_true)

## Global metrics

In [None]:
train_pred_output.metrics

In [None]:
val_pred_output.metrics

## Classification report

In [None]:
print(classification_report(y_true=train_y_true, y_pred=train_y_pred))

In [None]:
print(classification_report(y_true=val_y_true, y_pred=val_y_pred))

## Confusion matrix

In [None]:
cm = confusion_matrix(y_true=val_y_true, y_pred=val_y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()

In [None]:
cm = confusion_matrix(y_true=val_y_true, y_pred=val_y_pred, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()

## Error analysis

TODO...

In [None]:
tokenized_datasets["validation"].select_columns(USED_COLUMNS).to_pandas()

In [None]:
val_y_true_vs_pred_df = pd.concat([pd.DataFrame({"y_true": val_y_true, "y_pred": val_y_pred}), tokenized_datasets["validation"].select_columns(USED_COLUMNS).to_pandas()], axis=1)
# val_y_true_vs_pred_df = pd.DataFrame({"y_true": val_y_true, "y_pred": val_y_pred})
val_y_true_vs_pred_df

## false negatives

In [None]:
# false negative
false_neg_df = val_y_true_vs_pred_df.query("y_true==1 & y_true != y_pred")
false_neg_df

In [None]:
false_neg_df[TARGET_COL].value_counts().sort_values()

In [None]:
false_neg_df.query(f""" {TARGET_COL}=="['HAS_FOR_LENGTH']" """)[TASK_INPUT_COL].values.tolist()

In [None]:
false_neg_df.query(f""" {TARGET_COL}=="['HAS_FOR_LENGTH']" """)

In [None]:
from transformers import pipeline

classifier = classifier = pipeline("text-classification", model=best_ckpt_path, device="cuda")

classifier

In [None]:
text = 'Le { super-navire } { Thang Long }, mis en service le 6 mai dernier, a coulé avec à son bord plusieurs passagers. Le { bateau } à grande vitesse était un { engin } à simple coque et constituait le plus grand { navire } du pays. Il mesurait [ 77,46 mètres ] de long avec une capacité de 1017 passagers.'
classifier(text)

In [None]:
classifier(text.lower())