## EMAIL vs CALENDAR QUERY CLASSIFIER


## Overview

This notebook implements a **binary text classifier** using **DistilBERT** to categorize user queries into:

- **Class 0**: Gmail-related queries (emails, attachments, inbox, etc.)
- **Class 1**: Calendar-related queries (meetings, events, appointments, etc.)

---

## Approach

1. **Fine-tune** a pre-trained DistilBERT model on labeled query dataset
2. Use **PyTorch** for training with AdamW optimizer and learning rate scheduling
3. **Evaluate** on hold-out test set with accuracy, precision, recall, F1 metrics
4. Implement inference function: `predict_class(query: str) -> int`
5. **BONUS Features**:
   - Extract time ranges from calendar queries
   - Extract people/entities from queries

---

##  Model Selection: Why DistilBERT?

| Feature | Benefit |
|---------|---------|
| **Size** | 66% smaller than BERT-base |
| **Performance** | Maintains ~97% of BERT's accuracy |
| **Speed** | 60% faster inference |
| **Use Case** | Perfect balance for production deployment |

---

## CONSTRAINTS:

| Parameter | Value |
|-----------|-------|
| **Device** | MPS (Mac) / CUDA (GPU) / CPU |
| **Max Sequence Length** | 128 tokens |
| **Training Batch Size** | 16 |
| **Inference Batch Size** | 32-128 |
| **Epochs** | 2-3 (small dataset), 2 (large dataset) |
| **Learning Rate** | 2e-5 |
| **Optimizer** | AdamW with linear warmup |


## Implementation

- Environment Setup & Import Libraries
- Data Loading & Exploration
- Data Validation & Quality Checks
- Tokenization & Dataset Preparation
- Model Initialization
- Model Training Loop
- Evaluation on Test Set
- Error Analysis
- Helper Functions
- Stress-Test on Model
- Testing the Model with Queries (Submission Criteria)

In [None]:

# ============================================================================
# SECTION 1: ENVIRONMENT SETUP
# ============================================================================

import os
import random
import math
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from pathlib import Path
from tqdm import tqdm
import re
import calendar
from datetime import date, datetime
from functools import lru_cache

# Set device
device = torch.device(
    "mps" if torch.backends.mps.is_available() 
    else ("cuda" if torch.cuda.is_available() else "cpu")
)
print("Using device:", device)

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if device.type in ("cuda", "mps"):
    torch.cuda.manual_seed_all(SEED)

# Suppress transformers warnings
import transformers
transformers.logging.set_verbosity_error()


In [None]:
# Define paths
DATA_DIR = Path("../data/set1")
MODEL_DIR = Path("../models/distilbert")
MODEL_SAVE= Path("../models/trained/distilbert_trained/")
MODEL_SAVE.mkdir(parents=True, exist_ok=True)

# Model configuration
MODEL_NAME = "distilbert-base-uncased"
MAX_LEN = 128
BATCH_SIZE = 16
EPOCHS = 2
LEARNING_RATE = 2e-5

print(f"\nConfiguration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Max length: {MAX_LEN}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")


In [None]:
# ============================================================================
# SECTION 2: DATA LOADING & EXPLORATION
# ============================================================================

print("\n" + "="*80)
print("LOADING DATASET")
print("="*80)

# Load pre-split datasets
train_df = pd.read_csv(DATA_DIR / "mail_calendar_dataset_train.csv")
val_df = pd.read_csv(DATA_DIR / "mail_calendar_dataset_val.csv")
test_df = pd.read_csv(DATA_DIR / "mail_calendar_dataset_test.csv")

# Combine for full dataset statistics
df = pd.concat([train_df, val_df, test_df], ignore_index=True)

print(f"\nDataset sizes:")
print(f"  Total: {len(df):,}")
print(f"  Train: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
print(f"  Val:   {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
print(f"  Test:  {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")

# Class distribution
print(f"\nClass distribution:")
print(df["label"].value_counts())
print(f"  Gmail (0):    {(df['label'] == 0).sum():,} ({(df['label'] == 0).sum()/len(df)*100:.1f}%)")
print(f"  Calendar (1): {(df['label'] == 1).sum():,} ({(df['label'] == 1).sum()/len(df)*100:.1f}%)")

# Display sample queries
def show_samples(df, label_value, text_col='query', n=5, seed=SEED):
    subset = df.loc[df['label'] == label_value, text_col]
    k = min(n, len(subset))
    label_name = "Gmail" if label_value == 0 else "Calendar"
    print(f"\n{label_name} samples ({k} shown):")
    print("-" * 60)
    if k == 0:
        print("  (no rows)")
    else:
        for query in subset.sample(n=k, random_state=seed):
            print(f"  • {query}")

show_samples(df, 0)
show_samples(df, 1)

# Query length distribution
df["len"] = df["query"].str.len()
print(f"\nQuery length statistics:")
print(f"  Mean: {df['len'].mean():.1f} characters")
print(f"  Median: {df['len'].median():.0f} characters")
print(f"  Min: {df['len'].min()}")
print(f"  Max: {df['len'].max()}")

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.histplot(df["len"], bins=30, kde=True)
plt.title("Query Length Distribution")
plt.xlabel("Character count")
plt.ylabel("Frequency")



In [None]:


# ============================================================================
# SECTION 3: DATA VALIDATION & QUALITY CHECKS
# ============================================================================

print("\n" + "="*80)
print("DATA QUALITY CHECKS")
print("="*80)

# Check for data leakage between splits
train_queries = set(train_df['query'].tolist())
val_queries = set(val_df['query'].tolist())
test_queries = set(test_df['query'].tolist())

train_val_overlap = train_queries.intersection(val_queries)
train_test_overlap = train_queries.intersection(test_queries)
val_test_overlap = val_queries.intersection(test_queries)

print(f"\nData leakage check:")
print(f"  Train-Val overlap:  {len(train_val_overlap)} queries")
print(f"  Train-Test overlap: {len(train_test_overlap)} queries")
print(f"  Val-Test overlap:   {len(val_test_overlap)} queries")

if train_test_overlap:
    print(f"\n WARNING: Data leakage detected!")
    print(f"  Sample overlapping queries: {list(train_test_overlap)[:3]}")
else:
    print(f"\n No data leakage detected")



In [None]:

# ============================================================================
# SECTION 4: TOKENIZATION & DATASET PREPARATION
# ============================================================================

print("\n" + "="*80)
print("PREPARING DATASETS")
print("="*80)

from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"\nTokenizer loaded: {MODEL_NAME}")

# Custom Dataset class
class QueryDataset(Dataset):
    """PyTorch Dataset for query classification."""
    
    def __init__(self, df, tokenizer, max_len):
        self.texts = df["query"].tolist()
        self.labels = df["label"].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = int(self.labels[idx])
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Create datasets
train_ds = QueryDataset(train_df, tokenizer, MAX_LEN)
val_ds = QueryDataset(val_df, tokenizer, MAX_LEN)
test_ds = QueryDataset(test_df, tokenizer, MAX_LEN)

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

print(f"\nDatasets created:")
print(f"  Train: {len(train_ds):,} samples, {len(train_loader)} batches")
print(f"  Val:   {len(val_ds):,} samples, {len(val_loader)} batches")
print(f"  Test:  {len(test_ds):,} samples, {len(test_loader)} batches")

# Demo tokenization
sample_text = train_df.iloc[0]['query']
sample_encoding = tokenizer(sample_text, truncation=True, padding='max_length', max_length=MAX_LEN)
print(f"\nSample tokenization:")
print(f"  Text: '{sample_text}'")
print(f"  Token IDs shape: {len(sample_encoding['input_ids'])}")


In [None]:

# ============================================================================
# SECTION 5: MODEL INITIALIZATION
# ============================================================================

print("\n" + "="*80)
print("MODEL INITIALIZATION")
print("="*80)

from transformers import AutoModelForSequenceClassification

# Load or initialize model
num_labels = 2  # Binary classification
if MODEL_DIR.exists() and any(MODEL_DIR.iterdir()):
    print(f"\nLoading fine-tuned model from {MODEL_DIR}")
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
else:
    print(f"\nInitializing new model: {MODEL_NAME}")
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, 
        num_labels=num_labels
    )

model.to(device)
print(f" Model loaded to {device}")

# Model architecture summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")


In [None]:
# ============================================================================
# SECTION 6: TRAINING LOOP
# ============================================================================

print("\n" + "="*80)
print("TRAINING")
print("="*80)

from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score

# Initialize optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

print(f"\nTraining configuration:")
print(f"  Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"  Scheduler: Linear warmup")
print(f"  Total steps: {total_steps:,}")

best_val_acc = -1.0

for epoch in range(1, EPOCHS + 1):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch}/{EPOCHS}")
    print(f"{'='*60}")
    
    # Training phase
    model.train()
    total_loss = 0.0
    train_preds = []
    train_labels = []
    
    train_pbar = tqdm(train_loader, desc=f"Training")
    for batch in train_pbar:
        optimizer.zero_grad()
        
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_loss += loss.item()
        preds = torch.argmax(outputs.logits.detach(), dim=-1)
        train_preds.extend(preds.cpu().tolist())
        train_labels.extend(labels.cpu().tolist())
        
        train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Training metrics
    train_loss = total_loss / len(train_loader)
    train_acc = accuracy_score(train_labels, train_preds)
    print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    
    # Validation phase
    model.eval()
    val_preds = []
    val_labels = []
    
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc="Validation")
        for batch in val_pbar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=-1)
            
            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels.cpu().tolist())
    
    # Validation metrics
    val_acc = accuracy_score(val_labels, val_preds)
    print(f"Val Acc: {val_acc:.4f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model.save_pretrained(MODEL_SAVE)
        tokenizer.save_pretrained(MODEL_SAVE)
        print(f"Saved best checkpoint (val_acc={val_acc:.4f}) to {MODEL_DIR}")

print(f"\n{'='*80}")
print(f"Training complete! Best validation accuracy: {best_val_acc:.4f}")
print(f"{'='*80}")


In [None]:
# ============================================================================
# SECTION 7: EVALUATION ON TEST SET
# ============================================================================

print("\n" + "="*80)
print("TEST SET EVALUATION")
print("="*80)

from sklearn.metrics import (
    accuracy_score, 
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)

# Load best model
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.to(device)
model.eval()

# Run predictions on test set
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=-1)
        
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

# Calculate metrics
acc = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average='macro'
)

print(f"\n{'='*60}")
print(f"TEST RESULTS")
print(f"{'='*60}")
print(f"Accuracy:  {acc:.4f} ({acc*100:.2f}%)")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"{'='*60}")


# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title('Confusion Matrix (Counts)')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks([0.5, 1.5], ['Gmail', 'Calendar'])
plt.yticks([0.5, 1.5], ['Gmail', 'Calendar'], rotation=0)

plt.subplot(1, 2, 2)
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Greens', cbar=False)
plt.title('Confusion Matrix (Normalized)')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.xticks([0.5, 1.5], ['Gmail', 'Calendar'])
plt.yticks([0.5, 1.5], ['Gmail', 'Calendar'], rotation=0)

plt.tight_layout()
plt.show()


In [None]:
# ============================================================================
# SECTION 8: ERROR ANALYSIS
# ============================================================================

print("\n" + "="*80)
print("ERROR ANALYSIS")
print("="*80)

# Create results dataframe
test_results = pd.DataFrame({
    'query': test_df['query'].tolist(),
    'true_label': test_df['label'].tolist(),
    'pred_label': all_preds
})
test_results['correct'] = test_results['true_label'] == test_results['pred_label']

correct_count = test_results['correct'].sum()
wrong_count = (~test_results['correct']).sum()

print(f"\nPrediction summary:")
print(f"  Correct:        {correct_count} / {len(test_results)}")
print(f"  Misclassified:  {wrong_count}")

# Show misclassified examples
if wrong_count > 0:
    print(f"\n MISCLASSIFIED QUERIES ({wrong_count}):")
    print("-" * 80)
    for _, row in test_results[~test_results['correct']].iterrows():
        true_label = "Gmail" if row['true_label'] == 0 else "Calendar"
        pred_label = "Gmail" if row['pred_label'] == 0 else "Calendar"
        print(f"  True: {true_label:8} | Pred: {pred_label:8} | {row['query']}")
else:
    print("\n Perfect classification! No errors on test set.")
    print("\nSample predictions:")
    print("-" * 80)
    for _, row in test_results.sample(10).iterrows():
        label_name = "Gmail" if row['true_label'] == 0 else "Calendar"
        print(f"  {label_name:8} | {row['query']}")



In [None]:
# ============================================================================
# SECTION 9: HELPER FUNCTIONS
# ============================================================================

def extract_time_range(query: str) -> dict:
    """
    Extract time range from calendar queries.
    Handles: "June 2025", "May", "next Tuesday", date ranges, etc.
    
    Returns:
        dict with 'from' and 'to' keys (ISO format), or None if no time found
    """
    query_lower = query.lower()
    current_year = datetime.now().year
    
    # Month names mapping
    MONTHS = {m.lower(): i for i, m in enumerate(calendar.month_name) if m}
    MONTHS.update({m.lower(): i for i, m in enumerate(calendar.month_abbr) if m})
    
    # Pattern 1: "Month YYYY" (e.g., "June 2025")
    pattern_month_year = r'\b([a-z]+)\s+(\d{4})\b'
    match = re.search(pattern_month_year, query_lower)
    if match:
        month_name, year_str = match.groups()
        month_num = MONTHS.get(month_name)
        if month_num and month_num > 0:
            year = int(year_str)
            last_day = calendar.monthrange(year, month_num)[1]
            return {
                "from": date(year, month_num, 1).isoformat(),
                "to": date(year, month_num, last_day).isoformat()
            }
    
    # Pattern 2: "Month" alone (e.g., "May", "June") → assume current year
    pattern_month_only = r'\b(' + '|'.join(MONTHS.keys()) + r')\b'
    match = re.search(pattern_month_only, query_lower)
    if match:
        month_name = match.group(1)
        month_num = MONTHS.get(month_name)
        if month_num and month_num > 0:
            last_day = calendar.monthrange(current_year, month_num)[1]
            return {
                "from": date(current_year, month_num, 1).isoformat(),
                "to": date(current_year, month_num, last_day).isoformat()
            }
    
    # Pattern 3: "next/this <weekday>" (e.g., "next Tuesday")
    weekdays = ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']
    for day in weekdays:
        if f'next {day}' in query_lower or f'this {day}' in query_lower:
            today = datetime.now().date()
            return {
                "from": today.isoformat(),
                "to": today.isoformat(),
                "note": f"Relative day detected. Compute exact date in production."
            }
    
    # Pattern 4: Date range like "May 1 to May 31"
    pattern_date_range = r'([a-z]+)\s+(\d{1,2})\s+(?:to|-)\s+([a-z]+)\s+(\d{1,2})'
    match = re.search(pattern_date_range, query_lower)
    if match:
        month1, day1, month2, day2 = match.groups()
        m1 = MONTHS.get(month1)
        m2 = MONTHS.get(month2)
        if m1 and m2 and m1 > 0 and m2 > 0:
            return {
                "from": date(current_year, m1, int(day1)).isoformat(),
                "to": date(current_year, m2, int(day2)).isoformat()
            }
    
    # Pattern 5: Try dateparser if available
    try:
        import dateparser.search
        results = dateparser.search.search_dates(query)
        if results:
            dt = results[0][1]
            iso_date = dt.date().isoformat()
            return {"from": iso_date, "to": iso_date}
    except (ImportError, Exception):
        pass
    
    return None


def extract_people(query: str) -> list:
    """
    Extract people names from queries using rule-based patterns.
    Optimized to avoid duplicates and handle titles correctly.
    
    Returns:
        list of person names found in the query
    """
    found = []
    
    # Pattern-based extraction (order matters - most specific first!)
    patterns = [
        # Titles with names (MUST come first to avoid duplicates)
        (r'\b(Dr\.|Mr\.|Ms\.|Mrs\.)\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)', 2),
        # "from John" or "from John Doe"
        (r'\bfrom\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)', 1),
        # "with Sarah"
        (r'\bwith\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)', 1),
        # "to Alice"
        (r'\bto\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)', 1),
        # "Sarah about"
        (r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)?)\s+about', 1),
    ]
    
    for pattern, group_idx in patterns:
        for match in re.finditer(pattern, query):
            if group_idx == 2:
                # For titled names, combine title + name
                name = f"{match.group(1)} {match.group(2)}".strip()
            else:
                # For other patterns, just use the name
                name = match.group(1).strip()
            found.append(name)
    
    # Deduplicate while preserving order
    seen = set()
    result = []
    for name in found:
        key = name.lower()
        if key not in seen:
            seen.add(key)
            result.append(name)
    
    return result



In [None]:
# ============================================================================
# SECTION 10: STRESS - TEST ON MODEL 
# ============================================================================

import pandas as pd
from tqdm import tqdm
import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the large test dataset
large_test_df = pd.read_csv("../data/set2/email_calendar_dataset_v2.csv")
print(f"Large test dataset size: {len(large_test_df):,}")
print(f"Label distribution:")
print(f"  Gmail (0): {(large_test_df['label'] == 0).sum():,}")
print(f"  Calendar (1): {(large_test_df['label'] == 1).sum():,}")
MAX_LEN = 128
# Batch prediction on large dataset
def batch_predict_large(queries, model, tokenizer, batch_size=128):
    """Predict in batches for efficiency"""
    model.eval()
    all_preds = []
    
    for i in tqdm(range(0, len(queries), batch_size), desc="Predicting"):
        batch = queries[i:i+batch_size]
        enc = tokenizer(batch, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}
        
        with torch.no_grad():
            outputs = model(**enc)
            preds = torch.argmax(outputs.logits, dim=1).cpu().tolist()
            all_preds.extend(preds)
    
    return all_preds

# Run predictions
print("\nRunning predictions on larger unknown test queries...")
queries = large_test_df['query'].tolist()
labels = large_test_df['label'].tolist()
MODEL_DIR = "../models/trained/distilbert_trained/"
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()
preds = batch_predict_large(queries, model, tokenizer, batch_size=128)

# Calculate metrics
accuracy = accuracy_score(labels, preds)
precision = precision_score(labels, preds)
recall = recall_score(labels, preds)
f1 = f1_score(labels, preds)

print(f"\n" + "="*60)
print(f"STRESS TEST RESULTS (on {len(large_test_df):,} queries)")
print(f"="*60)
print(f"Accuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"="*60)


In [None]:
# ============================================================================
# SECTION 11: TESTING THE MODEL WITH QUERIES [SUBMISSION CRITERIA]
# ============================================================================


@lru_cache(maxsize=1)
def _get_inference_model():
    """Load model and tokenizer once (cached for efficiency)"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
    model.to(device)
    model.eval()  # ← Only call .eval() on MODEL, not tokenizer
    return model, tokenizer

def predict_class(query: str) -> int:
    """
    Classify a query as Gmail (0) or Calendar (1)
    
    Args:
        query: User's query string
        
    Returns:
        int: 0 for Gmail, 1 for Calendar
    """
    model, tokenizer = _get_inference_model()  # ← Correct order: model, tokenizer
    
    # Tokenize
    encoding = tokenizer(
        query,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=MAX_LEN
    )
    encoding = {k: v.to(device) for k, v in encoding.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits
        prediction = torch.argmax(logits, dim=-1).item()
    
    return int(prediction)

def analyze_query(query: str) -> dict:
    """
    End-to-end query analysis with all features.
    
    Returns:
        {
            "query": str,
            "label_id": int (0=Mail, 1=Calendar),
            "label_name": str,
            "confidence": float (0-1),
            "time_range": dict or None,
            "people": list,
            "summary": str (human-readable)
        }
    """
    # 1) Classify
    label_id = predict_class(query)
    label_name = "Mail" if label_id == 0 else "Calendar"
    
    # 2) Extract time range (best for Calendar)
    time_range = extract_time_range(query) if label_id == 1 else None
    
    # 3) Extract people
    people = extract_people(query)
    
    # 4) Build human-readable summary
    summary = f"{label_name} query"
    if people:
        summary += f" | People: {', '.join(people)}"
    if time_range:
        summary += f" | Time: {time_range['from']} to {time_range['to']}"
    
    return {
        "query": query,
        "label_id": label_id,
        "label_name": label_name,
        "time_range": time_range,
        "people": people,
        "summary": summary
    }

# Test unified analysis
unified_test_queries = [
    "Find emails with PDF attachments",
    "Show me all events scheduled for next Tuesday",
    "Search for unread messages in my inbox",
    "When is my next meeting with the design team?",
    "Find messages with subject line 'quarterly report'",
    "Show me all-day events in May",
    "Find emails from Sarah about project proposal",
    "Find appointments with Dr. Johnson",
    "Find my meetings for June 2025",
]

print("\nTesting analyze_query() unified function:\n")

for q in unified_test_queries:
    result = analyze_query(q)
    print(f"Query: {q}")
    print(f"  Label:      [{result['label_id']}] {result['label_name']}")
    print(f"  Time range: {result['time_range']}")
    print(f"  People:     {result['people']}")
    print()


In [None]:
# Device setup
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))

# Model path (update this to your actual model directory)
MODEL_DIR = "../models/trained/distilbert_trained/"
MAX_LEN = 128

@lru_cache(maxsize=1)
def _load_model_and_tokenizer():
    """Load model and tokenizer once (cached for efficiency)"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
    model.to(device)
    model.eval()
    return tokenizer, model

def predict_class(test_query: str) -> int:
    """
    Classify a query as Gmail (0) or Calendar (1)
    
    Args:
        test_query: User's query string
        
    Returns:
        int: 0 for Gmail, 1 for Calendar
    """
    tokenizer, model = _load_model_and_tokenizer()
    
    # Tokenize
    encoding = tokenizer(
        test_query,
        return_tensors='pt',
        truncation=True,
        padding=True,
        max_length=MAX_LEN
    )
    
    # Move to device
    encoding = {k: v.to(device) for k, v in encoding.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits
        prediction = torch.argmax(logits, dim=-1).item()
    
    return int(prediction)

# Test it
if __name__ == "__main__":
    print(predict_class("Find emails with PDF attachments"))  
    print(predict_class("Show me all events scheduled for next Tuesday"))  
