# The AI Telco Troubleshooting Challenge

Reducing the operational cost of network faults - whether caused by hardware failures or software misconfigurations - is a critical priority for modern telecom service providers (telcos).

Telelogs, the automatically generated fault and event logs produced by network equipment, offer a rich source of information. Recent research has demonstrated that these logs can be used to fine-tune specialised LLMs capable of performing root-cause analysis and assisting network engineers. However, building models that generalise across unseen faults, new data distributions, and entirely new network environments, while still running efficiently on constrained edge servers, is still a significant challenge.


The evaluation metric for this challenge is Pass @ 1

This metric measures the ability of the model to produce a correct answer in a single attempt. It is computed by evaluating each of the 4 generated responses individually and averaging the correctness over all samples.

The models will be evaluated on their capability to troubleshoot network problems together with knowledge retention. Knowledge retention is the ability to maintain general knowledge accuracy after fine-tuning. The private dataset will include network faults whose data has a different structure than telelogs, and general knowledge questions.

In [None]:
# Install neccesary libraries
!pip install pandas numpy matplotlib seaborn
!pip install transformers peft bitsandbytes trl accelerate datasets
!pip install -qU flash-attn --no-build-isolation
!pip install -qU vllm>=0.6.0



In [10]:
# Imports and data loading
import re
import pandas as pd
from datasets import Dataset
from io import StringIO

# Load datasets
train_df = pd.read_csv("train.csv")
phase1_truth = pd.read_csv('phase_1_test_truth.csv')
phase1_test = pd.read_csv("phase_1_test.csv")
phase2_test = pd.read_csv("phase_2_test.csv")

In [11]:
# Prepare train dataset

# Remove the underscore followed by a digit at the end of each ID
phase1_truth['ID'] = phase1_truth['ID'].str.replace(r'_\d$', '', regex=True)
phase1_truth = phase1_truth[['ID', 'Qwen2.5-1.5B-Instruct']]
phase1_truth = phase1_truth.rename(columns={'Qwen2.5-1.5B-Instruct': 'answer'})
phase1_truth = phase1_truth.drop_duplicates(subset=['ID'])
questions_df = phase1_test[['ID', 'question']].copy()
phase1_truth = phase1_truth.merge(questions_df, on='ID', how='left')
phase1_truth = phase1_truth[['ID', 'question', 'answer']]
train_df = pd.concat([train_df, phase1_truth], ignore_index=True)
train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [12]:
# Review the question feature sample
with pd.option_context('display.max_colwidth', None):
    print(train_df['question'].sample(1))

1860    Analyze the 5G wireless network drive-test user plane data and engineering parameters.\nIdentify the reason for the throughput dropping below 600Mbps in certain road sections.\nFrom the following 8 potential root causes, select the most likely one and enclose its number in \boxed{{}} in the final answer.\n\nC1: The serving cell's downtilt angle is too large, causing weak coverage at the far end.\nC2: The serving cell's coverage distance exceeds 1km, resulting in over-shooting.\nC3: A neighboring cell provides higher throughput.\nC4: Non-colocated co-frequency neighboring cells cause severe overlapping coverage.\nC5: Frequent handovers degrade performance.\nC6: Neighbor cell and serving cell have the same PCI mod 30, leading to interference.\nC7: Test vehicle speed exceeds 40km/h, impacting user throughput.\nC8: Average scheduled RBs are below 160, affecting throughput.\n\nGiven:\n- The default electronic downtilt value is 255, representing a downtilt angle of 6 degrees. Other v

In [13]:
import re
import pandas as pd
from io import StringIO

CAUSE_DESC = {
    "C1": "downtilt too large",
    "C2": "coverage > 1km",
    "C3": "neighbor cell better throughput",
    "C4": "non-colocated co-frequency overlap",
    "C5": "frequent handovers",
    "C6": "PCI mod30 conflict",
    "C7": "speed > 40 km/h",
    "C8": "RBs < 160"
}

def extract_key_info(text: str) -> str:
    info_parts = []

    # --- Symptom ---
    info_parts.append("Symptom: Throughput dropped below 600 Mbps in some road sections.")

    # --- Parse User Plane Data ---
    user_plane_match = re.search(r'User plane drive test data as followsï¼š?\s*\n(.*?)(?:\n\n|\Z)', text, re.DOTALL)
    if user_plane_match:
        table_str = user_plane_match.group(1).strip()
        try:
            table_str = table_str.replace('|', ',')
            df_user = pd.read_csv(StringIO(table_str))

            # Speed Analysis
            speeds = df_user.get('GPS Speed (km/h)', pd.Series())
            if not speeds.empty:
                max_speed = float(speeds.max())
                avg_speed = float(speeds.mean())
                info_parts.append(f"Speed: Max={max_speed:.1f} km/h, Avg={avg_speed:.1f} km/h")
                if max_speed > 40:
                    info_parts.append("Speed exceeds 40 km/h â†’ potential C7")

            # Throughput Analysis
            throughputs = df_user.get('5G KPI PCell Layer2 MAC DL Throughput [Mbps]', pd.Series())
            if not throughputs.empty:
                min_tp = float(throughputs.min())
                avg_tp = float(throughputs.mean())
                info_parts.append(f"Throughput: Min={min_tp:.1f} Mbps, Avg={avg_tp:.1f} Mbps")
                if min_tp < 600:
                    info_parts.append("Throughput consistently low")

            # RB Analysis
            rbs = df_user.get('5G KPI PCell Layer1 DL RB Num (Including 0)', pd.Series())
            if not rbs.empty:
                avg_rbs = float(rbs.mean())
                info_parts.append(f"RBs: Avg={avg_rbs:.1f}")
                if avg_rbs < 160:
                    info_parts.append("Avg RBs < 160 â†’ potential C8")

            # Handover Analysis
            pcis = df_user.get('5G KPI PCell RF Serving PCI', pd.Series())
            if not pcis.empty:
                handover_count = (pcis != pcis.shift(1)).sum() - 1
                unique_pcis = pcis.unique()
                info_parts.append(f"Handovers: {max(0, int(handover_count))} events")
                if handover_count > 3:
                    info_parts.append("Frequent handovers â†’ potential C5")
                info_parts.append(f"Serving PCIs: {list(unique_pcis.astype(int))}")

                # PCI Mod30 Conflict Check
                neighbor_cols = [col for col in df_user.columns if 'Top' in col and 'PCI' in col]
                if neighbor_cols:
                    neighbor_pcis = pd.concat([df_user[col].dropna() for col in neighbor_cols]).unique()
                    for pci_s in unique_pcis:
                        for pci_n in neighbor_pcis:
                            if int(pci_s) % 30 == int(pci_n) % 30 and pci_s != pci_n:
                                info_parts.append(f"PCI Mod30 Conflict: Serving {int(pci_s)} â†” Neighbor {int(pci_n)} â†’ potential C6")
                                break

            # Neighbor Throughput Comparison (C3)
            neighbor_tp_cols = [col for col in df_user.columns if 'Top' in col and 'Throughput' in col]
            if neighbor_tp_cols:
                neighbor_tps = df_user[neighbor_tp_cols].max(axis=1)
                if not neighbor_tps.empty:
                    avg_neighbor_tp = neighbor_tps.mean()
                    if avg_neighbor_tp > avg_tp * 1.2:  # 20% higher
                        info_parts.append(f"Neighbor throughput avg={avg_neighbor_tp:.1f} Mbps > serving avg={avg_tp:.1f} â†’ potential C3")

        except Exception as e:
            info_parts.append("Failed to parse user plane table; using fallback regex.")

    # --- Parse Engineering Parameters ---
    eng_match = re.search(r'Engeneering parameters data as followsï¼š?\s*\n(.*?)(?:\n\n|\Z)', text, re.DOTALL)
    if eng_match:
        table_str = eng_match.group(1).strip()
        try:
            table_str = table_str.replace('|', ',')
            df_eng = pd.read_csv(StringIO(table_str))

            # Effective Downtilt (C1)
            digital_tilt = df_eng.get('Digital Tilt', pd.Series())
            mech_tilt = df_eng.get('Mechanical Downtilt', pd.Series())
            if not digital_tilt.empty and not mech_tilt.empty:
                eff_tilts = []
                for d, m in zip(digital_tilt, mech_tilt):
                    d = float(d); m = float(m)
                    eff_d = 6.0 if d == 255 else d
                    eff_tilts.append(eff_d + m)
                avg_downtilt = sum(eff_tilts) / len(eff_tilts)
                info_parts.append(f"Downtilt: Avg={avg_downtilt:.1f}Â°")
                if avg_downtilt > 10:  # Threshold based on typical deployment
                    info_parts.append("Downtilt >10Â° â†’ potential C1 (weak far-end coverage)")

            # Beam Scenario â†’ Vertical Beamwidth
            beam_scenarios = df_eng.get('Beam Scenario', pd.Series())
            if not beam_scenarios.empty:
                beam_widths = []
                for bs in beam_scenarios:
                    if pd.isna(bs): continue
                    if 'SCENARIO_' in str(bs):
                        num = int(re.search(r'SCENARIO_(\d+)', str(bs)).group(1))
                        if num <= 5:
                            bw = 6
                        elif 6 <= num <= 11:
                            bw = 12
                        else:
                            bw = 25
                    else:
                        bw = 6  # DEFAULT
                    beam_widths.append(bw)
                info_parts.append(f"Beamwidths: {beam_widths}")
                if any(bw > 12 for bw in beam_widths):
                    info_parts.append("Wide beamwidths â†’ may indicate over-shooting or poor focus â†’ potential C2")

            # Colocation & Frequency (C4)
            gnodeb_ids = df_eng.get('gNodeB ID', pd.Series()).dropna().unique()
            if len(gnodeb_ids) == 1:
                info_parts.append("Cells are colocated")
            else:
                info_parts.append("Non-colocated cells â†’ potential C4 if co-frequency")
                info_parts.append("Non-colocated + co-frequency â†’ high risk of overlapping coverage â†’ potential C4")

            # Coverage Distance (C2)
            distances = df_eng.get('Distance to Cell (m)', pd.Series())
            if not distances.empty:
                max_dist = float(distances.max())
                info_parts.append(f"Max Distance: {max_dist:.0f} m")
                if max_dist > 1000:
                    info_parts.append("Coverage >1km â†’ potential C2 (over-shooting)")

        except Exception as e:
            info_parts.append("Failed to parse engineering table; using fallback regex.")

    # --- Fallback Regex for Critical Fields ---
    if not any("Speed:" in p for p in info_parts):
        speed_match = re.search(r'speed\s*[=:]\s*(\d+\.?\d*)\s*km/h', text, re.IGNORECASE)
        if speed_match:
            info_parts.append(f"Speed: Max={float(speed_match.group(1)):.1f} km/h")
            if float(speed_match.group(1)) > 40:
                info_parts.append("Speed >40 km/h â†’ potential C7")

    if not any("RBs:" in p for p in info_parts):
        rb_match = re.search(r'(RBs?|resource blocks?)\s*[=:]\s*(\d+)', text, re.IGNORECASE)
        if rb_match:
            avg_rbs = int(rb_match.group(2))
            info_parts.append(f"RBs: Avg={avg_rbs}")
            if avg_rbs < 160:
                info_parts.append("RBs <160 â†’ potential C8")

    # --- Final Compact Summary ---
    compact_q = "Telco RCA Input:\n" + "\n".join(info_parts)
    return compact_q

In [14]:
def build_sft_example(question: str, answer: str) -> str:
    cause_desc = CAUSE_DESC[answer]
    compact_q = extract_key_info(question)

    # Construct instruction-aware prompt using Qwen2.5-Instruct's chat format
    return (
        f"<|im_start|>user\n{compact_q}<|im_end|>\n"
        f"<|im_start|>assistant\n"
        f"Based on the diagnostic evidence, the most likely root cause is {cause_desc}. "
        f"Final answer: \\boxed{{{answer}}}.<|im_end|>"
    )

In [15]:
# Cell 3: Build dataset using enhanced preprocessing
sft_texts = []
for _, row in train_df.iterrows():
    q = str(row["question"]).strip()
    a = str(row["answer"]).strip()
    
    # Skip invalid or out-of-scope answers
    if a not in CAUSE_DESC:
        continue
        
    try:
        txt = build_sft_example(q, a)
        sft_texts.append(txt)
    except Exception as e:
        # Optional: log skipped samples for debugging
        continue  # silently skip malformed entries

# Create Hugging Face Dataset
from datasets import Dataset

dataset = Dataset.from_dict({"text": sft_texts})
dataset = dataset.train_test_split(test_size=0.2, seed=42)  # reproducible split

print("Sample enhanced prompt:")
print(dataset["train"][0]["text"])

Sample enhanced prompt:
<|im_start|>user
Telco RCA Input:
Symptom: Throughput dropped below 600 Mbps in some road sections.
Speed: Max=39.0 km/h, Avg=23.4 km/h
Throughput: Min=244.9 Mbps, Avg=604.2 Mbps
Throughput consistently low
RBs: Avg=190.2
Handovers: 1 events
Serving PCIs: [np.int64(430), np.int64(374)]
Failed to parse user plane table; using fallback regex.
Downtilt: Avg=12.9Â°
Downtilt >10Â° â†’ potential C1 (weak far-end coverage)
Beamwidths: [6, 6, 6, 6, 6, 12, 12]
Non-colocated cells â†’ potential C4 if co-frequency
Non-colocated + co-frequency â†’ high risk of overlapping coverage â†’ potential C4<|im_end|>
<|im_start|>assistant
Based on the diagnostic evidence, the most likely root cause is downtilt too large. Final answer: \boxed{C1}.<|im_end|>


In [None]:
print(repr(dataset["train"][1000]["text"][:1024]))

In [16]:
# Cell 4: Load model and tokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
# MODEL_NAME = '/kaggle/input/qwen2.5/transformers/1.5b-instruct/1'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

# Freeze base model
model.requires_grad_(False)

# LoRA config
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ]
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 73,859,072 || all params: 1,617,573,376 || trainable%: 4.5660


In [21]:
# pip install --upgrade trl

In [None]:
# from trl import SFTTrainer
# from trl import DataCollatorForCompletionOnlyLM
# from transformers import TrainingArguments

# # 1. Define the Response Template for Qwen2.5 (ChatML)
# # This identifies the start of the assistant's reply.
# response_template = "<|im_start|>assistant\n"

# # 2. Setup the specialized collator
# # mlm=False is implied here as it's specifically for causal completion.
# data_collator = DataCollatorForCompletionOnlyLM(
#     response_template=response_template, 
#     tokenizer=tokenizer
# )

# # 3. Training arguments tuned for T4Ã—2 (24GB VRAM)
# training_args = TrainingArguments(
#     output_dir="./qwen2.5-1.5b-rca",
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=8,
#     learning_rate=3e-4,
#     num_train_epochs=1,
#     logging_steps=1,
#     save_strategy="steps",
#     save_steps=20,
#     eval_strategy="steps",
#     eval_steps=20,
#     bf16=True,  # T4 supports bf16, but use fp16=True if you see stability issues
#     fp16=False,
#     optim="adamw_torch",
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.1,
#     gradient_checkpointing=True,
#     dataloader_num_workers=2,
#     report_to="none",
#     remove_unused_columns=False, # Set to False when using SFTTrainer with raw text
# )

# # 4. Initialize SFTTrainer
# # We pass the RAW dataset (dataset["train"]) instead of tokenized_dataset.
# # SFTTrainer will handle tokenization internally using the 'text' field.
# trainer = SFTTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=dataset["train"],
#     eval_dataset=dataset["test"],
#     dataset_text_field="text",      # Name of the column containing the ChatML strings
#     max_seq_length=1024,
#     data_collator=data_collator,    # Our new completion-only collator
# )

# print("Starting training with Completion-Only masking...")
# trainer.train()

In [None]:
# Cell 5: Pre-tokenize dataset and train with evaluation every 10 steps
from transformers import DataCollatorForLanguageModeling
from trl import SFTTrainer
from transformers import TrainingArguments
from peft import PeftModel  # âœ… Import PeftModel to check type

# Tokenize function
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=1024,
        return_special_tokens_mask=False
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"],
    num_proc=1
)

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./qwen2.5-1.5b-rca",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=5e-4,
    num_train_epochs=1,
    logging_steps=1,
    save_strategy="steps",
    save_steps=20,
    eval_strategy="steps",
    eval_steps=20,
    bf16=True,
    fp16=False,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    report_to="none",
    remove_unused_columns=True,
    load_best_model_at_end=False,
)

# SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

# Start training
print("Starting training...")
trainer.train()

# Save adapter-only version (optional, for later PEFT reuse)
print("Saving LoRA adapters...")
trainer.save_model("./qwen2.5-1.5b-rca-adapters")  # Saves only adapters
tokenizer.save_pretrained("./qwen2.5-1.5b-rca-adapters")

# Save FULL MERGED MODEL (base + LoRA fused)
print("Merging LoRA adapters into base model...")
if isinstance(model, PeftModel):
    merged_model = model.merge_and_unload()  # Fuses LoRA weights into base model
else:
    merged_model = model  # Fallback if not PEFT-wrapped (unlikely)

merged_output_dir = "./qwen2.5-1.5b-rca-merged"
merged_model.save_pretrained(merged_output_dir)
tokenizer.save_pretrained(merged_output_dir)

print(f" Full merged model saved to: {merged_output_dir}")
print("You can now load it with AutoModelForCausalLM.from_pretrained() without PEFT!")

Tokenizing dataset...


Map (num_proc=1):   0%|          | 0/2611 [00:00<?, ? examples/s]

Map (num_proc=1):   0%|          | 0/653 [00:00<?, ? examples/s]



Truncating train dataset:   0%|          | 0/2611 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/653 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


Starting training...


In [7]:
# # Cell 5: Pre-tokenize dataset and train with evaluation every 10 steps
# from transformers import DataCollatorForLanguageModeling
# from trl import SFTTrainer
# from transformers import TrainingArguments

# # Tokenize function
# def tokenize_function(examples):
#     return tokenizer(
#         examples["text"],
#         truncation=True,
#         padding="max_length",
#         max_length=1024,  # compressed prompts are short; 1024 is safe & efficient
#         return_special_tokens_mask=False  # not needed for causal LM
#     )

# print("Tokenizing dataset...")
# tokenized_dataset = dataset.map(
#     tokenize_function,
#     batched=True,
#     remove_columns=["text"],  # Free memory
#     num_proc=1  # Kaggle CPU cores are limited
# )

# # Data collator for causal language modeling
# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer,
#     mlm=False  # Important: this is a causal LM, not masked LM
# )

# # Training arguments tuned for T4Ã—2 (24GB VRAM)
# training_args = TrainingArguments(
#     output_dir="./qwen2.5-1.5b-rca",
#     per_device_train_batch_size=4,           # Fit in T4 VRAM
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=8,           # Effective batch = 32
#     learning_rate=1e-4,
#     num_train_epochs=1,
#     logging_steps=1,
#     save_strategy="steps",                   # Optional: save more frequently
#     save_steps=20,                           # Save checkpoint every 10 steps
#     eval_strategy="steps",                   
#     eval_steps=20,                           
#     bf16=True,                              
#     fp16=False,
#     optim="adamw_torch",
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.1,
#     gradient_checkpointing=True,            
#     dataloader_num_workers=2,
#     report_to="none",
#     remove_unused_columns=True,
#     load_best_model_at_end=False,           # Disable if using step-based eval without metric
#     greater_is_better=False,                # Not used unless you define a metric
# )

# # SFTTrainer WITHOUT tokenizer or formatting_func
# trainer = SFTTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=tokenized_dataset["train"],
#     eval_dataset=tokenized_dataset["test"],
#     data_collator=data_collator,
# )

# trainer.train()

Tokenizing dataset...


Map (num_proc=1):   0%|          | 0/2611 [00:00<?, ? examples/s]

Map (num_proc=1):   0%|          | 0/653 [00:00<?, ? examples/s]



Truncating train dataset:   0%|          | 0/2611 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/653 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss,Validation Loss
20,0.2697,0.260451
40,0.2344,0.237121
60,0.2263,0.224982
80,0.2253,0.221473


TrainOutput(global_step=82, training_loss=0.543404329113844, metrics={'train_runtime': 14227.6988, 'train_samples_per_second': 0.184, 'train_steps_per_second': 0.006, 'total_flos': 2.220530911936512e+16, 'train_loss': 0.543404329113844})

Supervised Fine-tuning (SFT) is the first critical step in transforming a raw, "base" language model into an "instruct" model that can actually follow user commands. While a base model is trained to simply predict the next word on the internet, an SFT-trained model learns to act as a helpful assistant. Implemented using the **SFTTrainer** from HuggingFace's `trl` (**_Transformer Reinforcement Learning_**) library. SFT reduces the likelihood of the model "hallucinating" the continuation of the prompt by keeping it focused on the answer. 

The `lr_scheduler_type` parameter determines how the learning rate changes over the course of training. Set argument for this parameter is crucial in ensuring the model "settles" into its new task without erasing the foundational knowledge it gained during pre-training. 

In [None]:
# Cell 6: FAST Batched Inference with Progress Bar + Immediate Parsing
import torch
import pandas as pd
from tqdm import tqdm
import re
from transformers import AutoModelForCausalLM, AutoTokenizer

# âœ… Load the FULL MERGED model (no PEFT needed!)
print("Loading fine-tuned merged model for inference...")
model = AutoModelForCausalLM.from_pretrained(
    "./qwen2.5-1.5b-rca-merged",
    device_map="auto",        # Automatically uses GPU if available
    torch_dtype=torch.bfloat16  # Match training dtype (bf16)
)
tokenizer = AutoTokenizer.from_pretrained("./qwen2.5-1.5b-rca-merged")

# Ensure extract_key_info and CAUSE_DESC are defined (from earlier cells)
# If not, re-import or re-define them here

def map_output_to_root_cause(text: str) -> str:
    """Lightweight version for inference-time parsing"""
    # Priority 1: Extract \boxed{C#}
    boxed_match = re.search(r'\\boxed\{(C[1-8])\}', text)
    if boxed_match:
        return boxed_match.group(1)
    
    # Priority 2: Direct C# mention
    direct_match = re.search(r'\b(C[1-8])\b', text)
    if direct_match:
        return direct_match.group(1)
    
    # Priority 3: Heuristic fallback using cause descriptions
    text_lower = text.lower()
    if "downtilt too large" in text_lower or "tilt too large" in text_lower: return "C1"
    if "coverage > 1km" in text_lower or ">1km" in text_lower: return "C2"
    if "neighbor cell better throughput" in text_lower: return "C3"
    if "non-colocated" in text_lower and ("co-frequency" in text_lower or "overlap" in text_lower): return "C4"
    if "frequent handover" in text_lower: return "C5"
    if "pci mod30" in text_lower or "mod30 conflict" in text_lower: return "C6"
    if "speed > 40" in text_lower or ("speed" in text_lower and re.search(r'[4-9]\d\s*km/h', text_lower)): return "C7"
    if "rb" in text_lower and ("<160" in text_lower or "below 160" in text_lower or "avg rbs" in text_lower): return "C8"
    
    return "C1"  # safe default


def generate_responses_batched(df: pd.DataFrame, batch_size=64, num_return_sequences=4):
    all_rows = []
    device = model.device

    # Precompute prompts using the SAME extract_key_info used in training
    all_prompts = []
    all_original_ids = []

    for _, row in df.iterrows():
        # CRITICAL: Use the enhanced extract_key_info
        compact_q = extract_key_info(str(row["question"]).strip())
        
        # Format as chat for Qwen2.5-Instruct
        messages = [{"role": "user", "content": compact_q}]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Repeat for ensemble-style sampling
        all_prompts.extend([prompt] * num_return_sequences)
        all_original_ids.extend([row['ID']] * num_return_sequences)

    # Process in batches
    for i in tqdm(range(0, len(all_prompts), batch_size), desc=f"Inference ({df.iloc[0]['ID'].split('_')[0]})"):
        batch_prompts = all_prompts[i:i + batch_size]
        batch_orig_ids = all_original_ids[i:i + batch_size]

        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024
        ).to(device)

        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id
            )

        # Remove input prompt from output
        input_lengths = inputs.input_ids.shape[1]
        decoded = tokenizer.batch_decode(
            outputs[:, input_lengths:],
            skip_special_tokens=True
        )

        # Parse and store
        for j, text in enumerate(decoded):
            cause = map_output_to_root_cause(text.strip())
            all_rows.append({
                "original_ID": batch_orig_ids[j],
                "sample_ID": f"{batch_orig_ids[j]}_{j % num_return_sequences + 1}",
                "raw_output": text.strip(),
                "predicted_cause": cause
            })

    return all_rows


# Run inference
print("Running batched inference on Phase 1...")
phase1_rows = generate_responses_batched(phase1_test, batch_size=64, num_return_sequences=4)

print("Running batched inference on Phase 2...")
phase2_rows = generate_responses_batched(phase2_test, batch_size=64, num_return_sequences=4)

all_pred_rows = phase1_rows + phase2_rows

In [None]:
# # Cell 6: FAST Batched Inference with Progress Bar + Immediate Parsing
# import torch
# import pandas as pd
# from tqdm import tqdm
# import re

# # Ensure extract_key_info and CAUSE_DESC are defined (from earlier cells)
# # If not, re-import or re-define them here

# def map_output_to_root_cause(text: str) -> str:
#     """Lightweight version for inference-time parsing"""
#     # Priority 1: Extract \boxed{C#}
#     boxed_match = re.search(r'\\boxed\{(C[1-8])\}', text)
#     if boxed_match:
#         return boxed_match.group(1)
    
#     # Priority 2: Direct C# mention
#     direct_match = re.search(r'\b(C[1-8])\b', text)
#     if direct_match:
#         return direct_match.group(1)
    
#     # Priority 3: Heuristic fallback using cause descriptions
#     text_lower = text.lower()
#     if "downtilt too large" in text_lower or "tilt too large" in text_lower: return "C1"
#     if "coverage > 1km" in text_lower or ">1km" in text_lower: return "C2"
#     if "neighbor cell better throughput" in text_lower: return "C3"
#     if "non-colocated" in text_lower and ("co-frequency" in text_lower or "overlap" in text_lower): return "C4"
#     if "frequent handover" in text_lower: return "C5"
#     if "pci mod30" in text_lower or "mod30 conflict" in text_lower: return "C6"
#     if "speed > 40" in text_lower or ("speed" in text_lower and re.search(r'[4-9]\d\s*km/h', text_lower)): return "C7"
#     if "rb" in text_lower and ("<160" in text_lower or "below 160" in text_lower or "avg rbs" in text_lower): return "C8"
    
#     return "C1"  # safe default


# def generate_responses_batched(df: pd.DataFrame, batch_size=64, num_return_sequences=4):
#     all_rows = []
#     device = model.device

#     # Precompute prompts using the SAME extract_key_info used in training
#     all_prompts = []
#     all_original_ids = []

#     for _, row in df.iterrows():
#         # CRITICAL: Use the enhanced extract_key_info
#         compact_q = extract_key_info(str(row["question"]).strip())
        
#         # Format as chat for Qwen2.5-Instruct
#         messages = [{"role": "user", "content": compact_q}]
#         prompt = tokenizer.apply_chat_template(
#             messages,
#             tokenize=False,
#             add_generation_prompt=True
#         )
        
#         # Repeat for ensemble-style sampling
#         all_prompts.extend([prompt] * num_return_sequences)
#         all_original_ids.extend([row['ID']] * num_return_sequences)

#     # Process in batches
#     for i in tqdm(range(0, len(all_prompts), batch_size), desc=f"Inference ({df.iloc[0]['ID'].split('_')[0]})"):
#         batch_prompts = all_prompts[i:i + batch_size]
#         batch_orig_ids = all_original_ids[i:i + batch_size]

#         inputs = tokenizer(
#             batch_prompts,
#             return_tensors="pt",
#             padding=True,
#             truncation=True,
#             max_length=1024  # Increased to accommodate richer context
#         ).to(device)

#         with torch.inference_mode():
#             outputs = model.generate(
#                 **inputs,
#                 max_new_tokens=256,          # Reduced (answers are short)
#                 do_sample=True,
#                 temperature=0.7,
#                 top_p=0.9,
#                 pad_token_id=tokenizer.eos_token_id,
#                 eos_token_id=tokenizer.eos_token_id
#             )

#         # Remove input prompt from output
#         input_lengths = inputs.input_ids.shape[1]
#         decoded = tokenizer.batch_decode(
#             outputs[:, input_lengths:],
#             skip_special_tokens=True
#         )

#         # Parse and store
#         for j, text in enumerate(decoded):
#             cause = map_output_to_root_cause(text.strip())
#             all_rows.append({
#                 "original_ID": batch_orig_ids[j],
#                 "sample_ID": f"{batch_orig_ids[j]}_{j % num_return_sequences + 1}",
#                 "raw_output": text.strip(),
#                 "predicted_cause": cause
#             })

#     return all_rows


# # Run inference
# print("Running batched inference on Phase 1...")
# phase1_rows = generate_responses_batched(phase1_test, batch_size=64, num_return_sequences=4)

# print("Running batched inference on Phase 2...")
# phase2_rows = generate_responses_batched(phase2_test, batch_size=64, num_return_sequences=4)

# all_pred_rows = phase1_rows + phase2_rows

In [None]:
# Cell 7: Build final submission with voting over 4 samples per test case
import re
from collections import Counter
import pandas as pd

# Reuse CAUSE_DESC if not already in scope
CAUSE_DESC = {
    "C1": "downtilt too large",
    "C2": "coverage > 1km",
    "C3": "neighbor cell better throughput",
    "C4": "non-colocated co-frequency overlap",
    "C5": "frequent handovers",
    "C6": "PCI mod30 conflict",
    "C7": "speed > 40 km/h",
    "C8": "RBs < 160"
}

def map_output_to_root_cause(text: str) -> str:
    """Robust parser for model output (used during inference fallback)"""
    text_lower = text.lower()
    
    # Priority 1: \boxed{Cx}
    boxed_match = re.search(r'\\boxed\{(C[1-8])\}', text)
    if boxed_match:
        return boxed_match.group(1)
    
    # Priority 2: standalone Cx
    direct_match = re.search(r'\b(C[1-8])\b', text)
    if direct_match:
        return direct_match.group(1)
    
    # Priority 3: keyword triggers (aligned with cause descriptions)
    triggers = {
        "C1": [r"downtilt.*large", r"tilt.*too.*large", r"effective.*downtilt.*[7-9]\d*"],
        "C2": [r"coverage.*>.*1.?km", r"distance.*>.*1000", r"over.?shoot"],
        "C3": [r"neighbor.*cell.*better.*throughput", r"neighboring.*higher.*throughput"],
        "C4": [r"non.?colocated.*co.?frequency", r"overlap.*coverage.*non.?colocated"],
        "C5": [r"frequent.*handover", r"handovers.*\b([5-9]\d|\d{3,})\b"],
        "C6": [r"pci.*mod.*30.*conflict", r"mod30.*conflict", r"pci.*mod30"],
        "C7": [r"speed.*>.*40.*km/h", r"max.*speed.*>.*40"],
        "C8": [r"rb.*<.*160", r"avg.*rb.*<.*160", r"resource.*blocks.*below.*160"]
    }
    
    for cause, patterns in triggers.items():
        for pat in patterns:
            if re.search(pat, text_lower):
                return cause
                
    return "C7"  # safe default


# Convert inference results to DataFrame
pred_df = pd.DataFrame(all_pred_rows)

# Ensure 'original_ID' exists (e.g., "P1_001")
assert "original_ID" in pred_df.columns, "Missing 'original_ID' in inference output"
assert "raw_output" in pred_df.columns, "Missing 'raw_output' in inference output"

# Re-parse predictions (optional but safe: in case parsing was skipped during inference)
pred_df["predicted_cause"] = pred_df["raw_output"].apply(map_output_to_root_cause)

# Group by base test ID (e.g., "P1_001") and vote
final_predictions = {}
for base_id, group in pred_df.groupby("original_ID"):
    causes = group["predicted_cause"].tolist()
    counter = Counter(causes)
    # Get most common; if tie, pick the first one encountered (most_common is stable)
    final_cause = counter.most_common(1)[0][0]
    final_predictions[base_id] = final_cause

# Load sample submission
sample_sub = pd.read_csv("SampleSubmission.csv")
submission = sample_sub.copy()

# Map predictions to submission format: \boxed{Cx}
submission["Qwen2.5-1.5B-Instruct"] = (
    submission["ID"]
    .map(final_predictions)
    .apply(lambda c: f"\\boxed{{{c}}}" if pd.notna(c) else "\\boxed{C7}")
)

# Handle any missing IDs (should not happen if test sets match)
missing = submission["Qwen2.5-1.5B-Instruct"].isna().sum()
if missing > 0:
    print(f"  Warning: {missing} IDs had no prediction. Filling with \\boxed{{C7}}.")
    submission["Qwen2.5-1.5B-Instruct"].fillna("\\boxed{C7}", inplace=True)

# Save
output_file = "lightning_submission_track3_forcompletiononly.csv"
submission.to_csv(output_file, index=False)
print(f"âœ… Submission saved to '{output_file}' with {len(submission)} rows.")
submission.head()