# Joint DeBERTa-v3 Model for Intent Classification & NER

This notebook implements a **single joint model** that performs both:
1. **Multi-label Tool Classification** - Which tools to use (character_data, rulebook, session_notes)
2. **Intent Classification** - What specific intent for each selected tool (61 total intents)
3. **Named Entity Recognition (NER)** - Extract entities with BIO tagging

## Architecture
```
                    DeBERTa-v3 Encoder
                          │
            ┌─────────────┼─────────────┐
            │             │             │
        [CLS] token   [CLS] token   All tokens
            │             │             │
      Tool Head     Intent Head    NER Head
     (3 sigmoid)   (61 softmax)   (BIO tags)
            │             │             │
    Multi-label      Masked by      CRF Layer
    BCE Loss       tool selection
```

## Why DeBERTa-v3?
- **Disentangled attention** - better captures position and content separately
- **Replaced Token Detection (RTD)** - more efficient pre-training than MLM
- **0.9-3.6% improvement** over RoBERTa on NLU benchmarks
- **Better data efficiency** - crucial for our ~10K synthetic samples

In [None]:
# Install dependencies
!pip install -q transformers datasets torch accelerate seqeval scikit-learn wandb
!pip install -q pytorch-crf

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set paths
DRIVE_PATH = '/content/drive/MyDrive/574-assignment'
DATA_PATH = f'{DRIVE_PATH}/data'
MODEL_PATH = f'{DRIVE_PATH}/models/joint_deberta'

In [None]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    DebertaV2Tokenizer,
    DebertaV2Model,
    DebertaV2PreTrainedModel,
    get_linear_schedule_with_warmup,
    AutoConfig
)
from torchcrf import CRF
from sklearn.metrics import f1_score, precision_score, recall_score
from seqeval.metrics import f1_score as seqeval_f1, classification_report
from collections import defaultdict
import numpy as np
from tqdm.auto import tqdm
import os

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration

In [None]:
# Model configuration
CONFIG = {
    # Model - using deberta-v3-base for A100 (86M params)
    # Alternatives: 'microsoft/deberta-v3-small' (44M), 'microsoft/deberta-v3-large' (304M)
    'model_name': 'microsoft/deberta-v3-base',
    
    # Training
    'batch_size': 32,  # Can increase on A100
    'learning_rate': 2e-5,
    'num_epochs': 10,
    'warmup_ratio': 0.1,
    'weight_decay': 0.01,
    'max_length': 128,
    'gradient_accumulation_steps': 1,
    
    # Loss weights (tune these)
    'tool_loss_weight': 1.0,
    'intent_loss_weight': 1.0,
    'ner_loss_weight': 1.0,
    
    # Early stopping
    'patience': 3,
    
    # Dropout
    'classifier_dropout': 0.1,
}

# Tool and intent mappings
TOOLS = ['character_data', 'rulebook', 'session_notes']
TOOL_TO_IDX = {tool: idx for idx, tool in enumerate(TOOLS)}

# Intent mappings per tool
CHARACTER_INTENTS = [
    'ability_scores', 'combat_info', 'skills_proficiencies', 'class_features',
    'spell_slots', 'inventory', 'background', 'race_traits', 'level_info', 'full_character'
]

SESSION_INTENTS = [
    'recent_events', 'npc_interactions', 'location_info', 'quest_status',
    'party_decisions', 'combat_history', 'treasure_loot', 'plot_threads',
    'character_development', 'time_tracking', 'relationship_status', 'faction_standing',
    'unresolved_mysteries', 'player_notes', 'dm_notes', 'session_summary',
    'next_session_hooks', 'world_lore', 'house_rules', 'campaign_timeline'
]

RULEBOOK_INTENTS = [
    'spell_details', 'spell_list_query', 'class_info', 'subclass_info',
    'race_info', 'feat_info', 'condition_rules', 'combat_rules',
    'skill_rules', 'ability_check_rules', 'saving_throw_rules', 'death_rules',
    'rest_rules', 'movement_rules', 'cover_rules', 'action_rules',
    'reaction_rules', 'equipment_info', 'weapon_info', 'armor_info',
    'magic_item_info', 'monster_info', 'multiclassing_rules', 'spellcasting_rules',
    'concentration_rules', 'ritual_rules', 'component_rules', 'aoe_rules',
    'range_rules', 'duration_rules', 'damage_type_rules'
]

ALL_INTENTS = CHARACTER_INTENTS + SESSION_INTENTS + RULEBOOK_INTENTS
INTENT_TO_IDX = {intent: idx for idx, intent in enumerate(ALL_INTENTS)}
IDX_TO_INTENT = {idx: intent for intent, idx in INTENT_TO_IDX.items()}

# Intent to tool mapping
INTENT_TO_TOOL = {}
for intent in CHARACTER_INTENTS:
    INTENT_TO_TOOL[intent] = 'character_data'
for intent in SESSION_INTENTS:
    INTENT_TO_TOOL[intent] = 'session_notes'
for intent in RULEBOOK_INTENTS:
    INTENT_TO_TOOL[intent] = 'rulebook'

# NER tags
NER_TAGS = ['O', 'B-SPELL', 'I-SPELL', 'B-CLASS', 'I-CLASS', 'B-RACE', 'I-RACE',
            'B-CREATURE', 'I-CREATURE', 'B-ITEM', 'I-ITEM', 'B-LOCATION', 'I-LOCATION',
            'B-ABILITY', 'I-ABILITY', 'B-SKILL', 'I-SKILL', 'B-CONDITION', 'I-CONDITION',
            'B-DAMAGE_TYPE', 'I-DAMAGE_TYPE', 'B-FEAT', 'I-FEAT', 'B-BACKGROUND', 'I-BACKGROUND']
TAG_TO_IDX = {tag: idx for idx, tag in enumerate(NER_TAGS)}
IDX_TO_TAG = {idx: tag for tag, idx in TAG_TO_IDX.items()}

print(f"Number of tools: {len(TOOLS)}")
print(f"Number of intents: {len(ALL_INTENTS)}")
print(f"Number of NER tags: {len(NER_TAGS)}")

## 2. Load Data

In [None]:
def load_dataset(split='train'):
    """Load dataset from JSON file."""
    path = f'{DATA_PATH}/{split}.json'
    with open(path, 'r') as f:
        data = json.load(f)
    print(f"Loaded {len(data)} {split} samples")
    return data

train_data = load_dataset('train')
val_data = load_dataset('val')
test_data = load_dataset('test')

In [None]:
# Examine sample structure
print("Sample structure:")
print(json.dumps(train_data[0], indent=2))

## 3. Dataset Class

In [None]:
class JointDataset(Dataset):
    """Dataset for joint tool/intent classification and NER."""
    
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        text = sample['query']
        tools = sample['tools']
        intents = sample['intents']
        bio_tags = sample['bio_tags']
        
        # Tokenize with word IDs for aligning BIO tags
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_offsets_mapping=True,
            return_special_tokens_mask=True
        )
        
        # Tool labels (multi-label)
        tool_labels = torch.zeros(len(TOOLS))
        for tool in tools:
            if tool in TOOL_TO_IDX:
                tool_labels[TOOL_TO_IDX[tool]] = 1
        
        # Intent labels - multi-label for all selected intents
        intent_labels = torch.zeros(len(ALL_INTENTS))
        for tool, intent in intents.items():
            if intent in INTENT_TO_IDX:
                intent_labels[INTENT_TO_IDX[intent]] = 1
        
        # Align BIO tags with tokens
        ner_labels = self._align_labels(
            text, bio_tags, encoding, encoding['offset_mapping'][0]
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'tool_labels': tool_labels,
            'intent_labels': intent_labels,
            'ner_labels': ner_labels,
            'tools': tools,  # Keep for evaluation
            'intents': intents
        }
    
    def _align_labels(self, text, bio_tags, encoding, offset_mapping):
        """Align word-level BIO tags to subword tokens."""
        # Split text into words
        words = text.split()
        
        # Handle mismatch between words and bio_tags
        if len(bio_tags) != len(words):
            # Pad or truncate bio_tags
            if len(bio_tags) < len(words):
                bio_tags = bio_tags + ['O'] * (len(words) - len(bio_tags))
            else:
                bio_tags = bio_tags[:len(words)]
        
        # Create word boundaries
        word_boundaries = []
        current_pos = 0
        for word in words:
            start = text.find(word, current_pos)
            if start == -1:
                start = current_pos
            end = start + len(word)
            word_boundaries.append((start, end))
            current_pos = end
        
        # Map each token to a word
        aligned_labels = []
        special_tokens_mask = encoding['special_tokens_mask'][0].tolist()
        
        for idx, (start, end) in enumerate(offset_mapping.tolist()):
            if special_tokens_mask[idx] or (start == 0 and end == 0):
                # Special token or padding - use -100 (ignored in loss)
                aligned_labels.append(-100)
            else:
                # Find which word this token belongs to
                word_idx = None
                for w_idx, (w_start, w_end) in enumerate(word_boundaries):
                    if start >= w_start and start < w_end:
                        word_idx = w_idx
                        break
                
                if word_idx is not None and word_idx < len(bio_tags):
                    tag = bio_tags[word_idx]
                    # If this is a continuation token (not first subword), convert B- to I-
                    if start > word_boundaries[word_idx][0] and tag.startswith('B-'):
                        tag = 'I-' + tag[2:]
                    aligned_labels.append(TAG_TO_IDX.get(tag, TAG_TO_IDX['O']))
                else:
                    aligned_labels.append(TAG_TO_IDX['O'])
        
        return torch.tensor(aligned_labels, dtype=torch.long)


def collate_fn(batch):
    """Custom collate function."""
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'tool_labels': torch.stack([x['tool_labels'] for x in batch]),
        'intent_labels': torch.stack([x['intent_labels'] for x in batch]),
        'ner_labels': torch.stack([x['ner_labels'] for x in batch]),
        'tools': [x['tools'] for x in batch],
        'intents': [x['intents'] for x in batch]
    }

## 4. Joint Model Architecture

In [None]:
class JointDeBERTaModel(DebertaV2PreTrainedModel):
    """
    Joint DeBERTa-v3 model for:
    1. Multi-label tool classification (from [CLS])
    2. Multi-label intent classification (from [CLS])
    3. NER/slot filling with CRF (from all tokens)
    """
    
    def __init__(self, config, num_tools=3, num_intents=61, num_ner_tags=25):
        super().__init__(config)
        
        self.num_tools = num_tools
        self.num_intents = num_intents
        self.num_ner_tags = num_ner_tags
        
        # Shared encoder
        self.deberta = DebertaV2Model(config)
        
        # Dropout
        classifier_dropout = (
            config.classifier_dropout 
            if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None
            else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        
        # Tool classification head (multi-label)
        self.tool_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size, num_tools)
        )
        
        # Intent classification head (multi-label)
        self.intent_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size, num_intents)
        )
        
        # NER head with CRF
        self.ner_classifier = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size // 2),
            nn.GELU(),
            nn.Dropout(classifier_dropout),
            nn.Linear(config.hidden_size // 2, num_ner_tags)
        )
        self.crf = CRF(num_ner_tags, batch_first=True)
        
        # Initialize weights
        self.post_init()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        tool_labels=None,
        intent_labels=None,
        ner_labels=None,
        return_dict=True
    ):
        # Get encoder outputs
        outputs = self.deberta(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        sequence_output = outputs.last_hidden_state  # (batch, seq_len, hidden)
        cls_output = sequence_output[:, 0, :]  # (batch, hidden)
        cls_output = self.dropout(cls_output)
        
        # Tool classification
        tool_logits = self.tool_classifier(cls_output)  # (batch, num_tools)
        
        # Intent classification
        intent_logits = self.intent_classifier(cls_output)  # (batch, num_intents)
        
        # NER classification
        ner_emissions = self.ner_classifier(sequence_output)  # (batch, seq_len, num_tags)
        
        # Compute losses if labels provided
        loss = None
        tool_loss = None
        intent_loss = None
        ner_loss = None
        
        if tool_labels is not None:
            tool_loss = F.binary_cross_entropy_with_logits(tool_logits, tool_labels)
        
        if intent_labels is not None:
            intent_loss = F.binary_cross_entropy_with_logits(intent_logits, intent_labels)
        
        if ner_labels is not None:
            # Create mask for valid tokens (not -100)
            valid_mask = (ner_labels != -100)
            # Replace -100 with 0 for CRF (will be masked anyway)
            ner_labels_clean = ner_labels.clone()
            ner_labels_clean[~valid_mask] = 0
            # CRF loss (negative log likelihood)
            ner_loss = -self.crf(ner_emissions, ner_labels_clean, mask=valid_mask, reduction='mean')
        
        if tool_loss is not None and intent_loss is not None and ner_loss is not None:
            loss = (
                CONFIG['tool_loss_weight'] * tool_loss +
                CONFIG['intent_loss_weight'] * intent_loss +
                CONFIG['ner_loss_weight'] * ner_loss
            )
        
        return {
            'loss': loss,
            'tool_loss': tool_loss,
            'intent_loss': intent_loss,
            'ner_loss': ner_loss,
            'tool_logits': tool_logits,
            'intent_logits': intent_logits,
            'ner_emissions': ner_emissions,
        }
    
    def decode_ner(self, ner_emissions, attention_mask):
        """Decode NER predictions using CRF."""
        return self.crf.decode(ner_emissions, mask=attention_mask.bool())

## 5. Initialize Model and Data

In [None]:
# Load tokenizer and config
print(f"Loading model: {CONFIG['model_name']}")
tokenizer = DebertaV2Tokenizer.from_pretrained(CONFIG['model_name'])
config = AutoConfig.from_pretrained(CONFIG['model_name'])
config.classifier_dropout = CONFIG['classifier_dropout']

# Initialize model
model = JointDeBERTaModel.from_pretrained(
    CONFIG['model_name'],
    config=config,
    num_tools=len(TOOLS),
    num_intents=len(ALL_INTENTS),
    num_ner_tags=len(NER_TAGS),
    ignore_mismatched_sizes=True
)
model = model.to(device)

# Print model size
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Create datasets
train_dataset = JointDataset(train_data, tokenizer, CONFIG['max_length'])
val_dataset = JointDataset(val_data, tokenizer, CONFIG['max_length'])
test_dataset = JointDataset(test_data, tokenizer, CONFIG['max_length'])

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'],
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=CONFIG['batch_size'],
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 6. Training Setup

In [None]:
# Optimizer with layer-wise learning rate decay
def get_optimizer_grouped_parameters(model, learning_rate, weight_decay):
    """Group parameters with different learning rates."""
    no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight']
    
    # Different learning rates for encoder vs classifiers
    encoder_params = []
    classifier_params = []
    
    for name, param in model.named_parameters():
        if 'deberta' in name:
            encoder_params.append((name, param))
        else:
            classifier_params.append((name, param))
    
    optimizer_grouped_parameters = [
        # Encoder with weight decay
        {
            'params': [p for n, p in encoder_params if not any(nd in n for nd in no_decay)],
            'lr': learning_rate,
            'weight_decay': weight_decay
        },
        # Encoder without weight decay
        {
            'params': [p for n, p in encoder_params if any(nd in n for nd in no_decay)],
            'lr': learning_rate,
            'weight_decay': 0.0
        },
        # Classifiers with higher learning rate
        {
            'params': [p for n, p in classifier_params if not any(nd in n for nd in no_decay)],
            'lr': learning_rate * 10,  # Higher LR for classification heads
            'weight_decay': weight_decay
        },
        {
            'params': [p for n, p in classifier_params if any(nd in n for nd in no_decay)],
            'lr': learning_rate * 10,
            'weight_decay': 0.0
        },
    ]
    
    return optimizer_grouped_parameters

optimizer = torch.optim.AdamW(
    get_optimizer_grouped_parameters(model, CONFIG['learning_rate'], CONFIG['weight_decay'])
)

# Learning rate scheduler with warmup
total_steps = len(train_loader) * CONFIG['num_epochs'] // CONFIG['gradient_accumulation_steps']
warmup_steps = int(total_steps * CONFIG['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

## 7. Evaluation Functions

In [None]:
def evaluate(model, dataloader, device):
    """Evaluate model on validation/test set."""
    model.eval()
    
    all_tool_preds = []
    all_tool_labels = []
    all_intent_preds = []
    all_intent_labels = []
    all_ner_preds = []
    all_ner_labels = []
    
    total_loss = 0
    total_tool_loss = 0
    total_intent_loss = 0
    total_ner_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            tool_labels = batch['tool_labels'].to(device)
            intent_labels = batch['intent_labels'].to(device)
            ner_labels = batch['ner_labels'].to(device)
            
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                tool_labels=tool_labels,
                intent_labels=intent_labels,
                ner_labels=ner_labels
            )
            
            total_loss += outputs['loss'].item()
            total_tool_loss += outputs['tool_loss'].item()
            total_intent_loss += outputs['intent_loss'].item()
            total_ner_loss += outputs['ner_loss'].item()
            
            # Tool predictions (threshold 0.5)
            tool_preds = (torch.sigmoid(outputs['tool_logits']) > 0.5).float()
            all_tool_preds.append(tool_preds.cpu())
            all_tool_labels.append(tool_labels.cpu())
            
            # Intent predictions (threshold 0.5 for multi-label)
            intent_preds = (torch.sigmoid(outputs['intent_logits']) > 0.5).float()
            all_intent_preds.append(intent_preds.cpu())
            all_intent_labels.append(intent_labels.cpu())
            
            # NER predictions (CRF decode)
            ner_preds = model.decode_ner(outputs['ner_emissions'], attention_mask)
            
            # Convert to tag sequences for seqeval
            for i, (pred_seq, label_seq) in enumerate(zip(ner_preds, ner_labels.cpu().tolist())):
                pred_tags = []
                label_tags = []
                for pred_tag, label_tag in zip(pred_seq, label_seq):
                    if label_tag != -100:  # Only consider valid tokens
                        pred_tags.append(IDX_TO_TAG[pred_tag])
                        label_tags.append(IDX_TO_TAG[label_tag])
                if pred_tags:
                    all_ner_preds.append(pred_tags)
                    all_ner_labels.append(label_tags)
    
    # Compute metrics
    num_batches = len(dataloader)
    
    # Tool metrics
    all_tool_preds = torch.cat(all_tool_preds, dim=0).numpy()
    all_tool_labels = torch.cat(all_tool_labels, dim=0).numpy()
    tool_f1 = f1_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    tool_precision = precision_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    tool_recall = recall_score(all_tool_labels, all_tool_preds, average='micro', zero_division=0)
    
    # Exact match for tools
    tool_exact_match = np.mean(np.all(all_tool_preds == all_tool_labels, axis=1))
    
    # Intent metrics
    all_intent_preds = torch.cat(all_intent_preds, dim=0).numpy()
    all_intent_labels = torch.cat(all_intent_labels, dim=0).numpy()
    intent_f1 = f1_score(all_intent_labels, all_intent_preds, average='micro', zero_division=0)
    intent_precision = precision_score(all_intent_labels, all_intent_preds, average='micro', zero_division=0)
    intent_recall = recall_score(all_intent_labels, all_intent_preds, average='micro', zero_division=0)
    
    # Exact match for intents
    intent_exact_match = np.mean(np.all(all_intent_preds == all_intent_labels, axis=1))
    
    # NER metrics (using seqeval)
    ner_f1 = seqeval_f1(all_ner_labels, all_ner_preds, average='micro', zero_division=0)
    
    metrics = {
        'loss': total_loss / num_batches,
        'tool_loss': total_tool_loss / num_batches,
        'intent_loss': total_intent_loss / num_batches,
        'ner_loss': total_ner_loss / num_batches,
        'tool_f1': tool_f1,
        'tool_precision': tool_precision,
        'tool_recall': tool_recall,
        'tool_exact_match': tool_exact_match,
        'intent_f1': intent_f1,
        'intent_precision': intent_precision,
        'intent_recall': intent_recall,
        'intent_exact_match': intent_exact_match,
        'ner_f1': ner_f1,
    }
    
    return metrics, (all_ner_labels, all_ner_preds)

## 8. Training Loop

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, device, accumulation_steps=1):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_tool_loss = 0
    total_intent_loss = 0
    total_ner_loss = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    optimizer.zero_grad()
    
    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        tool_labels = batch['tool_labels'].to(device)
        intent_labels = batch['intent_labels'].to(device)
        ner_labels = batch['ner_labels'].to(device)
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            tool_labels=tool_labels,
            intent_labels=intent_labels,
            ner_labels=ner_labels
        )
        
        loss = outputs['loss'] / accumulation_steps
        loss.backward()
        
        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += outputs['loss'].item()
        total_tool_loss += outputs['tool_loss'].item()
        total_intent_loss += outputs['intent_loss'].item()
        total_ner_loss += outputs['ner_loss'].item()
        
        progress_bar.set_postfix({
            'loss': f"{outputs['loss'].item():.4f}",
            'tool': f"{outputs['tool_loss'].item():.4f}",
            'intent': f"{outputs['intent_loss'].item():.4f}",
            'ner': f"{outputs['ner_loss'].item():.4f}"
        })
    
    num_batches = len(dataloader)
    return {
        'loss': total_loss / num_batches,
        'tool_loss': total_tool_loss / num_batches,
        'intent_loss': total_intent_loss / num_batches,
        'ner_loss': total_ner_loss / num_batches
    }

In [None]:
# Training loop with early stopping
best_val_f1 = 0
patience_counter = 0
training_history = []

print("Starting training...")
print("=" * 60)

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    print("-" * 40)
    
    # Train
    train_metrics = train_epoch(
        model, train_loader, optimizer, scheduler, device, 
        CONFIG['gradient_accumulation_steps']
    )
    
    # Evaluate
    val_metrics, _ = evaluate(model, val_loader, device)
    
    # Combined F1 score (average of all three tasks)
    combined_f1 = (val_metrics['tool_f1'] + val_metrics['intent_f1'] + val_metrics['ner_f1']) / 3
    
    # Log metrics
    history_entry = {
        'epoch': epoch + 1,
        'train_loss': train_metrics['loss'],
        'val_loss': val_metrics['loss'],
        'val_tool_f1': val_metrics['tool_f1'],
        'val_tool_exact_match': val_metrics['tool_exact_match'],
        'val_intent_f1': val_metrics['intent_f1'],
        'val_intent_exact_match': val_metrics['intent_exact_match'],
        'val_ner_f1': val_metrics['ner_f1'],
        'val_combined_f1': combined_f1
    }
    training_history.append(history_entry)
    
    print(f"\nTrain Loss: {train_metrics['loss']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f}")
    print(f"Val Tool F1: {val_metrics['tool_f1']:.4f} (Exact Match: {val_metrics['tool_exact_match']:.4f})")
    print(f"Val Intent F1: {val_metrics['intent_f1']:.4f} (Exact Match: {val_metrics['intent_exact_match']:.4f})")
    print(f"Val NER F1: {val_metrics['ner_f1']:.4f}")
    print(f"Combined F1: {combined_f1:.4f}")
    
    # Save best model
    if combined_f1 > best_val_f1:
        best_val_f1 = combined_f1
        patience_counter = 0
        
        # Save model
        os.makedirs(MODEL_PATH, exist_ok=True)
        model.save_pretrained(MODEL_PATH)
        tokenizer.save_pretrained(MODEL_PATH)
        
        # Save config
        with open(f'{MODEL_PATH}/training_config.json', 'w') as f:
            json.dump(CONFIG, f, indent=2)
        
        # Save label mappings
        mappings = {
            'tool_to_idx': TOOL_TO_IDX,
            'intent_to_idx': INTENT_TO_IDX,
            'tag_to_idx': TAG_TO_IDX,
            'intent_to_tool': INTENT_TO_TOOL
        }
        with open(f'{MODEL_PATH}/label_mappings.json', 'w') as f:
            json.dump(mappings, f, indent=2)
        
        print(f"\n✓ New best model saved! Combined F1: {combined_f1:.4f}")
    else:
        patience_counter += 1
        print(f"\n✗ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
    
    # Early stopping
    if patience_counter >= CONFIG['patience']:
        print(f"\nEarly stopping triggered after {epoch + 1} epochs")
        break

print("\n" + "=" * 60)
print(f"Training complete! Best Combined F1: {best_val_f1:.4f}")

## 9. Final Evaluation on Test Set

In [None]:
# Load best model for testing
print("Loading best model for final evaluation...")
model = JointDeBERTaModel.from_pretrained(
    MODEL_PATH,
    num_tools=len(TOOLS),
    num_intents=len(ALL_INTENTS),
    num_ner_tags=len(NER_TAGS)
)
model = model.to(device)

# Evaluate on test set
test_metrics, (ner_labels, ner_preds) = evaluate(model, test_loader, device)

print("\n" + "=" * 60)
print("TEST SET RESULTS")
print("=" * 60)
print(f"\nTool Classification:")
print(f"  F1 Score: {test_metrics['tool_f1']:.4f}")
print(f"  Precision: {test_metrics['tool_precision']:.4f}")
print(f"  Recall: {test_metrics['tool_recall']:.4f}")
print(f"  Exact Match: {test_metrics['tool_exact_match']:.4f}")

print(f"\nIntent Classification:")
print(f"  F1 Score: {test_metrics['intent_f1']:.4f}")
print(f"  Precision: {test_metrics['intent_precision']:.4f}")
print(f"  Recall: {test_metrics['intent_recall']:.4f}")
print(f"  Exact Match: {test_metrics['intent_exact_match']:.4f}")

print(f"\nNER (Entity Extraction):")
print(f"  F1 Score: {test_metrics['ner_f1']:.4f}")

# Detailed NER report
print("\nDetailed NER Classification Report:")
print(classification_report(ner_labels, ner_preds, zero_division=0))

In [None]:
# Save test results
test_results = {
    'model_name': CONFIG['model_name'],
    'test_metrics': test_metrics,
    'training_history': training_history,
    'best_val_f1': best_val_f1,
    'config': CONFIG
}

with open(f'{MODEL_PATH}/test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

print(f"\nTest results saved to {MODEL_PATH}/test_results.json")

## 10. Training History Visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = [h['epoch'] for h in training_history]

# Loss
axes[0, 0].plot(epochs, [h['train_loss'] for h in training_history], 'b-', label='Train')
axes[0, 0].plot(epochs, [h['val_loss'] for h in training_history], 'r-', label='Validation')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Tool F1
axes[0, 1].plot(epochs, [h['val_tool_f1'] for h in training_history], 'g-', marker='o')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Tool Classification F1')
axes[0, 1].grid(True)

# Intent F1
axes[1, 0].plot(epochs, [h['val_intent_f1'] for h in training_history], 'm-', marker='o')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].set_title('Intent Classification F1')
axes[1, 0].grid(True)

# NER F1
axes[1, 1].plot(epochs, [h['val_ner_f1'] for h in training_history], 'c-', marker='o')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('F1 Score')
axes[1, 1].set_title('NER F1 Score')
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(f'{MODEL_PATH}/training_history.png', dpi=150)
plt.show()

## 11. Inference Example

In [None]:
def predict(model, tokenizer, text, device):
    """Run inference on a single query."""
    model.eval()
    
    # Tokenize
    encoding = tokenizer(
        text,
        max_length=CONFIG['max_length'],
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    # Tool predictions
    tool_probs = torch.sigmoid(outputs['tool_logits']).cpu().numpy()[0]
    predicted_tools = [TOOLS[i] for i, p in enumerate(tool_probs) if p > 0.5]
    
    # Intent predictions
    intent_probs = torch.sigmoid(outputs['intent_logits']).cpu().numpy()[0]
    predicted_intents = {}
    for tool in predicted_tools:
        # Get intents for this tool
        tool_intents = [intent for intent, t in INTENT_TO_TOOL.items() if t == tool]
        tool_intent_probs = [(intent, intent_probs[INTENT_TO_IDX[intent]]) 
                             for intent in tool_intents]
        best_intent = max(tool_intent_probs, key=lambda x: x[1])
        predicted_intents[tool] = best_intent[0]
    
    # NER predictions
    ner_preds = model.decode_ner(outputs['ner_emissions'], attention_mask)[0]
    
    # Extract entities
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    entities = []
    current_entity = None
    current_type = None
    
    for i, (token, tag_idx) in enumerate(zip(tokens, ner_preds)):
        if token in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
            continue
        
        tag = IDX_TO_TAG[tag_idx]
        
        if tag.startswith('B-'):
            if current_entity:
                entities.append({'text': current_entity, 'type': current_type})
            current_entity = token.replace('▁', ' ').strip()
            current_type = tag[2:]
        elif tag.startswith('I-') and current_type == tag[2:]:
            current_entity += token.replace('▁', ' ')
        else:
            if current_entity:
                entities.append({'text': current_entity, 'type': current_type})
            current_entity = None
            current_type = None
    
    if current_entity:
        entities.append({'text': current_entity, 'type': current_type})
    
    return {
        'tools': predicted_tools,
        'tool_probs': {TOOLS[i]: float(p) for i, p in enumerate(tool_probs)},
        'intents': predicted_intents,
        'entities': entities
    }


# Test with example queries
test_queries = [
    "What's my armor class and how does Shield work?",
    "Did we meet anyone named the Archmage in our last session?",
    "How much damage does Fireball do and can I cast it at 5th level?",
    "What are my spell slots and did we find any magic items last time?"
]

print("\nExample Predictions:")
print("=" * 60)

for query in test_queries:
    result = predict(model, tokenizer, query, device)
    print(f"\nQuery: {query}")
    print(f"Tools: {result['tools']}")
    print(f"Tool probabilities: {result['tool_probs']}")
    print(f"Intents: {result['intents']}")
    print(f"Entities: {result['entities']}")
    print("-" * 40)

## 12. Export Model for Inference

In [None]:
# Create a standalone inference class
inference_code = '''
import json
import torch
from transformers import DebertaV2Tokenizer, DebertaV2Model, DebertaV2PreTrainedModel
from torchcrf import CRF
import torch.nn as nn

class JointDeBERTaInference:
    """Inference wrapper for the joint DeBERTa model."""
    
    def __init__(self, model_path, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Load mappings
        with open(f'{model_path}/label_mappings.json') as f:
            mappings = json.load(f)
        
        self.tool_to_idx = mappings['tool_to_idx']
        self.idx_to_tool = {int(v): k for k, v in self.tool_to_idx.items()}
        self.intent_to_idx = mappings['intent_to_idx']
        self.idx_to_intent = {int(v): k for k, v in self.intent_to_idx.items()}
        self.tag_to_idx = mappings['tag_to_idx']
        self.idx_to_tag = {int(v): k for k, v in self.tag_to_idx.items()}
        self.intent_to_tool = mappings['intent_to_tool']
        
        # Load config
        with open(f'{model_path}/training_config.json') as f:
            self.config = json.load(f)
        
        # Load model and tokenizer
        self.tokenizer = DebertaV2Tokenizer.from_pretrained(model_path)
        self.model = JointDeBERTaModel.from_pretrained(
            model_path,
            num_tools=len(self.tool_to_idx),
            num_intents=len(self.intent_to_idx),
            num_ner_tags=len(self.tag_to_idx)
        )
        self.model.to(self.device)
        self.model.eval()
    
    def predict(self, text):
        """Run inference on a query."""
        encoding = self.tokenizer(
            text,
            max_length=self.config['max_length'],
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Parse predictions
        tool_probs = torch.sigmoid(outputs['tool_logits']).cpu().numpy()[0]
        predicted_tools = [self.idx_to_tool[i] for i, p in enumerate(tool_probs) if p > 0.5]
        
        intent_probs = torch.sigmoid(outputs['intent_logits']).cpu().numpy()[0]
        predicted_intents = {}
        for tool in predicted_tools:
            tool_intents = [intent for intent, t in self.intent_to_tool.items() if t == tool]
            tool_intent_probs = [(intent, intent_probs[self.intent_to_idx[intent]]) 
                                 for intent in tool_intents]
            best_intent = max(tool_intent_probs, key=lambda x: x[1])
            predicted_intents[tool] = best_intent[0]
        
        # NER
        ner_preds = self.model.decode_ner(outputs['ner_emissions'], attention_mask)[0]
        entities = self._extract_entities(input_ids[0], ner_preds)
        
        return {
            'tools': predicted_tools,
            'intents': predicted_intents,
            'entities': entities
        }
    
    def _extract_entities(self, input_ids, ner_preds):
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        entities = []
        current_entity = None
        current_type = None
        
        for token, tag_idx in zip(tokens, ner_preds):
            if token in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>', '<pad>']:
                continue
            tag = self.idx_to_tag[tag_idx]
            if tag.startswith('B-'):
                if current_entity:
                    entities.append({'text': current_entity.strip(), 'type': current_type})
                current_entity = token.replace('▁', ' ')
                current_type = tag[2:]
            elif tag.startswith('I-') and current_type == tag[2:]:
                current_entity += token.replace('▁', ' ')
            else:
                if current_entity:
                    entities.append({'text': current_entity.strip(), 'type': current_type})
                current_entity = None
                current_type = None
        if current_entity:
            entities.append({'text': current_entity.strip(), 'type': current_type})
        return entities
'''

with open(f'{MODEL_PATH}/inference.py', 'w') as f:
    f.write(inference_code)

print(f"Inference code saved to {MODEL_PATH}/inference.py")

## Summary

This notebook trains a **joint DeBERTa-v3 model** that performs three tasks simultaneously:

1. **Multi-label Tool Classification** - Predicts which tools (character_data, rulebook, session_notes) are needed
2. **Intent Classification** - Predicts the specific intent for each selected tool (61 total intents)
3. **Named Entity Recognition** - Extracts D&D entities using BIO tagging with CRF

### Key Features:
- **Shared encoder** learns representations beneficial for all three tasks
- **CRF layer** for NER ensures valid BIO tag sequences
- **Layer-wise learning rate decay** - higher LR for classification heads
- **Early stopping** based on combined F1 score
- **Google Drive persistence** for Colab

### Model saved to:
- `{DRIVE_PATH}/models/joint_deberta/`