In [None]:
# ==============================================================================
# task 3: Fine-Tuning the Model
# ==============================================================================
# Now we bring everything together to train the model.

# --- Data Collator ---
# This helper object creates batches of data for training. It will also
# pad our sentences to be the same length within a batch.
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

# --- Evaluation Metrics ---
# This function calculates the performance of our model during evaluation.
# It computes precision, recall, and F1-score, as required by the assignment.
seqeval = evaluate.load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (-100) and convert predictions to label strings
    true_predictions = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

# --- Calculate Class Weights ---
# To address class imbalance, we calculate weights based on label frequencies.
# Less frequent labels will have higher weights.
if "train" in tokenized_datasets:
    label_counts = {}
    for example in tokenized_datasets["train"]:
        for label_id in example["labels"]:
            if label_id != -100: # Ignore padding/special tokens
                label_counts[label_id] = label_counts.get(label_id, 0) + 1

    total_labels = sum(label_counts.values())
    # Ensure all possible label IDs from 0 to len(id2label)-1 are included
    # even if they don't appear in the training data (assign a small count or handle carefully)
    # For simplicity, we'll use counts from the training data, assuming all relevant
    # labels appear at least once. If not, you might need to adjust this.
    num_classes = len(id2label)
    # Initialize counts for all classes to avoid division by zero if a class is missing
    full_label_counts = {i: label_counts.get(i, 0) for i in range(num_classes)}

    # Calculate inverse frequency weights
    # weight_i = total_labels / (num_classes * count_i) or similar
    # A common approach is 1 / frequency, then normalize.
    # Using total samples / (num_classes * count) helps scale.
    weights = [0.0] * num_classes
    for i in range(num_classes):
        count = full_label_counts[i]
        # Add a small smoothing term to avoid division by zero for unseen labels
        weights[i] = total_labels / (num_classes * (count + 1e-5))

    import torch
    class_weights = torch.tensor(weights, dtype=torch.float)
    print("\nCalculated class weights:")
    print(class_weights)
else:
    class_weights = None
    print("\nSkipping class weight calculation as training data is not available.")

# --- Load the Pre-trained Model ---
# We load the XLM-Roberta model but tell it we are using it for "Token
# Classification". We also pass our label mappings so it knows what to predict.
# We'll pass the calculated class weights to the model's configuration.
config = AutoConfig.from_pretrained(
    model_checkpoint,
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id,
)
# Pass class weights to the configuration if they were calculated
if class_weights is not None:
    # Convert the tensor to a list for JSON serialization
    config.class_weights = class_weights.tolist() # Convert tensor to list

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    config=config, # Pass the modified configuration
)
print(f"\nModel '{model_checkpoint}' loaded and configured for NER with class weights.")
