In [None]:
# Install dependencies
!pip install -q transformers datasets torch accelerate seqeval scikit-learn

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Setup paths for Google Colab
import os
from pathlib import Path

PROJECT_ROOT = Path('/content/drive/MyDrive/574-assignment')
DATA_PATH = PROJECT_ROOT / 'data' / 'generated'
MODEL_PATH = PROJECT_ROOT / 'models' / 'two_stage_joint'

# Ensure model output directory exists
MODEL_PATH.mkdir(parents=True, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Data path: {DATA_PATH}")
print(f"Model path: {MODEL_PATH}")

# Verify data files exist
print("\nChecking data files:")
for f in ['train.json', 'val.json', 'test.json', 'label_mappings.json']:
    if (DATA_PATH / f).exists():
        print(f"  ‚úì Found {f}")
    else:
        print(f"  ‚úó Missing {f} - please upload to Google Drive")

In [None]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    DebertaV2TokenizerFast,
    DebertaV2Model,
    DebertaV2PreTrainedModel,
    get_linear_schedule_with_warmup,
    AutoConfig
)
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from seqeval.metrics import f1_score as seqeval_f1, classification_report
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration & Label Mappings

In [None]:
# Model configuration
CONFIG = {
    # Model
    'model_name': 'microsoft/deberta-v3-base',
    
    # Training
    'batch_size': 32,
    'learning_rate': 2e-5,
    'head_learning_rate': 1e-4,  # Higher LR for classification heads
    'max_length': 128,
    'weight_decay': 0.01,
    'gradient_accumulation_steps': 1,
    
    # Staged training
    'stage1_epochs': 3,      # Tool + NER only
    'stage2_epochs': 7,      # Joint training
    'finetune_epochs': 5,    # Fine-tuning with lower LR
    'warmup_ratio': 0.1,
    
    # Loss weights
    'tool_loss_weight': 1.0,
    'intent_loss_weight': 1.0,
    'ner_loss_weight': 1.0,
    
    # Early stopping
    'patience': 3,
    
    # Dropout
    'classifier_dropout': 0.1,
}

# Load label mappings from generated file
with open(DATA_PATH / 'label_mappings.json', 'r') as f:
    LABEL_MAPPINGS = json.load(f)

# Extract mappings
TOOL_TO_IDX = LABEL_MAPPINGS['tool_to_idx']
IDX_TO_TOOL = {int(k): v for k, v in LABEL_MAPPINGS['idx_to_tool'].items()}
TOOLS = list(TOOL_TO_IDX.keys())

TAG_TO_IDX = LABEL_MAPPINGS['tag_to_idx']
IDX_TO_TAG = {int(k): v for k, v in LABEL_MAPPINGS['idx_to_tag'].items()}

# Per-tool intent mappings (the key change from flat architecture)
INTENT_TO_IDX_PER_TOOL = LABEL_MAPPINGS['intent_to_idx_per_tool']
IDX_TO_INTENT_PER_TOOL = LABEL_MAPPINGS['idx_to_intent_per_tool']
NUM_INTENTS_PER_TOOL = LABEL_MAPPINGS['num_intents_per_tool']

# Global intent mapping (for reference)
INTENT_TO_TOOL = LABEL_MAPPINGS['intent_to_tool']

print(f"Number of tools: {len(TOOLS)}")
print(f"Number of NER tags: {len(TAG_TO_IDX)}")
print(f"\nIntents per tool:")
for tool, num in NUM_INTENTS_PER_TOOL.items():
    print(f"  {tool}: {num}")
print(f"\nTotal intents: {sum(NUM_INTENTS_PER_TOOL.values())}")

## 2. Load Data

In [None]:
def load_dataset(split='train'):
    """Load dataset from JSON file."""
    path = DATA_PATH / f'{split}.json'
    with open(path, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} {split} samples from {path}")
    return data

train_data = load_dataset('train')
val_data = load_dataset('val')
test_data = load_dataset('test')

In [None]:
# Examine sample structure
print("Sample structure:")
print(json.dumps(train_data[0], indent=2))

## 3. Dataset Class (Per-Tool Intent Labels)

In [None]:
class TwoStageDataset(Dataset):
    """
    Dataset for two-stage joint model.
    
    Key difference from flat architecture:
    - Intent labels are per-tool (categorical index, not multi-hot)
    - Non-selected tools get intent label -100 (ignored in loss)
    """
    
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        text = sample['query']
        tools = sample['tools']
        intents = sample['intents']
        bio_tags = sample['bio_tags']
        
        # Tokenize with word IDs for aligning BIO tags
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_offsets_mapping=True,
            return_special_tokens_mask=True
        )
        
        # Tool labels (multi-label binary)
        tool_labels = torch.zeros(len(TOOLS))
        for tool in tools:
            if tool in TOOL_TO_IDX:
                tool_labels[TOOL_TO_IDX[tool]] = 1
        
        # Per-tool intent labels (categorical with -100 for non-selected)
        intent_label_character = torch.tensor(-100, dtype=torch.long)
        intent_label_session = torch.tensor(-100, dtype=torch.long)
        intent_label_rulebook = torch.tensor(-100, dtype=torch.long)
        
        for tool, intent in intents.items():
            if tool == 'character_data' and intent in INTENT_TO_IDX_PER_TOOL['character_data']:
                intent_label_character = torch.tensor(
                    INTENT_TO_IDX_PER_TOOL['character_data'][intent], 
                    dtype=torch.long
                )
            elif tool == 'session_notes' and intent in INTENT_TO_IDX_PER_TOOL['session_notes']:
                intent_label_session = torch.tensor(
                    INTENT_TO_IDX_PER_TOOL['session_notes'][intent], 
                    dtype=torch.long
                )
            elif tool == 'rulebook' and intent in INTENT_TO_IDX_PER_TOOL['rulebook']:
                intent_label_rulebook = torch.tensor(
                    INTENT_TO_IDX_PER_TOOL['rulebook'][intent], 
                    dtype=torch.long
                )
        
        # Align BIO tags with tokens
        ner_labels = self._align_labels(
            text, bio_tags, encoding, encoding['offset_mapping'][0]
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'tool_labels': tool_labels,
            'intent_label_character': intent_label_character,
            'intent_label_session': intent_label_session,
            'intent_label_rulebook': intent_label_rulebook,
            'ner_labels': ner_labels,
            'tools': tools,
            'intents': intents
        }
    
    def _align_labels(self, text, bio_tags, encoding, offset_mapping):
        """Align word-level BIO tags to subword tokens."""
        words = text.split()
        
        # Handle mismatch between words and bio_tags
        if len(bio_tags) != len(words):
            if len(bio_tags) < len(words):
                bio_tags = bio_tags + ['O'] * (len(words) - len(bio_tags))
            else:
                bio_tags = bio_tags[:len(words)]
        
        # Create word boundaries
        word_boundaries = []
        current_pos = 0
        for word in words:
            start = text.find(word, current_pos)
            if start == -1:
                start = current_pos
            end = start + len(word)
            word_boundaries.append((start, end))
            current_pos = end
        
        # Map each token to a word
        aligned_labels = []
        special_tokens_mask = encoding['special_tokens_mask'][0].tolist()
        
        for idx, (start, end) in enumerate(offset_mapping.tolist()):
            if special_tokens_mask[idx] or (start == 0 and end == 0):
                aligned_labels.append(-100)
            else:
                word_idx = None
                for w_idx, (w_start, w_end) in enumerate(word_boundaries):
                    if start >= w_start and start < w_end:
                        word_idx = w_idx
                        break
                
                if word_idx is not None and word_idx < len(bio_tags):
                    tag = bio_tags[word_idx]
                    # Convert B- to I- for continuation tokens
                    if start > word_boundaries[word_idx][0] and tag.startswith('B-'):
                        tag = 'I-' + tag[2:]
                    aligned_labels.append(TAG_TO_IDX.get(tag, TAG_TO_IDX['O']))
                else:
                    aligned_labels.append(TAG_TO_IDX['O'])
        
        return torch.tensor(aligned_labels, dtype=torch.long)


def collate_fn(batch):
    """Custom collate function for per-tool intent labels."""
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'tool_labels': torch.stack([x['tool_labels'] for x in batch]),
        'intent_label_character': torch.stack([x['intent_label_character'] for x in batch]),
        'intent_label_session': torch.stack([x['intent_label_session'] for x in batch]),
        'intent_label_rulebook': torch.stack([x['intent_label_rulebook'] for x in batch]),
        'ner_labels': torch.stack([x['ner_labels'] for x in batch]),
        'tools': [x['tools'] for x in batch],
        'intents': [x['intents'] for x in batch]
    }

## 4. Two-Stage Joint Model Architecture

In [None]:
class TwoStageJointModel(DebertaV2PreTrainedModel):
    """
    Two-Stage Joint DeBERTa-v3 Model.
    
    Stage 1 (Query-intrinsic):
        - Tool Classification: Which tools are needed? (multi-label, 3 classes)
        - NER: What entities are mentioned? (BIO tagging with CRF)
    
    Stage 2 (Context-dependent, gated by tool selection):
        - Per-tool Intent Classification:
            - character_data: 10 intents (softmax)
            - session_notes: 20 intents (softmax)
            - rulebook: 30 intents (softmax)
    """
    
    def __init__(self, config, num_tools=3, num_ner_tags=25, 
                 num_character_intents=10, num_session_intents=20, num_rulebook_intents=30):
        super().__init__(config)
        
        self.num_tools = num_tools
        self.num_ner_tags = num_ner_tags
        self.num_character_intents = num_character_intents
        self.num_session_intents = num_session_intents
        self.num_rulebook_intents = num_rulebook_intents
        
        # Shared encoder
        self.deberta = DebertaV2Model(config)
        
        # Dropout
        classifier_dropout = (
            config.classifier_dropout 
            if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        
        # ========== STAGE 1 HEADS ==========
        # Tool classification head (multi-label with sigmoid)
        self.tool_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size, num_tools)
        )
        
        # NER head (standard token classification, no CRF for stability)
        self.ner_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_ner_tags)
        )
        
        # ========== STAGE 2 HEADS (Per-Tool Intent) ==========
        # Each head outputs softmax over that tool's intents
        self.character_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_character_intents)
        )
        
        self.session_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_session_intents)
        )
        
        self.rulebook_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_rulebook_intents)
        )
        
        # Initialize weights
        self.post_init()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        tool_labels=None,
        intent_label_character=None,
        intent_label_session=None,
        intent_label_rulebook=None,
        ner_labels=None,
        stage='all',  # 'stage1', 'stage2', or 'all'
        return_dict=True
    ):
        # Get encoder outputs (shared for both stages)
        outputs = self.deberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        cls_output = sequence_output[:, 0, :]  # (batch, hidden)
        cls_output = self.dropout(cls_output)
        
        # ========== STAGE 1: Tools + NER ==========
        tool_logits = self.tool_classifier(cls_output)  # (batch, num_tools)
        ner_logits = self.ner_classifier(sequence_output)  # (batch, seq_len, num_tags)
        
        # ========== STAGE 2: Per-Tool Intents (gated) ==========
        character_intent_logits = None
        session_intent_logits = None
        rulebook_intent_logits = None
        
        if stage in ['stage2', 'all']:
            # Always compute all intent logits (needed for inference)
            # Loss masking handles non-selected tools
            character_intent_logits = self.character_intent_head(cls_output)
            session_intent_logits = self.session_intent_head(cls_output)
            rulebook_intent_logits = self.rulebook_intent_head(cls_output)
        
        # ========== COMPUTE LOSSES ==========
        loss = None
        tool_loss = None
        ner_loss = None
        character_intent_loss = None
        session_intent_loss = None
        rulebook_intent_loss = None
        
        # Tool loss (BCE for multi-label)
        if tool_labels is not None:
            tool_loss = F.binary_cross_entropy_with_logits(tool_logits, tool_labels)
        
        # NER loss (CrossEntropy with ignore_index=-100)
        if ner_labels is not None:
            # Flatten for CrossEntropy: (batch * seq_len, num_tags) vs (batch * seq_len,)
            ner_loss = F.cross_entropy(
                ner_logits.view(-1, self.num_ner_tags),
                ner_labels.view(-1),
                ignore_index=-100
            )
        
        # Per-tool intent losses (CrossEntropy with ignore_index=-100)
        if stage in ['stage2', 'all']:
            ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
            
            if intent_label_character is not None:
                # Only compute if there are valid labels (not all -100)
                valid_mask = intent_label_character != -100
                if valid_mask.any():
                    character_intent_loss = ce_loss(
                        character_intent_logits[valid_mask],
                        intent_label_character[valid_mask]
                    )
            
            if intent_label_session is not None:
                valid_mask = intent_label_session != -100
                if valid_mask.any():
                    session_intent_loss = ce_loss(
                        session_intent_logits[valid_mask],
                        intent_label_session[valid_mask]
                    )
            
            if intent_label_rulebook is not None:
                valid_mask = intent_label_rulebook != -100
                if valid_mask.any():
                    rulebook_intent_loss = ce_loss(
                        rulebook_intent_logits[valid_mask],
                        intent_label_rulebook[valid_mask]
                    )
        
        # Combine losses based on training stage
        if stage == 'stage1':
            if tool_loss is not None and ner_loss is not None:
                loss = (
                    CONFIG['tool_loss_weight'] * tool_loss +
                    CONFIG['ner_loss_weight'] * ner_loss
                )
        elif stage in ['stage2', 'all']:
            # Sum all available losses
            loss_components = []
            if tool_loss is not None:
                loss_components.append(CONFIG['tool_loss_weight'] * tool_loss)
            if ner_loss is not None:
                loss_components.append(CONFIG['ner_loss_weight'] * ner_loss)
            
            # Intent losses (may be None if no samples for that tool in batch)
            intent_losses = [l for l in [character_intent_loss, session_intent_loss, rulebook_intent_loss] if l is not None]
            if intent_losses:
                # Average intent losses and weight
                avg_intent_loss = sum(intent_losses) / len(intent_losses)
                loss_components.append(CONFIG['intent_loss_weight'] * avg_intent_loss)
            
            if loss_components:
                loss = sum(loss_components)
        
        return {
            'loss': loss,
            'tool_loss': tool_loss,
            'ner_loss': ner_loss,
            'character_intent_loss': character_intent_loss,
            'session_intent_loss': session_intent_loss,
            'rulebook_intent_loss': rulebook_intent_loss,
            'tool_logits': tool_logits,
            'ner_logits': ner_logits,
            'character_intent_logits': character_intent_logits,
            'session_intent_logits': session_intent_logits,
            'rulebook_intent_logits': rulebook_intent_logits,
        }
    
    def decode_ner(self, ner_logits, attention_mask):
        """Decode NER predictions using argmax."""
        # Simple argmax decoding (no CRF)
        return ner_logits.argmax(dim=-1).tolist()
    
    def freeze_intent_heads(self):
        """Freeze intent heads for Stage 1 training."""
        for param in self.character_intent_head.parameters():
            param.requires_grad = False
        for param in self.session_intent_head.parameters():
            param.requires_grad = False
        for param in self.rulebook_intent_head.parameters():
            param.requires_grad = False
    
    def unfreeze_intent_heads(self):
        """Unfreeze intent heads for Stage 2 training."""
        for param in self.character_intent_head.parameters():
            param.requires_grad = True
        for param in self.session_intent_head.parameters():
            param.requires_grad = True
        for param in self.rulebook_intent_head.parameters():
            param.requires_grad = True

## 5. Initialize Model and Data

In [None]:
# Load tokenizer and config
print(f"Loading model: {CONFIG['model_name']}")
tokenizer = DebertaV2TokenizerFast.from_pretrained(CONFIG['model_name'])
config = AutoConfig.from_pretrained(CONFIG['model_name'])
config.classifier_dropout = CONFIG['classifier_dropout']

# Initialize model with per-tool intent counts
model = TwoStageJointModel.from_pretrained(
    CONFIG['model_name'],
    config=config,
    num_tools=len(TOOLS),
    num_ner_tags=len(TAG_TO_IDX),
    num_character_intents=NUM_INTENTS_PER_TOOL['character_data'],
    num_session_intents=NUM_INTENTS_PER_TOOL['session_notes'],
    num_rulebook_intents=NUM_INTENTS_PER_TOOL['rulebook'],
    ignore_mismatched_sizes=True
)
model = model.to(device)

# Print model size
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Print head sizes
print(f"\nIntent head sizes:")
print(f"  character_data: {NUM_INTENTS_PER_TOOL['character_data']} classes")
print(f"  session_notes: {NUM_INTENTS_PER_TOOL['session_notes']} classes")
print(f"  rulebook: {NUM_INTENTS_PER_TOOL['rulebook']} classes")

In [None]:
# Create datasets
train_dataset = TwoStageDataset(train_data, tokenizer, CONFIG['max_length'])
val_dataset = TwoStageDataset(val_data, tokenizer, CONFIG['max_length'])
test_dataset = TwoStageDataset(test_data, tokenizer, CONFIG['max_length'])

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'],
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=CONFIG['batch_size'],
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

# Verify a batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  input_ids: {sample_batch['input_ids'].shape}")
print(f"  tool_labels: {sample_batch['tool_labels'].shape}")
print(f"  intent_label_character: {sample_batch['intent_label_character'].shape}")
print(f"  intent_label_session: {sample_batch['intent_label_session'].shape}")
print(f"  intent_label_rulebook: {sample_batch['intent_label_rulebook'].shape}")
print(f"  ner_labels: {sample_batch['ner_labels'].shape}")

## 6. Training Setup

In [None]:
def get_optimizer(model, stage='stage1'):
    """
    Get optimizer with appropriate parameter groups.
    
    - Encoder uses base learning rate
    - Classification heads use higher learning rate
    - Stage 1: intent heads are frozen
    """
    no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight']
    
    # Group parameters
    encoder_params = []
    stage1_head_params = []  # tool + ner heads
    stage2_head_params = []  # intent heads
    
    for name, param in model.named_parameters():
        if 'deberta' in name:
            encoder_params.append((name, param))
        elif 'tool_classifier' in name or 'ner_classifier' in name or 'crf' in name:
            stage1_head_params.append((name, param))
        else:  # intent heads
            stage2_head_params.append((name, param))
    
    optimizer_grouped_parameters = [
        # Encoder with weight decay
        {
            'params': [p for n, p in encoder_params if not any(nd in n for nd in no_decay)],
            'lr': CONFIG['learning_rate'],
            'weight_decay': CONFIG['weight_decay']
        },
        # Encoder without weight decay
        {
            'params': [p for n, p in encoder_params if any(nd in n for nd in no_decay)],
            'lr': CONFIG['learning_rate'],
            'weight_decay': 0.0
        },
        # Stage 1 heads (tool + NER)
        {
            'params': [p for n, p in stage1_head_params if not any(nd in n for nd in no_decay)],
            'lr': CONFIG['head_learning_rate'],
            'weight_decay': CONFIG['weight_decay']
        },
        {
            'params': [p for n, p in stage1_head_params if any(nd in n for nd in no_decay)],
            'lr': CONFIG['head_learning_rate'],
            'weight_decay': 0.0
        },
    ]
    
    # Only include intent heads if not stage1
    if stage != 'stage1':
        optimizer_grouped_parameters.extend([
            {
                'params': [p for n, p in stage2_head_params if not any(nd in n for nd in no_decay)],
                'lr': CONFIG['head_learning_rate'],
                'weight_decay': CONFIG['weight_decay']
            },
            {
                'params': [p for n, p in stage2_head_params if any(nd in n for nd in no_decay)],
                'lr': CONFIG['head_learning_rate'],
                'weight_decay': 0.0
            },
        ])
    
    return torch.optim.AdamW(optimizer_grouped_parameters)


def get_scheduler(optimizer, num_training_steps, warmup_ratio=0.1):
    """Get linear scheduler with warmup."""
    warmup_steps = int(num_training_steps * warmup_ratio)
    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=num_training_steps
    )

## 7. Evaluation Functions

In [None]:
def evaluate(model, dataloader, device, stage='all'):
    """
    Evaluate model on validation/test set.
    
    Returns metrics for:
    - Tool classification (F1, exact match)
    - Per-tool intent accuracy
    - NER (entity-level F1)
    """
    model.eval()
    
    # Collectors
    all_tool_preds = []
    all_tool_labels = []
    
    # Per-tool intent collectors
    character_intent_preds = []
    character_intent_labels = []
    session_intent_preds = []
    session_intent_labels = []
    rulebook_intent_preds = []
    rulebook_intent_labels = []
    
    # NER collectors
    all_ner_preds = []
    all_ner_labels = []
    
    # Loss tracking
    total_loss = 0
    total_tool_loss = 0
    total_ner_loss = 0
    total_intent_losses = {'character': 0, 'session': 0, 'rulebook': 0}
    intent_loss_counts = {'character': 0, 'session': 0, 'rulebook': 0}
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            tool_labels = batch['tool_labels'].to(device)
            intent_label_character = batch['intent_label_character'].to(device)
            intent_label_session = batch['intent_label_session'].to(device)
            intent_label_rulebook = batch['intent_label_rulebook'].to(device)
            ner_labels = batch['ner_labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                tool_labels=tool_labels,
                intent_label_character=intent_label_character,
                intent_label_session=intent_label_session,
                intent_label_rulebook=intent_label_rulebook,
                ner_labels=ner_labels,
                stage=stage
            )
            
            if outputs['loss'] is not None:
                total_loss += outputs['loss'].item()
            if outputs['tool_loss'] is not None:
                total_tool_loss += outputs['tool_loss'].item()
            if outputs['ner_loss'] is not None:
                total_ner_loss += outputs['ner_loss'].item()
            
            # Track intent losses
            if outputs['character_intent_loss'] is not None:
                total_intent_losses['character'] += outputs['character_intent_loss'].item()
                intent_loss_counts['character'] += 1
            if outputs['session_intent_loss'] is not None:
                total_intent_losses['session'] += outputs['session_intent_loss'].item()
                intent_loss_counts['session'] += 1
            if outputs['rulebook_intent_loss'] is not None:
                total_intent_losses['rulebook'] += outputs['rulebook_intent_loss'].item()
                intent_loss_counts['rulebook'] += 1
            
            # Tool predictions (threshold 0.5)
            tool_preds = (torch.sigmoid(outputs['tool_logits']) > 0.5).float()
            all_tool_preds.append(tool_preds.cpu())
            all_tool_labels.append(tool_labels.cpu())
            
            # Per-tool intent predictions (argmax)
            if stage in ['stage2', 'all'] and outputs['character_intent_logits'] is not None:
                # Character intents
                valid_char = intent_label_character != -100
                if valid_char.any():
                    char_preds = outputs['character_intent_logits'][valid_char].argmax(dim=-1)
                    character_intent_preds.extend(char_preds.cpu().tolist())
                    character_intent_labels.extend(intent_label_character[valid_char].cpu().tolist())
                
                # Session intents
                valid_sess = intent_label_session != -100
                if valid_sess.any():
                    sess_preds = outputs['session_intent_logits'][valid_sess].argmax(dim=-1)
                    session_intent_preds.extend(sess_preds.cpu().tolist())
                    session_intent_labels.extend(intent_label_session[valid_sess].cpu().tolist())
                
                # Rulebook intents
                valid_rule = intent_label_rulebook != -100
                if valid_rule.any():
                    rule_preds = outputs['rulebook_intent_logits'][valid_rule].argmax(dim=-1)
                    rulebook_intent_preds.extend(rule_preds.cpu().tolist())
                    rulebook_intent_labels.extend(intent_label_rulebook[valid_rule].cpu().tolist())
            
            # NER predictions (argmax decode)
            ner_preds = model.decode_ner(outputs['ner_logits'], attention_mask)
            
            # Convert to tag sequences for seqeval
            for pred_seq, label_seq in zip(ner_preds, ner_labels.cpu().tolist()):
                pred_tags = []
                label_tags = []
                for pred_tag, label_tag in zip(pred_seq, label_seq):
                    if label_tag != -100:
                        pred_tags.append(IDX_TO_TAG[pred_tag])
                        label_tags.append(IDX_TO_TAG[label_tag])
                if pred_tags:
                    all_ner_preds.append(pred_tags)
                    all_ner_labels.append(label_tags)
    
    # Compute metrics
    num_batches = len(dataloader)
    
    # Tool metrics
    all_tool_preds = torch.cat(all_tool_preds, dim=0).numpy()
    all_tool_labels = torch.cat(all_tool_labels, dim=0).numpy()
    tool_f1 = f1_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    tool_precision = precision_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    tool_recall = recall_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    tool_exact_match = np.mean(np.all(all_tool_preds == all_tool_labels, axis=1))
    
    # Per-tool intent accuracy
    character_intent_acc = accuracy_score(character_intent_labels, character_intent_preds) if character_intent_labels else 0.0
    session_intent_acc = accuracy_score(session_intent_labels, session_intent_preds) if session_intent_labels else 0.0
    rulebook_intent_acc = accuracy_score(rulebook_intent_labels, rulebook_intent_preds) if rulebook_intent_labels else 0.0
    
    # Average intent accuracy
    intent_accs = [acc for acc in [character_intent_acc, session_intent_acc, rulebook_intent_acc] if acc > 0]
    avg_intent_acc = np.mean(intent_accs) if intent_accs else 0.0
    
    # NER F1 (using seqeval for entity-level)
    ner_f1 = seqeval_f1(all_ner_labels, all_ner_preds, average='micro', zero_division=0) if all_ner_labels else 0.0
    
    metrics = {
        'loss': total_loss / num_batches if num_batches > 0 else 0,
        'tool_loss': total_tool_loss / num_batches if num_batches > 0 else 0,
        'ner_loss': total_ner_loss / num_batches if num_batches > 0 else 0,
        'tool_f1': tool_f1,
        'tool_precision': tool_precision,
        'tool_recall': tool_recall,
        'tool_exact_match': tool_exact_match,
        'character_intent_acc': character_intent_acc,
        'session_intent_acc': session_intent_acc,
        'rulebook_intent_acc': rulebook_intent_acc,
        'avg_intent_acc': avg_intent_acc,
        'ner_f1': ner_f1,
    }
    
    # Add per-tool intent losses
    for tool in ['character', 'session', 'rulebook']:
        if intent_loss_counts[tool] > 0:
            metrics[f'{tool}_intent_loss'] = total_intent_losses[tool] / intent_loss_counts[tool]
        else:
            metrics[f'{tool}_intent_loss'] = 0.0
    
    return metrics, (all_ner_labels, all_ner_preds)

## 8. Training Loop (Staged)

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device, stage='all'):
    """Train for one epoch."""
    model.train()
    
    total_loss = 0
    total_tool_loss = 0
    total_ner_loss = 0
    total_intent_loss = 0
    intent_batch_count = 0
    
    progress_bar = tqdm(dataloader, desc=f"Training ({stage})")
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tool_labels = batch['tool_labels'].to(device)
        intent_label_character = batch['intent_label_character'].to(device)
        intent_label_session = batch['intent_label_session'].to(device)
        intent_label_rulebook = batch['intent_label_rulebook'].to(device)
        ner_labels = batch['ner_labels'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            tool_labels=tool_labels,
            intent_label_character=intent_label_character,
            intent_label_session=intent_label_session,
            intent_label_rulebook=intent_label_rulebook,
            ner_labels=ner_labels,
            stage=stage
        )
        
        loss = outputs['loss']
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        if outputs['tool_loss'] is not None:
            total_tool_loss += outputs['tool_loss'].item()
        if outputs['ner_loss'] is not None:
            total_ner_loss += outputs['ner_loss'].item()
        
        # Track intent losses
        intent_losses = [l for l in [outputs['character_intent_loss'], 
                                      outputs['session_intent_loss'], 
                                      outputs['rulebook_intent_loss']] if l is not None]
        if intent_losses:
            total_intent_loss += sum(l.item() for l in intent_losses) / len(intent_losses)
            intent_batch_count += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'tool': f"{outputs['tool_loss'].item():.4f}" if outputs['tool_loss'] else "N/A",
            'ner': f"{outputs['ner_loss'].item():.4f}" if outputs['ner_loss'] else "N/A",
        })
    
    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'tool_loss': total_tool_loss / num_batches,
        'ner_loss': total_ner_loss / num_batches,
        'intent_loss': total_intent_loss / intent_batch_count if intent_batch_count > 0 else 0
    }

In [None]:
# ========== STAGED TRAINING ==========
# Stage 1: Tool + NER only (freeze intent heads)
# Stage 2: Joint training (unfreeze intent heads)
# Stage 3: Fine-tuning (lower learning rate)

training_history = []
best_combined_score = 0
patience_counter = 0

print("=" * 70)
print("STAGE 1: Training Tool + NER heads (intent heads frozen)")
print("=" * 70)

# Freeze intent heads for Stage 1
model.freeze_intent_heads()

# Stage 1 optimizer (excludes intent heads)
stage1_steps = len(train_loader) * CONFIG['stage1_epochs']
optimizer = get_optimizer(model, stage='stage1')
scheduler = get_scheduler(optimizer, stage1_steps, CONFIG['warmup_ratio'])

for epoch in range(CONFIG['stage1_epochs']):
    print(f"\nEpoch {epoch + 1}/{CONFIG['stage1_epochs']} (Stage 1)")
    print("-" * 40)
    
    train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, stage='stage1')
    val_metrics, _ = evaluate(model, val_loader, device, stage='stage1')
    
    # Stage 1 score: tool + NER
    stage1_score = (val_metrics['tool_f1'] + val_metrics['ner_f1']) / 2
    
    history_entry = {
        'epoch': epoch + 1,
        'stage': 'stage1',
        'train_loss': train_metrics['loss'],
        'val_loss': val_metrics['loss'],
        'val_tool_f1': val_metrics['tool_f1'],
        'val_tool_exact_match': val_metrics['tool_exact_match'],
        'val_ner_f1': val_metrics['ner_f1'],
        'val_avg_intent_acc': 0.0,  # Not computed in stage 1
        'stage_score': stage1_score
    }
    training_history.append(history_entry)
    
    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f}")
    print(f"Val Tool F1: {val_metrics['tool_f1']:.4f} (Exact Match: {val_metrics['tool_exact_match']:.4f})")
    print(f"Val NER F1: {val_metrics['ner_f1']:.4f}")
    print(f"Stage 1 Score: {stage1_score:.4f}")

In [None]:
# ========== STAGE 2: Joint Training ==========
print("\n" + "=" * 70)
print("STAGE 2: Joint training (all heads)")
print("=" * 70)

# Unfreeze intent heads
model.unfreeze_intent_heads()

# Stage 2 optimizer (includes all parameters)
stage2_steps = len(train_loader) * CONFIG['stage2_epochs']
optimizer = get_optimizer(model, stage='stage2')
scheduler = get_scheduler(optimizer, stage2_steps, CONFIG['warmup_ratio'])

for epoch in range(CONFIG['stage2_epochs']):
    global_epoch = CONFIG['stage1_epochs'] + epoch + 1
    print(f"\nEpoch {global_epoch}/{CONFIG['stage1_epochs'] + CONFIG['stage2_epochs'] + CONFIG['finetune_epochs']} (Stage 2)")
    print("-" * 40)
    
    train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, stage='all')
    val_metrics, _ = evaluate(model, val_loader, device, stage='all')
    
    # Combined score: tool + NER + intent
    combined_score = (
        val_metrics['tool_f1'] + 
        val_metrics['ner_f1'] + 
        val_metrics['avg_intent_acc']
    ) / 3
    
    history_entry = {
        'epoch': global_epoch,
        'stage': 'stage2',
        'train_loss': train_metrics['loss'],
        'val_loss': val_metrics['loss'],
        'val_tool_f1': val_metrics['tool_f1'],
        'val_tool_exact_match': val_metrics['tool_exact_match'],
        'val_ner_f1': val_metrics['ner_f1'],
        'val_avg_intent_acc': val_metrics['avg_intent_acc'],
        'val_character_intent_acc': val_metrics['character_intent_acc'],
        'val_session_intent_acc': val_metrics['session_intent_acc'],
        'val_rulebook_intent_acc': val_metrics['rulebook_intent_acc'],
        'stage_score': combined_score
    }
    training_history.append(history_entry)
    
    print(f"Train Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f}")
    print(f"Val Tool F1: {val_metrics['tool_f1']:.4f} (Exact Match: {val_metrics['tool_exact_match']:.4f})")
    print(f"Val NER F1: {val_metrics['ner_f1']:.4f}")
    print(f"Val Intent Accuracy: {val_metrics['avg_intent_acc']:.4f}")
    print(f"  - Character: {val_metrics['character_intent_acc']:.4f}")
    print(f"  - Session: {val_metrics['session_intent_acc']:.4f}")
    print(f"  - Rulebook: {val_metrics['rulebook_intent_acc']:.4f}")
    print(f"Combined Score: {combined_score:.4f}")
    
    # Save best model
    if combined_score > best_combined_score:
        best_combined_score = combined_score
        patience_counter = 0
        
        model.save_pretrained(MODEL_PATH)
        tokenizer.save_pretrained(MODEL_PATH)
        
        # Save config and mappings
        with open(MODEL_PATH / 'training_config.json', 'w') as f:
            json.dump(CONFIG, f, indent=2)
        
        with open(MODEL_PATH / 'label_mappings.json', 'w') as f:
            json.dump(LABEL_MAPPINGS, f, indent=2)
        
        print(f"\n‚úì New best model saved! Combined Score: {combined_score:.4f}")
    else:
        patience_counter += 1
        print(f"\n‚úó No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
    
    if patience_counter >= CONFIG['patience']:
        print(f"\nEarly stopping triggered!")
        break

In [None]:
# ========== STAGE 3: Fine-tuning ==========
if patience_counter < CONFIG['patience']:
    print("\n" + "=" * 70)
    print("STAGE 3: Fine-tuning (lower learning rate)")
    print("=" * 70)
    
    # Lower learning rate for fine-tuning
    finetune_lr = CONFIG['learning_rate'] / 10
    finetune_head_lr = CONFIG['head_learning_rate'] / 10
    
    # Update CONFIG temporarily for fine-tuning
    original_lr = CONFIG['learning_rate']
    original_head_lr = CONFIG['head_learning_rate']
    CONFIG['learning_rate'] = finetune_lr
    CONFIG['head_learning_rate'] = finetune_head_lr
    
    stage3_steps = len(train_loader) * CONFIG['finetune_epochs']
    optimizer = get_optimizer(model, stage='stage2')
    scheduler = get_scheduler(optimizer, stage3_steps, CONFIG['warmup_ratio'])
    
    # Restore CONFIG
    CONFIG['learning_rate'] = original_lr
    CONFIG['head_learning_rate'] = original_head_lr
    
    for epoch in range(CONFIG['finetune_epochs']):
        global_epoch = CONFIG['stage1_epochs'] + CONFIG['stage2_epochs'] + epoch + 1
        print(f"\nEpoch {global_epoch}/{CONFIG['stage1_epochs'] + CONFIG['stage2_epochs'] + CONFIG['finetune_epochs']} (Fine-tune)")
        print("-" * 40)
        
        train_metrics = train_epoch(model, train_loader, optimizer, scheduler, device, stage='all')
        val_metrics, _ = evaluate(model, val_loader, device, stage='all')
        
        combined_score = (
            val_metrics['tool_f1'] + 
            val_metrics['ner_f1'] + 
            val_metrics['avg_intent_acc']
        ) / 3
        
        history_entry = {
            'epoch': global_epoch,
            'stage': 'finetune',
            'train_loss': train_metrics['loss'],
            'val_loss': val_metrics['loss'],
            'val_tool_f1': val_metrics['tool_f1'],
            'val_tool_exact_match': val_metrics['tool_exact_match'],
            'val_ner_f1': val_metrics['ner_f1'],
            'val_avg_intent_acc': val_metrics['avg_intent_acc'],
            'val_character_intent_acc': val_metrics['character_intent_acc'],
            'val_session_intent_acc': val_metrics['session_intent_acc'],
            'val_rulebook_intent_acc': val_metrics['rulebook_intent_acc'],
            'stage_score': combined_score
        }
        training_history.append(history_entry)
        
        print(f"Train Loss: {train_metrics['loss']:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val Tool F1: {val_metrics['tool_f1']:.4f}")
        print(f"Val NER F1: {val_metrics['ner_f1']:.4f}")
        print(f"Val Intent Accuracy: {val_metrics['avg_intent_acc']:.4f}")
        print(f"Combined Score: {combined_score:.4f}")
        
        if combined_score > best_combined_score:
            best_combined_score = combined_score
            patience_counter = 0
            
            model.save_pretrained(MODEL_PATH)
            tokenizer.save_pretrained(MODEL_PATH)
            print(f"\n‚úì New best model saved! Combined Score: {combined_score:.4f}")
        else:
            patience_counter += 1
            print(f"\n‚úó No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
        
        if patience_counter >= CONFIG['patience']:
            print(f"\nEarly stopping triggered!")
            break

print("\n" + "=" * 70)
print(f"Training complete! Best Combined Score: {best_combined_score:.4f}")
print("=" * 70)

## 9. Final Evaluation on Test Set

In [None]:
# Load best model for testing
print("Loading best model for final evaluation...")
model = TwoStageJointModel.from_pretrained(
    MODEL_PATH,
    num_tools=len(TOOLS),
    num_ner_tags=len(TAG_TO_IDX),
    num_character_intents=NUM_INTENTS_PER_TOOL['character_data'],
    num_session_intents=NUM_INTENTS_PER_TOOL['session_notes'],
    num_rulebook_intents=NUM_INTENTS_PER_TOOL['rulebook']
)
model = model.to(device)

# Evaluate on test set
test_metrics, (ner_labels, ner_preds) = evaluate(model, test_loader, device, stage='all')

print("\n" + "=" * 70)
print("TEST SET RESULTS")
print("=" * 70)

print(f"\nüìä Tool Classification:")
print(f"  F1 Score: {test_metrics['tool_f1']:.4f}")
print(f"  Precision: {test_metrics['tool_precision']:.4f}")
print(f"  Recall: {test_metrics['tool_recall']:.4f}")
print(f"  Exact Match: {test_metrics['tool_exact_match']:.4f}")

print(f"\nüéØ Intent Classification (Per-Tool Accuracy):")
print(f"  Character Data: {test_metrics['character_intent_acc']:.4f}")
print(f"  Session Notes: {test_metrics['session_intent_acc']:.4f}")
print(f"  Rulebook: {test_metrics['rulebook_intent_acc']:.4f}")
print(f"  Average: {test_metrics['avg_intent_acc']:.4f}")

print(f"\nüè∑Ô∏è NER (Entity Extraction):")
print(f"  F1 Score: {test_metrics['ner_f1']:.4f}")

# Combined score
combined_test_score = (
    test_metrics['tool_f1'] + 
    test_metrics['ner_f1'] + 
    test_metrics['avg_intent_acc']
) / 3
print(f"\n‚≠ê Combined Test Score: {combined_test_score:.4f}")

# Detailed NER report
print("\n" + "-" * 40)
print("Detailed NER Classification Report:")
print("-" * 40)
print(classification_report(ner_labels, ner_preds, zero_division=0))

In [None]:
# Save test results
test_results = {
    'model_name': CONFIG['model_name'],
    'test_metrics': test_metrics,
    'training_history': training_history,
    'best_combined_score': best_combined_score,
    'config': CONFIG
}

with open(MODEL_PATH / 'test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"\n‚úì Test results saved to {MODEL_PATH}/test_results.json")

## 10. Training History Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = [h['epoch'] for h in training_history]

# Color-code by stage
stage_colors = {'stage1': 'blue', 'stage2': 'green', 'finetune': 'orange'}
colors = [stage_colors[h['stage']] for h in training_history]

# Loss
axes[0, 0].plot(epochs, [h['train_loss'] for h in training_history], 'b-', label='Train', alpha=0.7)
axes[0, 0].plot(epochs, [h['val_loss'] for h in training_history], 'r-', label='Validation', alpha=0.7)
# Add stage boundaries
stage1_end = CONFIG['stage1_epochs']
stage2_end = CONFIG['stage1_epochs'] + CONFIG['stage2_epochs']
axes[0, 0].axvline(x=stage1_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[0, 0].axvline(x=stage2_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Tool F1
axes[0, 1].plot(epochs, [h['val_tool_f1'] for h in training_history], 'g-', marker='o', markersize=4)
axes[0, 1].axvline(x=stage1_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[0, 1].axvline(x=stage2_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Tool Classification F1')
axes[0, 1].grid(True, alpha=0.3)
axes[0, 1].set_ylim([0, 1])

# Intent Accuracy (only available from Stage 2)
intent_epochs = [h['epoch'] for h in training_history if h['stage'] != 'stage1']
intent_accs = [h['val_avg_intent_acc'] for h in training_history if h['stage'] != 'stage1']
if intent_epochs:
    axes[1, 0].plot(intent_epochs, intent_accs, 'm-', marker='o', markersize=4)
    # Also plot per-tool accuracies
    char_accs = [h.get('val_character_intent_acc', 0) for h in training_history if h['stage'] != 'stage1']
    sess_accs = [h.get('val_session_intent_acc', 0) for h in training_history if h['stage'] != 'stage1']
    rule_accs = [h.get('val_rulebook_intent_acc', 0) for h in training_history if h['stage'] != 'stage1']
    axes[1, 0].plot(intent_epochs, char_accs, 'b--', alpha=0.5, label='Character')
    axes[1, 0].plot(intent_epochs, sess_accs, 'g--', alpha=0.5, label='Session')
    axes[1, 0].plot(intent_epochs, rule_accs, 'r--', alpha=0.5, label='Rulebook')
    axes[1, 0].legend(loc='lower right')
axes[1, 0].axvline(x=stage2_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].set_title('Intent Classification Accuracy')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1])

# NER F1
axes[1, 1].plot(epochs, [h['val_ner_f1'] for h in training_history], 'c-', marker='o', markersize=4)
axes[1, 1].axvline(x=stage1_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[1, 1].axvline(x=stage2_end + 0.5, color='gray', linestyle='--', alpha=0.5)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1 Score')
axes[1, 1].set_title('NER F1 Score')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim([0, 1])

# Add stage labels
for ax in axes.flat:
    ax.text(stage1_end/2, ax.get_ylim()[1]*0.95, 'Stage 1', ha='center', fontsize=8, alpha=0.7)
    ax.text((stage1_end + stage2_end)/2, ax.get_ylim()[1]*0.95, 'Stage 2', ha='center', fontsize=8, alpha=0.7)
    if len(epochs) > stage2_end:
        ax.text((stage2_end + max(epochs))/2, ax.get_ylim()[1]*0.95, 'Fine-tune', ha='center', fontsize=8, alpha=0.7)

plt.tight_layout()
plt.savefig(MODEL_PATH / 'training_history.png', dpi=150)
plt.show()

print(f"‚úì Training history plot saved to {MODEL_PATH}/training_history.png")

## 11. Inference Example

In [None]:
def predict(model, tokenizer, text, device):
    """
    Run inference on a single query.
    
    Returns:
        dict with tools, intents (per selected tool), and entities
    """
    model.eval()
    
    # Tokenize
    encoding = tokenizer(
        text,
        max_length=CONFIG['max_length'],
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, stage='all')
    
    # Tool predictions (threshold 0.5)
    tool_probs = torch.sigmoid(outputs['tool_logits']).cpu().numpy()[0]
    predicted_tools = [TOOLS[i] for i, p in enumerate(tool_probs) if p > 0.5]
    
    # Per-tool intent predictions (only for selected tools)
    predicted_intents = {}
    
    if 'character_data' in predicted_tools:
        char_intent_idx = outputs['character_intent_logits'].argmax(dim=-1).item()
        char_intent = IDX_TO_INTENT_PER_TOOL['character_data'][str(char_intent_idx)]
        predicted_intents['character_data'] = char_intent
    
    if 'session_notes' in predicted_tools:
        sess_intent_idx = outputs['session_intent_logits'].argmax(dim=-1).item()
        sess_intent = IDX_TO_INTENT_PER_TOOL['session_notes'][str(sess_intent_idx)]
        predicted_intents['session_notes'] = sess_intent
    
    if 'rulebook' in predicted_tools:
        rule_intent_idx = outputs['rulebook_intent_logits'].argmax(dim=-1).item()
        rule_intent = IDX_TO_INTENT_PER_TOOL['rulebook'][str(rule_intent_idx)]
        predicted_intents['rulebook'] = rule_intent
    
    # NER predictions (CRF decode)
    ner_preds = model.decode_ner(outputs['ner_emissions'], attention_mask)[0]
    
    # Extract entities from predictions
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    entities = []
    current_entity = None
    current_type = None
    
    for i, (token, tag_idx) in enumerate(zip(tokens, ner_preds)):
        # Skip special tokens
        if token in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
            continue
        
        tag = IDX_TO_TAG[tag_idx]
        
        if tag.startswith('B-'):
            # Save previous entity if exists
            if current_entity:
                entities.append({'text': current_entity.strip(), 'type': current_type})
            # Start new entity
            current_entity = token.replace('‚ñÅ', ' ').replace('##', '')
            current_type = tag[2:]
        elif tag.startswith('I-') and current_type == tag[2:]:
            # Continue current entity
            current_entity += token.replace('‚ñÅ', ' ').replace('##', '')
        else:
            # End current entity
            if current_entity:
                entities.append({'text': current_entity.strip(), 'type': current_type})
            current_entity = None
            current_type = None
    
    # Don't forget last entity
    if current_entity:
        entities.append({'text': current_entity.strip(), 'type': current_type})
    
    return {
        'query': text,
        'tools': predicted_tools,
        'tool_probs': {TOOLS[i]: float(p) for i, p in enumerate(tool_probs)},
        'intents': predicted_intents,
        'entities': entities
    }

In [None]:
# Test with example queries
test_queries = [
    "What's my armor class and how does Shield work?",
    "Did we meet anyone named the Archmage in our last session?",
    "How much damage does Fireball do and can I cast it at 5th level?",
    "What are my spell slots and did we find any magic items last time?",
    "Tell me about the Beholder's legendary actions",
    "What happened when we fought the dragon?",
]

print("=" * 70)
print("EXAMPLE PREDICTIONS")
print("=" * 70)

for query in test_queries:
    result = predict(model, tokenizer, query, device)
    
    print(f"\nüìù Query: \"{query}\"")
    print(f"   üîß Tools: {result['tools']}")
    print(f"   üìä Tool Probs: {', '.join(f'{k}: {v:.2f}' for k, v in result['tool_probs'].items())}")
    print(f"   üéØ Intents: {result['intents']}")
    print(f"   üè∑Ô∏è Entities: {result['entities']}")
    print("-" * 70)

## 12. Export Inference Module

In [None]:
# Create a standalone inference module
inference_code = '''"""
Two-Stage Joint DeBERTa Model - Inference Module

Usage:
    from inference import TwoStageInference
    
    model = TwoStageInference("path/to/model")
    result = model.predict("What's my AC and how does Fireball work?")
    print(result)
"""

import json
import torch
import torch.nn as nn
from transformers import DebertaV2TokenizerFast, DebertaV2Model, DebertaV2PreTrainedModel


class TwoStageJointModel(DebertaV2PreTrainedModel):
    """Two-Stage Joint DeBERTa-v3 Model."""
    
    def __init__(self, config, num_tools=3, num_ner_tags=25, 
                 num_character_intents=10, num_session_intents=20, num_rulebook_intents=30):
        super().__init__(config)
        
        self.num_tools = num_tools
        self.num_ner_tags = num_ner_tags
        self.num_character_intents = num_character_intents
        self.num_session_intents = num_session_intents
        self.num_rulebook_intents = num_rulebook_intents
        
        self.deberta = DebertaV2Model(config)
        
        classifier_dropout = (
            config.classifier_dropout 
            if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        
        # Stage 1 heads
        self.tool_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size, num_tools)
        )
        
        self.ner_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_ner_tags)
        )
        
        # Stage 2 heads
        self.character_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_character_intents)
        )
        
        self.session_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_session_intents)
        )
        
        self.rulebook_intent_head = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_rulebook_intents)
        )
        
        self.post_init()
        
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        
        sequence_output = outputs.last_hidden_state
        cls_output = self.dropout(sequence_output[:, 0, :])
        
        return {
            'tool_logits': self.tool_classifier(cls_output),
            'ner_logits': self.ner_classifier(sequence_output),
            'character_intent_logits': self.character_intent_head(cls_output),
            'session_intent_logits': self.session_intent_head(cls_output),
            'rulebook_intent_logits': self.rulebook_intent_head(cls_output),
        }
    
    def decode_ner(self, ner_logits, attention_mask):
        """Decode NER predictions using argmax."""
        return ner_logits.argmax(dim=-1).tolist()


class TwoStageInference:
    """Inference wrapper for the Two-Stage Joint Model."""
    
    def __init__(self, model_path, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Load mappings
        with open(f'{model_path}/label_mappings.json') as f:
            self.mappings = json.load(f)
        
        self.tools = list(self.mappings['tool_to_idx'].keys())
        self.idx_to_tag = {int(k): v for k, v in self.mappings['idx_to_tag'].items()}
        self.idx_to_intent_per_tool = self.mappings['idx_to_intent_per_tool']
        
        # Load config
        with open(f'{model_path}/training_config.json') as f:
            self.config = json.load(f)
        
        # Load model
        self.tokenizer = DebertaV2TokenizerFast.from_pretrained(model_path)
        self.model = TwoStageJointModel.from_pretrained(
            model_path,
            num_tools=len(self.tools),
            num_ner_tags=len(self.mappings['tag_to_idx']),
            num_character_intents=self.mappings['num_intents_per_tool']['character_data'],
            num_session_intents=self.mappings['num_intents_per_tool']['session_notes'],
            num_rulebook_intents=self.mappings['num_intents_per_tool']['rulebook']
        )
        self.model.to(self.device)
        self.model.eval()
    
    def predict(self, text):
        """Run inference on a query."""
        encoding = self.tokenizer(
            text,
            max_length=self.config['max_length'],
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Tools
        tool_probs = torch.sigmoid(outputs['tool_logits']).cpu().numpy()[0]
        predicted_tools = [self.tools[i] for i, p in enumerate(tool_probs) if p > 0.5]
        
        # Intents (per selected tool)
        intents = {}
        if 'character_data' in predicted_tools:
            idx = outputs['character_intent_logits'].argmax(dim=-1).item()
            intents['character_data'] = self.idx_to_intent_per_tool['character_data'][str(idx)]
        if 'session_notes' in predicted_tools:
            idx = outputs['session_intent_logits'].argmax(dim=-1).item()
            intents['session_notes'] = self.idx_to_intent_per_tool['session_notes'][str(idx)]
        if 'rulebook' in predicted_tools:
            idx = outputs['rulebook_intent_logits'].argmax(dim=-1).item()
            intents['rulebook'] = self.idx_to_intent_per_tool['rulebook'][str(idx)]
        
        # Entities
        ner_preds = self.model.decode_ner(outputs['ner_logits'], attention_mask)[0]
        entities = self._extract_entities(input_ids[0], ner_preds)
        
        return {
            'tools': predicted_tools,
            'intents': intents,
            'entities': entities
        }
    
    def _extract_entities(self, input_ids, ner_preds):
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        entities = []
        current_entity = None
        current_type = None
        
        for token, tag_idx in zip(tokens, ner_preds):
            if token in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
                continue
            
            tag = self.idx_to_tag[tag_idx]
            
            if tag.startswith('B-'):
                if current_entity:
                    entities.append({'text': current_entity.strip(), 'type': current_type})
                current_entity = token.replace('‚ñÅ', ' ').replace('##', '')
                current_type = tag[2:]
            elif tag.startswith('I-') and current_type == tag[2:]:
                current_entity += token.replace('‚ñÅ', ' ').replace('##', '')
            else:
                if current_entity:
                    entities.append({'text': current_entity.strip(), 'type': current_type})
                current_entity = None
                current_type = None
        
        if current_entity:
            entities.append({'text': current_entity.strip(), 'type': current_type})
        
        return entities


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 2:
        print("Usage: python inference.py 'Your query here'")
        sys.exit(1)
    
    model = TwoStageInference(".")
    result = model.predict(sys.argv[1])
    print(json.dumps(result, indent=2))
'''

with open(MODEL_PATH / 'inference.py', 'w') as f:
    f.write(inference_code)

print(f"‚úì Inference module saved to {MODEL_PATH}/inference.py")

## Summary

This notebook trained a **Two-Stage Joint DeBERTa-v3 Model** for D&D query understanding:

### Architecture
- **Shared Encoder**: DeBERTa-v3-base (86M params)
- **Stage 1 Heads**: Tool Classification (3 classes) + NER with CRF (25 BIO tags)
- **Stage 2 Heads**: Per-tool Intent Classification (10 + 20 + 30 = 60 intents)

### Training Strategy
1. **Stage 1 (Epochs 1-3)**: Train Tool + NER heads only (intent heads frozen)
2. **Stage 2 (Epochs 4-10)**: Joint training of all heads
3. **Fine-tuning (Epochs 11-15)**: Lower learning rate for refinement

### Key Differences from Flat Architecture
- ‚úÖ Per-tool intent heads instead of single 61-class head
- ‚úÖ Gated intent prediction (only compute for selected tools)
- ‚úÖ Masked intent loss (ignore non-selected tools)
- ‚úÖ Staged training for better tool/NER convergence first

### Model Saved To
- `models/two_stage_joint/`
  - `pytorch_model.bin` - Model weights
  - `config.json` - Model config
  - `tokenizer.json` - Tokenizer
  - `label_mappings.json` - All label mappings
  - `training_config.json` - Training hyperparameters
  - `test_results.json` - Final evaluation metrics
  - `training_history.png` - Training curves
  - `inference.py` - Standalone inference module