In [None]:
import json
import os
import zipfile
import random
import logging
from typing import List, Dict, Any
import math

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from tqdm.auto import tqdm
import numpy as np
from collections import Counter

MODEL_NAME = "microsoft/deberta-v3-large"
DEV_FILE_PATH = "/kaggle/input/bae-acl-dataset/mrbench_v3_devset.json"
TEST_FILE_PATH = "/kaggle/input/bae-acl-dataset/mrbench_v3_testset.json"
OUTPUT_DIR = "bea2025_track4_output"
PREDICTIONS_FILENAME = "predictions.json"
ZIP_FILENAME = "predictions.json.zip"

TASK_TRACK_NAME = "Track 4 - Actionability"
ANNOTATION_KEY = "Actionability"

BATCH_SIZE = 2
LEARNING_RATE = 1.5e-5
WEIGHT_DECAY = 0.01
EPOCHS = 12
MAX_SEQ_LENGTH = 512
GRADIENT_ACCUMULATION_STEPS = 2
WARMUP_PROPORTION = 0.1
SEED = 42

LABEL_MAP = {"Yes": 0, "To some extent": 1, "No": 2}
ID_TO_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
NUM_LABELS = len(LABEL_MAP)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(SEED)

def load_data(file_path: str) -> List[Dict[str, Any]]:
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        log_message = f"Loaded {len(data)} conversations from {file_path}"
        logger.info(log_message)
        print(log_message)
        return data
    except FileNotFoundError:
        error_message = f"Error: File not found at {file_path}"
        logger.error(error_message)
        print(error_message)
        raise
    except json.JSONDecodeError:
        error_message = f"Error: Could not decode JSON from {file_path}"
        logger.error(error_message)
        print(error_message)
        raise

def preprocess_data(raw_data: List[Dict[str, Any]], annotation_key: str, is_test_set: bool = False) -> List[Dict[str, Any]]:
    processed_examples = []
    skipped_count = 0
    for conversation in raw_data:
        conv_id = conversation["conversation_id"]
        history = conversation.get("conversation_history", "").strip()

        for tutor_id, response_data in conversation.get("tutor_responses", {}).items():
            tutor_response = response_data.get("response", "").strip()
            if not tutor_response:
                 warning_message = f"Missing 'response' for {conv_id}/{tutor_id}. Skipping."
                 logger.warning(warning_message)
                 skipped_count += 1
                 continue

            combined_text = f"Conversation History:\n{history}\n\nTutor Response:\n{tutor_response}"

            example = {
                "conversation_id": conv_id,
                "tutor_id": tutor_id,
                "text": combined_text,
            }

            if not is_test_set:
                try:
                    if "annotation" not in response_data or annotation_key not in response_data["annotation"]:
                         raise KeyError(f"Missing annotation structure or key '{annotation_key}'")
                    label_str = response_data["annotation"][annotation_key]
                    if label_str not in LABEL_MAP:
                        raise ValueError(f"Invalid label '{label_str}' found for key '{annotation_key}'.")
                    example["label"] = LABEL_MAP[label_str]
                except KeyError as e:
                    warning_message = f"Missing or incomplete '{annotation_key}' annotation for {conv_id}/{tutor_id}: {e}. Skipping."
                    logger.warning(warning_message)
                    skipped_count += 1
                    continue
                except ValueError as e:
                    warning_message = f"Annotation error for {conv_id}/{tutor_id}: {e}. Skipping."
                    logger.warning(warning_message)
                    skipped_count += 1
                    continue
                except Exception as e:
                    warning_message = f"Unexpected error processing annotation for {conv_id}/{tutor_id}: {e}. Skipping."
                    logger.warning(warning_message)
                    skipped_count += 1
                    continue

            processed_examples.append(example)

    log_message = f"Preprocessed into {len(processed_examples)} individual examples for key '{annotation_key}'."
    if skipped_count > 0:
        log_message += f" Skipped {skipped_count} due to missing response/annotation."
    logger.info(log_message)
    print(log_message)
    return processed_examples

class PedagogicalAbilityDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], tokenizer, max_length: int, is_test: bool = False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_test = is_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        inputs = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

        if not self.is_test:
            inputs['labels'] = torch.tensor(item['label'], dtype=torch.long)

        inputs['metadata'] = {
            'conversation_id': item['conversation_id'],
            'tutor_id': item['tutor_id']
        }

        return inputs

def calculate_metrics(preds: np.ndarray, labels: np.ndarray) -> Dict[str, float]:
    accuracy = accuracy_score(labels, preds)
    f1_strict = f1_score(labels, preds, average='macro', labels=[0, 1, 2], zero_division=0)

    lenient_labels = np.where(labels == 2, 1, 0)
    lenient_preds = np.where(preds == 2, 1, 0)

    f1_lenient = f1_score(lenient_labels, lenient_preds, average='macro', labels=[0, 1], zero_division=0)

    return {
        "accuracy": accuracy,
        "f1_macro_strict": f1_strict,
        "f1_macro_lenient": f1_lenient
    }

def train_epoch(model, data_loader, loss_fct, optimizer, scheduler, device, grad_accum_steps, epoch_num, total_epochs):
    model.train()
    total_loss = 0
    optimizer.zero_grad()

    print(f"\n{'='*80}")
    print(f"EPOCH {epoch_num}/{total_epochs} - TRAINING ({TASK_TRACK_NAME})")
    print(f"{'='*80}")

    progress_bar = tqdm(data_loader, desc=f"Training Epoch {epoch_num}/{total_epochs}",
                        position=0, leave=True, dynamic_ncols=True)

    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        loss = loss_fct(logits, labels)

        loss = loss / grad_accum_steps

        loss.backward()

        if (i + 1) % grad_accum_steps == 0 or (i + 1) == len(data_loader):
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        batch_loss = loss.item() * grad_accum_steps
        total_loss += batch_loss
        progress_bar.set_postfix({'loss': f"{batch_loss:.4f}"})

    avg_loss = total_loss / len(data_loader.dataset)
    log_message = f"Epoch {epoch_num}/{total_epochs} - Average Training Loss: {avg_loss:.6f}"
    logger.info(log_message)
    print(f"\n{log_message}")

    return avg_loss

def evaluate(model, data_loader, device, epoch_num=None, total_epochs=None):
    model.eval()
    all_preds = []
    all_labels = []
    total_eval_loss = 0

    epoch_str = f"{epoch_num}/{total_epochs}" if epoch_num is not None and total_epochs is not None else ""

    print(f"\n{'-'*80}")
    print(f"EPOCH {epoch_str} - VALIDATION" if epoch_str else "VALIDATION")
    print(f"{'-'*80}")

    progress_bar = tqdm(data_loader, desc="Validating", position=0, leave=True, dynamic_ncols=True)

    with torch.no_grad():
        for i, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs.loss
            logits = outputs.logits

            batch_loss = loss.item()
            total_eval_loss += batch_loss
            progress_bar.set_postfix({'eval_loss': f"{batch_loss:.4f}"})

            preds = torch.argmax(logits, dim=-1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    avg_eval_loss = total_eval_loss / len(data_loader) if len(data_loader) > 0 else 0
    metrics = calculate_metrics(np.array(all_preds), np.array(all_labels))
    metrics["eval_loss"] = avg_eval_loss

    result_message_header = f"Epoch {epoch_str} - " if epoch_str else ""
    result_message = (
        f"{result_message_header}Validation Results:"
        f"\n - Eval Loss (Unweighted): {metrics['eval_loss']:.4f}"
        f"\n - Accuracy: {metrics['accuracy']:.4f}"
        f"\n - F1 Macro (Strict): {metrics['f1_macro_strict']:.4f} (Primary Metric)"
        f"\n - F1 Macro (Lenient): {metrics['f1_macro_lenient']:.4f}"
    )

    logger.info(result_message)
    print(result_message)

    label_counts = np.bincount(np.array(all_preds), minlength=NUM_LABELS)
    print(f"\nValidation Prediction distribution:")
    total_preds = len(all_preds)
    if total_preds > 0:
        for label_id, count in enumerate(label_counts):
            print(f"  {ID_TO_LABEL_MAP.get(label_id, 'Unknown')}: {count} ({count/total_preds*100:.1f}%)")
    else:
        print("  No predictions made.")

    return metrics

if __name__ == "__main__":
    print("\n" + "="*80)
    print(f" BEA 2025 SHARED TASK {TASK_TRACK_NAME} SOLUTION ".center(80, "="))
    print("="*80 + "\n")

    logger.info(f"Starting BEA 2025 Shared Task {TASK_TRACK_NAME} Solution")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log_message = f"Using device: {device}"
    logger.info(log_message)
    print(log_message)

    print("\n" + "-"*80)
    print(" DATA PREPARATION ".center(80, "-"))
    print("-"*80)

    logger.info("Loading and preprocessing data...")
    print("Loading and preprocessing data...")

    raw_dev_data = load_data(DEV_FILE_PATH)
    processed_dev_data = preprocess_data(raw_dev_data, annotation_key=ANNOTATION_KEY, is_test_set=False)

    if not processed_dev_data:
       error_message = "No valid development examples found after preprocessing. Check data and ANNOTATION_KEY. Exiting."
       logger.error(error_message)
       print(error_message)
       exit()

    try:
        train_labels = [d['label'] for d in processed_dev_data]
        train_data, val_data = train_test_split(
            processed_dev_data,
            test_size=0.1,
            random_state=SEED,
            stratify=train_labels
        )
        split_message = f"Split dev data: {len(train_data)} train, {len(val_data)} validation examples."
        logger.info(split_message)
        print(split_message)
    except ValueError as e:
        warning_message = f"Could not stratify split (maybe too few samples per class?): {e}. Using random split."
        logger.warning(warning_message)
        print(warning_message)
        train_data, val_data = train_test_split(
            processed_dev_data, test_size=0.1, random_state=SEED
        )
        split_message = f"Split dev data (random): {len(train_data)} train, {len(val_data)} validation examples."
        logger.info(split_message)
        print(split_message)

    print("\nCalculating class weights for weighted loss...")
    train_labels = [d['label'] for d in train_data]
    class_counts = Counter(train_labels)
    total_samples = len(train_labels)
    class_weights = []
    print("Training set label distribution:")
    for i in range(NUM_LABELS):
        count = class_counts.get(i, 0)
        label_name = ID_TO_LABEL_MAP.get(i, f"Class_{i}")
        percentage = (count / total_samples * 100) if total_samples > 0 else 0
        print(f"  {label_name}: {count} ({percentage:.1f}%)")
        if count == 0:
             weight = 0
             print(f"  WARN: Class '{label_name}' not present in training split. Assigning weight 0.")
        else:
            weight = total_samples / (NUM_LABELS * count)
        class_weights.append(weight)

    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
    print(f"\nCalculated class weights: {class_weights_tensor.cpu().numpy().round(3)}")
    logger.info(f"Calculated class weights: {class_weights_tensor.cpu().numpy().round(3).tolist()}")

    print("\n" + "-"*80)
    print(" MODEL INITIALIZATION ".center(80, "-"))
    print("-"*80)

    tokenizer_message = f"Loading tokenizer: {MODEL_NAME}"
    logger.info(tokenizer_message)
    print(tokenizer_message)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    dataloader_message = "Creating datasets and dataloaders..."
    logger.info(dataloader_message)
    print(dataloader_message)

    train_dataset = PedagogicalAbilityDataset(train_data, tokenizer, MAX_SEQ_LENGTH)
    val_dataset = PedagogicalAbilityDataset(val_data, tokenizer, MAX_SEQ_LENGTH)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True)

    model_message = f"Loading pre-trained model: {MODEL_NAME} for {NUM_LABELS} labels."
    logger.info(model_message)
    print(model_message)

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=NUM_LABELS,
        id2label=ID_TO_LABEL_MAP,
        label2id=LABEL_MAP,
        ignore_mismatched_sizes=True
    )
    model.to(device)
    print(f"Model loaded and moved to {device}.")

    loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)
    print("Using weighted CrossEntropyLoss for training.")
    logger.info("Using weighted CrossEntropyLoss for training.")

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / GRADIENT_ACCUMULATION_STEPS)
    num_training_steps = num_update_steps_per_epoch * EPOCHS
    num_warmup_steps = int(num_training_steps * WARMUP_PROPORTION)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    optimizer_message = f"Optimizer & Scheduler: Effective BS={BATCH_SIZE*GRADIENT_ACCUMULATION_STEPS}, Total Steps={num_training_steps}, Warmup Steps={num_warmup_steps}"
    logger.info(optimizer_message)
    print(optimizer_message)

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    best_f1_strict = -1.0
    best_epoch = -1
    best_model_path = os.path.join(OUTPUT_DIR, "best_model")

    print("\n" + "="*80)
    print(" TRAINING ".center(80, "="))
    print("="*80)

    logger.info(f"Starting training for {TASK_TRACK_NAME}...")
    print(f"Starting training for {TASK_TRACK_NAME}...")

    for epoch in range(EPOCHS):
        train_epoch(model, train_dataloader, loss_fct, optimizer, scheduler, device, GRADIENT_ACCUMULATION_STEPS, epoch + 1, EPOCHS)
        eval_metrics = evaluate(model, val_dataloader, device, epoch + 1, EPOCHS)

        current_f1_strict = eval_metrics["f1_macro_strict"]
        if current_f1_strict > best_f1_strict:
            best_f1_strict = current_f1_strict
            best_epoch = epoch + 1
            save_message = f"*** New best model! F1 Strict: {best_f1_strict:.4f} (Epoch {best_epoch}). Saving to {best_model_path} ***"
            logger.info(save_message)
            print(f"\n{save_message}\n")
            model.save_pretrained(best_model_path)
            tokenizer.save_pretrained(best_model_path)
        else:
            no_improve_message = f"F1 Strict ({current_f1_strict:.4f}) did not improve from best ({best_f1_strict:.4f} from epoch {best_epoch})."
            logger.info(no_improve_message)

        print("\n" + "-"*80 + "\n")

    training_finished_message = f"Training finished. Best model from Epoch {best_epoch} with Strict F1: {best_f1_strict:.4f}"
    logger.info(training_finished_message)
    print(training_finished_message)

    print("\n" + "="*80)
    print(" TEST SET PREDICTION ".center(80, "="))
    print("="*80 + "\n")

    logger.info("Predicting on Test Set using the best model")
    print("Predicting on Test Set using the best model")

    if best_epoch == -1:
        if os.path.exists(best_model_path):
             warn_message = "WARN: Training finished without improvement, but found existing 'best_model'. Loading it for prediction."
             logger.warning(warn_message)
             print(f"\n{warn_message}\n")
        else:
            error_message = "ERROR: No best model was saved during training and no pre-existing 'best_model' found. Cannot proceed with prediction."
            logger.error(error_message)
            print(f"\n{error_message}\n")
            exit()
    else:
         load_message = f"Loading best model from {best_model_path} (Epoch {best_epoch})"
         logger.info(load_message)
         print(load_message)

    try:
        model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
        tokenizer = AutoTokenizer.from_pretrained(best_model_path)
        model.to(device)
        model.eval()
        print(f"Best model loaded and moved to {device}.")
    except Exception as e:
        error_message = f"ERROR: Failed to load the best model from {best_model_path}: {e}"
        logger.error(error_message)
        print(error_message)
        exit()

    raw_test_data = load_data(TEST_FILE_PATH)
    processed_test_data = preprocess_data(raw_test_data, annotation_key=ANNOTATION_KEY, is_test_set=True)

    if not processed_test_data:
        error_message = "No test examples found after preprocessing. Cannot generate predictions."
        logger.error(error_message)
        print(error_message)
        exit()

    test_dataset = PedagogicalAbilityDataset(processed_test_data, tokenizer, MAX_SEQ_LENGTH, is_test=True)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True)

    all_predictions = []
    inference_message = "Running inference on the test set..."
    logger.info(inference_message)
    print(inference_message)

    progress_bar_test = tqdm(test_dataloader, desc="Predicting", position=0, leave=True, dynamic_ncols=True)
    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar_test):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1).cpu().numpy()

            metadata_list = []
            num_items = len(batch['metadata']['conversation_id'])
            for i in range(num_items):
                metadata_list.append({
                    'conversation_id': batch['metadata']['conversation_id'][i],
                    'tutor_id': batch['metadata']['tutor_id'][i]
                })

            for i in range(len(preds)):
                pred_label_id = preds[i]
                pred_label_str = ID_TO_LABEL_MAP[pred_label_id]
                meta = metadata_list[i]
                all_predictions.append({
                    "conversation_id": meta['conversation_id'],
                    "tutor_id": meta['tutor_id'],
                    "predicted_label": pred_label_str
                })

    predictions_message = f"Generated {len(all_predictions)} predictions for the test set."
    logger.info(predictions_message)
    print(f"\n{predictions_message}")

    if all_predictions:
        pred_labels = [LABEL_MAP[p["predicted_label"]] for p in all_predictions]
        label_counts = np.bincount(np.array(pred_labels), minlength=NUM_LABELS)
        print("\nTest prediction distribution:")
        total_preds_count = len(pred_labels)
        if total_preds_count > 0:
            for label_id, count in enumerate(label_counts):
                 percentage = (count / total_preds_count * 100)
                 print(f"  {ID_TO_LABEL_MAP.get(label_id, 'Unknown')}: {count} ({percentage:.1f}%)")
        else:
             print("  Distribution calculation skipped as no predictions were generated.")
    else:
        print("No predictions were generated.")

    print("\n" + "-"*80)
    print(" FORMATTING PREDICTIONS ".center(80, "-"))
    print("-"*80 + "\n")

    format_message = f"Formatting predictions into submission structure for {TASK_TRACK_NAME}..."
    logger.info(format_message)
    print(format_message)

    submission_data_dict = {}
    test_conv_map = {conv["conversation_id"]: conv for conv in raw_test_data}

    for pred in all_predictions:
        conv_id = pred["conversation_id"]
        tutor_id = pred["tutor_id"]
        predicted_label = pred["predicted_label"]

        if conv_id not in submission_data_dict:
            original_conv = test_conv_map.get(conv_id)
            if not original_conv:
                warning_message = f"Original conversation {conv_id} not found in raw test data map. Skipping prediction for {tutor_id}."
                logger.warning(warning_message)
                print(f"WARN: {warning_message}")
                continue
            submission_data_dict[conv_id] = {
                "conversation_id": conv_id,
                "conversation_history": original_conv.get("conversation_history", "HISTORY_NOT_FOUND"),
                "tutor_responses": {}
            }

        original_response_text = "RESPONSE_NOT_FOUND"
        original_tutor_responses = test_conv_map.get(conv_id, {}).get("tutor_responses", {})
        original_response_data = original_tutor_responses.get(tutor_id, {})
        original_response_text = original_response_data.get("response", "RESPONSE_NOT_FOUND")

        submission_data_dict[conv_id]["tutor_responses"][tutor_id] = {
            "response": original_response_text,
            "annotation": {
                ANNOTATION_KEY: predicted_label
            }
        }

    final_submission_list = list(submission_data_dict.values())
    format_done_message = f"Formatted {len(final_submission_list)} conversations for submission."
    logger.info(format_done_message)
    print(format_done_message)

    print("\n" + "-"*80)
    print(" SAVING RESULTS & CHECKING COMPLETENESS ".center(80, "-"))
    print("-"*80 + "\n")

    output_json_path = os.path.join(OUTPUT_DIR, PREDICTIONS_FILENAME)
    output_zip_path = os.path.join(OUTPUT_DIR, ZIP_FILENAME)

    save_message = f"Saving formatted predictions to {output_json_path}"
    logger.info(save_message)
    print(save_message)

    try:
        with open(output_json_path, 'w', encoding='utf-8') as f:
            json.dump(final_submission_list, f, indent=2, ensure_ascii=False)
        print(f"Successfully saved {output_json_path}")
    except Exception as e:
        error_message = f"ERROR: Failed to save predictions JSON: {e}"
        logger.error(error_message)
        print(error_message)
        exit()

    print("\n--- Checking Prediction Completeness ---")
    logger.info("Checking if all expected tutor responses have predictions...")

    predicted_pairs = set()
    for item in final_submission_list:
        conv_id = item['conversation_id']
        tutor_responses = item.get('tutor_responses')
        if isinstance(tutor_responses, dict):
            for tutor_id in tutor_responses:
                if ANNOTATION_KEY in tutor_responses[tutor_id].get("annotation", {}):
                     predicted_pairs.add((conv_id, tutor_id))
                else:
                     warning_message = f"Annotation key '{ANNOTATION_KEY}' missing for {conv_id}/{tutor_id} in the final formatted list. This prediction might be considered missing."
                     logger.warning(warning_message)
                     print(f"WARN: {warning_message}")

    missing_predictions = []
    total_expected_responses = 0
    expected_pairs = set()

    for conv in raw_test_data:
        conv_id = conv['conversation_id']
        tutor_responses_original = conv.get('tutor_responses')
        if isinstance(tutor_responses_original, dict):
            for tutor_id, resp_data in tutor_responses_original.items():
                 if resp_data.get("response", "").strip():
                     expected_pair = (conv_id, tutor_id)
                     if expected_pair not in expected_pairs:
                         expected_pairs.add(expected_pair)
                         total_expected_responses += 1
                         if expected_pair not in predicted_pairs:
                             missing_predictions.append(expected_pair)

    if not missing_predictions:
        completeness_message = f"SUCCESS: All {total_expected_responses} expected tutor responses have predictions in {output_json_path}."
        logger.info(completeness_message)
        print(completeness_message)

        print("\n--- Zipping Predictions ---")
        zip_message = f"Zipping predictions to {output_zip_path}"
        logger.info(zip_message)
        print(zip_message)
        try:
            with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
                zf.write(output_json_path, arcname=PREDICTIONS_FILENAME)
            zip_success_message = f"Successfully created zip file: {output_zip_path}"
            logger.info(zip_success_message)
            print(zip_success_message)
            final_zip_path = output_zip_path
        except FileNotFoundError:
             error_message = f"ERROR: Failed to create zip file. Source file not found: {output_json_path}"
             logger.error(error_message)
             print(error_message)
             final_zip_path = None
        except Exception as e:
             error_message = f"ERROR: Failed creating zip file: {e}"
             logger.error(error_message)
             print(error_message)
             final_zip_path = None

    else:
        error_message_header = f"ERROR: Found {len(missing_predictions)} missing predictions out of {total_expected_responses} expected responses."
        logger.error(error_message_header)
        print(f"\n{error_message_header}")
        print("Examples of missing predictions (ConvID, TutorID):")
        for i, (mcid, mtid) in enumerate(missing_predictions):
            if i >= 10:
                 print(f"  ... and {len(missing_predictions) - 10} more.")
                 break
            print(f"  - Missing: ('{mcid}', '{mtid}')")

        zip_skip_message = f"Submission file '{output_json_path}' was generated, but ZIP file '{output_zip_path}' will NOT be created due to missing predictions. Please check the preprocessing steps and prediction generation."
        logger.error(zip_skip_message)
        print(f"\n{zip_skip_message}\n")
        final_zip_path = None

    print("\n" + "="*80)
    print(f" {TASK_TRACK_NAME} TASK COMPLETED ".center(80, "="))
    print("="*80)

    task_completed_message = f"{TASK_TRACK_NAME} Task Completed"
    logger.info(task_completed_message)

    completion_details = [
        f"Task: {TASK_TRACK_NAME}",
        f"Best model saved in: {best_model_path}",
        f"Predictions saved to: {output_json_path}",
    ]
    if final_zip_path:
        completion_details.append(f"Submission zip file saved to: {final_zip_path}")
    else:
        completion_details.append(f"Submission zip file was NOT created due to missing predictions (check logs).")

    for detail in completion_details:
        logger.info(detail)
        print(detail)



Using device: cuda

--------------------------------------------------------------------------------
------------------------------- DATA PREPARATION -------------------------------
--------------------------------------------------------------------------------
Loading and preprocessing data...
Loaded 300 conversations from /kaggle/input/bae-acl-dataset/mrbench_v3_devset.json
Preprocessed into 2476 individual examples for key 'Actionability'.
Split dev data: 2228 train, 248 validation examples.

Calculating class weights for weighted loss...
Training set label distribution:
  Yes: 1179 (52.9%)
  To some extent: 332 (14.9%)
  No: 717 (32.2%)

Calculated class weights: [0.63  2.237 1.036]

--------------------------------------------------------------------------------
----------------------------- MODEL INITIALIZATION -----------------------------
--------------------------------------------------------------------------------
Loading tokenizer: microsoft/deberta-v3-large


tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]



Creating datasets and dataloaders...
Loading pre-trained model: microsoft/deberta-v3-large for 3 labels.


2025-04-23 09:27:32.646519: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745400452.833364      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745400452.887092      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/874M [00:00<?, ?B/s]

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and moved to cuda.
Using weighted CrossEntropyLoss for training.
Optimizer & Scheduler: Effective BS=4, Total Steps=6684, Warmup Steps=668

Starting training for Track 4 - Actionability...

EPOCH 1/12 - TRAINING (Track 4 - Actionability)


Training Epoch 1/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 1/12 - Average Training Loss: 0.521028

--------------------------------------------------------------------------------
EPOCH 1/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 1/12 - Validation Results:
 - Eval Loss (Unweighted): 0.7995
 - Accuracy: 0.6653
 - F1 Macro (Strict): 0.4757 (Primary Metric)
 - F1 Macro (Lenient): 0.7468

Validation Prediction distribution:
  Yes: 141 (56.9%)
  To some extent: 0 (0.0%)
  No: 107 (43.1%)

*** New best model! F1 Strict: 0.4757 (Epoch 1). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 2/12 - TRAINING (Track 4 - Actionability)


Training Epoch 2/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 2/12 - Average Training Loss: 0.447283

--------------------------------------------------------------------------------
EPOCH 2/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 2/12 - Validation Results:
 - Eval Loss (Unweighted): 0.7549
 - Accuracy: 0.7097
 - F1 Macro (Strict): 0.5032 (Primary Metric)
 - F1 Macro (Lenient): 0.8068

Validation Prediction distribution:
  Yes: 187 (75.4%)
  To some extent: 0 (0.0%)
  No: 61 (24.6%)

*** New best model! F1 Strict: 0.5032 (Epoch 2). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 3/12 - TRAINING (Track 4 - Actionability)


Training Epoch 3/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 3/12 - Average Training Loss: 0.406540

--------------------------------------------------------------------------------
EPOCH 3/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 3/12 - Validation Results:
 - Eval Loss (Unweighted): 0.7193
 - Accuracy: 0.7379
 - F1 Macro (Strict): 0.6470 (Primary Metric)
 - F1 Macro (Lenient): 0.8544

Validation Prediction distribution:
  Yes: 157 (63.3%)
  To some extent: 25 (10.1%)
  No: 66 (26.6%)

*** New best model! F1 Strict: 0.6470 (Epoch 3). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 4/12 - TRAINING (Track 4 - Actionability)


Training Epoch 4/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 4/12 - Average Training Loss: 0.381683

--------------------------------------------------------------------------------
EPOCH 4/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 4/12 - Validation Results:
 - Eval Loss (Unweighted): 0.8977
 - Accuracy: 0.7097
 - F1 Macro (Strict): 0.6096 (Primary Metric)
 - F1 Macro (Lenient): 0.8100

Validation Prediction distribution:
  Yes: 155 (62.5%)
  To some extent: 28 (11.3%)
  No: 65 (26.2%)

--------------------------------------------------------------------------------


EPOCH 5/12 - TRAINING (Track 4 - Actionability)


Training Epoch 5/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 5/12 - Average Training Loss: 0.318732

--------------------------------------------------------------------------------
EPOCH 5/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 5/12 - Validation Results:
 - Eval Loss (Unweighted): 1.0556
 - Accuracy: 0.7339
 - F1 Macro (Strict): 0.6698 (Primary Metric)
 - F1 Macro (Lenient): 0.8598

Validation Prediction distribution:
  Yes: 135 (54.4%)
  To some extent: 46 (18.5%)
  No: 67 (27.0%)

*** New best model! F1 Strict: 0.6698 (Epoch 5). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 6/12 - TRAINING (Track 4 - Actionability)


Training Epoch 6/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 6/12 - Average Training Loss: 0.267351

--------------------------------------------------------------------------------
EPOCH 6/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 6/12 - Validation Results:
 - Eval Loss (Unweighted): 1.3414
 - Accuracy: 0.7218
 - F1 Macro (Strict): 0.6606 (Primary Metric)
 - F1 Macro (Lenient): 0.8391

Validation Prediction distribution:
  Yes: 123 (49.6%)
  To some extent: 44 (17.7%)
  No: 81 (32.7%)

--------------------------------------------------------------------------------


EPOCH 7/12 - TRAINING (Track 4 - Actionability)


Training Epoch 7/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 7/12 - Average Training Loss: 0.192471

--------------------------------------------------------------------------------
EPOCH 7/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 7/12 - Validation Results:
 - Eval Loss (Unweighted): 1.5379
 - Accuracy: 0.7218
 - F1 Macro (Strict): 0.6744 (Primary Metric)
 - F1 Macro (Lenient): 0.8724

Validation Prediction distribution:
  Yes: 120 (48.4%)
  To some extent: 55 (22.2%)
  No: 73 (29.4%)

*** New best model! F1 Strict: 0.6744 (Epoch 7). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 8/12 - TRAINING (Track 4 - Actionability)


Training Epoch 8/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 8/12 - Average Training Loss: 0.132961

--------------------------------------------------------------------------------
EPOCH 8/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 8/12 - Validation Results:
 - Eval Loss (Unweighted): 1.8532
 - Accuracy: 0.7339
 - F1 Macro (Strict): 0.6517 (Primary Metric)
 - F1 Macro (Lenient): 0.8405

Validation Prediction distribution:
  Yes: 150 (60.5%)
  To some extent: 31 (12.5%)
  No: 67 (27.0%)

--------------------------------------------------------------------------------


EPOCH 9/12 - TRAINING (Track 4 - Actionability)


Training Epoch 9/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 9/12 - Average Training Loss: 0.082539

--------------------------------------------------------------------------------
EPOCH 9/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 9/12 - Validation Results:
 - Eval Loss (Unweighted): 1.8417
 - Accuracy: 0.7218
 - F1 Macro (Strict): 0.6555 (Primary Metric)
 - F1 Macro (Lenient): 0.8392

Validation Prediction distribution:
  Yes: 142 (57.3%)
  To some extent: 41 (16.5%)
  No: 65 (26.2%)

--------------------------------------------------------------------------------


EPOCH 10/12 - TRAINING (Track 4 - Actionability)


Training Epoch 10/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 10/12 - Average Training Loss: 0.052474

--------------------------------------------------------------------------------
EPOCH 10/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 10/12 - Validation Results:
 - Eval Loss (Unweighted): 1.8676
 - Accuracy: 0.7500
 - F1 Macro (Strict): 0.6960 (Primary Metric)
 - F1 Macro (Lenient): 0.8609

Validation Prediction distribution:
  Yes: 138 (55.6%)
  To some extent: 41 (16.5%)
  No: 69 (27.8%)

*** New best model! F1 Strict: 0.6960 (Epoch 10). Saving to bea2025_track4_output/best_model ***


--------------------------------------------------------------------------------


EPOCH 11/12 - TRAINING (Track 4 - Actionability)


Training Epoch 11/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 11/12 - Average Training Loss: 0.033087

--------------------------------------------------------------------------------
EPOCH 11/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 11/12 - Validation Results:
 - Eval Loss (Unweighted): 1.9424
 - Accuracy: 0.7500
 - F1 Macro (Strict): 0.6920 (Primary Metric)
 - F1 Macro (Lenient): 0.8482

Validation Prediction distribution:
  Yes: 139 (56.0%)
  To some extent: 37 (14.9%)
  No: 72 (29.0%)

--------------------------------------------------------------------------------


EPOCH 12/12 - TRAINING (Track 4 - Actionability)


Training Epoch 12/12:   0%|          | 0/1114 [00:00<?, ?it/s]


Epoch 12/12 - Average Training Loss: 0.023198

--------------------------------------------------------------------------------
EPOCH 12/12 - VALIDATION
--------------------------------------------------------------------------------


Validating:   0%|          | 0/62 [00:00<?, ?it/s]

Epoch 12/12 - Validation Results:
 - Eval Loss (Unweighted): 2.0042
 - Accuracy: 0.7419
 - F1 Macro (Strict): 0.6805 (Primary Metric)
 - F1 Macro (Lenient): 0.8482

Validation Prediction distribution:
  Yes: 139 (56.0%)
  To some extent: 37 (14.9%)
  No: 72 (29.0%)

--------------------------------------------------------------------------------

Training finished. Best model from Epoch 10 with Strict F1: 0.6960


Predicting on Test Set using the best model
Loading best model from bea2025_track4_output/best_model (Epoch 10)
Best model loaded and moved to cuda.
Loaded 191 conversations from /kaggle/input/bae-acl-dataset/mrbench_v3_testset.json
Preprocessed into 1547 individual examples for key 'Actionability'.
Running inference on the test set...


Predicting:   0%|          | 0/387 [00:00<?, ?it/s]


Generated 1547 predictions for the test set.

Test prediction distribution:
  Yes: 905 (58.5%)
  To some extent: 226 (14.6%)
  No: 416 (26.9%)

--------------------------------------------------------------------------------
---------------------------- FORMATTING PREDICTIONS ----------------------------
--------------------------------------------------------------------------------

Formatting predictions into submission structure for Track 4 - Actionability...
Formatted 191 conversations for submission.

--------------------------------------------------------------------------------
-------------------- SAVING RESULTS & CHECKING COMPLETENESS --------------------
--------------------------------------------------------------------------------

Saving formatted predictions to bea2025_track4_output/predictions.json
Successfully saved bea2025_track4_output/predictions.json

--- Checking Prediction Completeness ---
SUCCESS: All 1547 expected tutor responses have predictions in bea2025_