# M3 RoBERTa + GRU Implementation
## One-Turn-Ahead Frustration Prediction

**Goal**: Beat M2's Macro-F1 (0.7396) while achieving production latency (≤15ms)

**Architecture**: RoBERTa embeddings → GRU temporal modeling → Classification


In [31]:
# Setup and imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaModel, RobertaTokenizer
import json
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report, f1_score, roc_auc_score
from tqdm import tqdm
import time
import os
import warnings
warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [32]:
# Model Configuration
CONFIG = {
    'model_name': 'roberta-base',
    'max_length': 512,
    'context_window': 3,
    'gru_hidden_size': 128,
    'gru_num_layers': 1,
    'dropout': 0.1,
    'batch_size': 16,
    'learning_rate': 2e-5,
    'epochs': 3,
    'weight_decay': 0.01,
    'class_weight_ratio': 13.7,
    'patience': 2
}

print("M3 Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")


M3 Configuration:
  model_name: roberta-base
  max_length: 512
  context_window: 3
  gru_hidden_size: 128
  gru_num_layers: 1
  dropout: 0.1
  batch_size: 16
  learning_rate: 2e-05
  epochs: 3
  weight_decay: 0.01
  class_weight_ratio: 13.7
  patience: 2


In [33]:
# M3 RoBERTa + GRU Model Architecture
class RobertaGRU(nn.Module):
    def __init__(self, config):
        super(RobertaGRU, self).__init__()
        self.config = config
        
        # RoBERTa for turn embeddings
        self.roberta = RobertaModel.from_pretrained(config['model_name'])
        self.roberta_hidden_size = self.roberta.config.hidden_size
        
        # GRU for temporal modeling
        self.gru = nn.GRU(
            input_size=self.roberta_hidden_size,
            hidden_size=config['gru_hidden_size'],
            num_layers=config['gru_num_layers'],
            batch_first=True,
            dropout=config['dropout'] if config['gru_num_layers'] > 1 else 0
        )
        
        # Classification head
        self.dropout = nn.Dropout(config['dropout'])
        self.classifier = nn.Linear(config['gru_hidden_size'], 1)
        
    def forward(self, input_ids, attention_mask):
        batch_size, seq_len, max_len = input_ids.shape
        
        # Reshape for RoBERTa processing
        input_ids = input_ids.view(-1, max_len)
        attention_mask = attention_mask.view(-1, max_len)
        
        # Get RoBERTa embeddings for each turn
        roberta_outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use pooler output (CLS token representation)
        turn_embeddings = roberta_outputs.pooler_output
        
        # Reshape back to sequence format
        turn_embeddings = turn_embeddings.view(batch_size, seq_len, self.roberta_hidden_size)
        
        # Process through GRU
        gru_output, _ = self.gru(turn_embeddings)
        
        # Use final hidden state
        final_hidden = gru_output[:, -1, :]
        
        # Classification
        output = self.dropout(final_hidden)
        logits = self.classifier(output)
        
        return logits

print("M3 RoBERTa + GRU model architecture defined")


M3 RoBERTa + GRU model architecture defined


In [34]:
# Dataset class for M3 temporal data
class EmoWOZTemporalDataset(Dataset):
    def __init__(self, data_path, tokenizer, config):
        self.tokenizer = tokenizer
        self.max_length = config['max_length']
        self.context_window = config['context_window']
        
        # Load data with error handling for corrupted lines
        self.data = []
        skipped_lines = 0
        
        with open(data_path, 'r') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                try:
                    data_item = json.loads(line)
                    self.data.append(data_item)
                except json.JSONDecodeError as e:
                    print(f"Warning: Skipping invalid JSON at line {line_num} in {data_path}: {e}")
                    skipped_lines += 1
                    continue
        
        print(f"Loaded {len(self.data)} samples from {data_path}")
        if skipped_lines > 0:
            print(f"Skipped {skipped_lines} invalid lines")
        
    def __len__(self):
        return len(self.data)
    
    def parse_context_string(self, context_str):
        """Parse context string into individual turns"""
        import re
        
        # Split by USER and SYSTEM tags
        turns = []
        
        # Find all [USER] and [SYSTEM] tags and their positions
        pattern = r'\[(USER|SYSTEM)\]'
        matches = list(re.finditer(pattern, context_str))
        
        for i, match in enumerate(matches):
            speaker = match.group(1)
            start_pos = match.end()
            
            # Find the end position (start of next tag or end of string)
            if i + 1 < len(matches):
                end_pos = matches[i + 1].start()
                text = context_str[start_pos:end_pos].strip()
            else:
                text = context_str[start_pos:].strip()
            
            if text:  # Only add non-empty turns
                turns.append({
                    'speaker': speaker,
                    'text': text
                })
        
        return turns
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Parse context string into turns
        context_str = item['context']
        context_turns = self.parse_context_string(context_str)
        
        # Limit to context window (take last N turns)
        if len(context_turns) > self.context_window:
            context_turns = context_turns[-self.context_window:]
        
        # Tokenize each turn separately
        turn_tokens = []
        turn_masks = []
        
        for turn in context_turns:
            text = turn['text']
            encoded = self.tokenizer(
                text,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            turn_tokens.append(encoded['input_ids'].squeeze(0))
            turn_masks.append(encoded['attention_mask'].squeeze(0))
        
        # Pad sequence if needed
        while len(turn_tokens) < self.context_window:
            # Add padding turn
            padding_tokens = torch.zeros(self.max_length, dtype=torch.long)
            padding_mask = torch.zeros(self.max_length, dtype=torch.long)
            turn_tokens.append(padding_tokens)
            turn_masks.append(padding_mask)
        
        # Stack into tensors
        input_ids = torch.stack(turn_tokens)
        attention_mask = torch.stack(turn_masks)
        
        # Label
        label = torch.tensor(item['label'], dtype=torch.float)
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'label': label
        }

print("Temporal dataset class defined")


Temporal dataset class defined


In [35]:
# Load tokenizer and create datasets
tokenizer = RobertaTokenizer.from_pretrained(CONFIG['model_name'])

# Create datasets
train_dataset = EmoWOZTemporalDataset('../data/train.jsonl', tokenizer, CONFIG)
val_dataset = EmoWOZTemporalDataset('../data/val.jsonl', tokenizer, CONFIG)
test_dataset = EmoWOZTemporalDataset('../data/test.jsonl', tokenizer, CONFIG)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")


Loaded 25738 samples from ../data/train.jsonl
Skipped 1 invalid lines
Loaded 7409 samples from ../data/val.jsonl
Loaded 7534 samples from ../data/test.jsonl
Train batches: 1609
Validation batches: 464
Test batches: 471


In [36]:
# Initialize model and training components
model = RobertaGRU(CONFIG).to(device)

# Loss function with class weights
pos_weight = torch.tensor(CONFIG['class_weight_ratio']).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])

# Model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.1f} MB")


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 124,990,593
Trainable parameters: 124,990,593
Model size: 476.8 MB


In [37]:
# Training and evaluation functions
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(input_ids, attention_mask)
        # Squeeze only the last dimension, preserve batch dimension
        outputs = outputs.squeeze(-1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def evaluate(model, eval_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(eval_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask)
            # Squeeze only the last dimension, preserve batch dimension
            outputs = outputs.squeeze(-1)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            
            # Get predictions
            preds = torch.sigmoid(outputs).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    avg_loss = total_loss / len(eval_loader)
    preds_binary = (np.array(all_preds) > 0.5).astype(int)
    
    macro_f1 = f1_score(all_labels, preds_binary, average='macro')
    auc = roc_auc_score(all_labels, all_preds)
    
    return avg_loss, macro_f1, auc, np.array(all_labels), np.array(all_preds), preds_binary

print("Training and evaluation functions defined")


Training and evaluation functions defined


In [38]:
# Training loop
print("Starting M3 training...")
print("="*50)

best_macro_f1 = 0
patience_counter = 0
training_history = []

start_time = time.time()

for epoch in range(CONFIG['epochs']):
    print(f"\nEpoch {epoch + 1}/{CONFIG['epochs']}")
    
    # Training
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss, val_macro_f1, val_auc, _, _, _ = evaluate(model, val_loader, criterion, device)
    
    # Log results
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Macro-F1: {val_macro_f1:.4f}")
    print(f"Val AUC: {val_auc:.4f}")
    
    # Save training history
    training_history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_macro_f1': val_macro_f1,
        'val_auc': val_auc
    })
    
    # Early stopping
    if val_macro_f1 > best_macro_f1:
        best_macro_f1 = val_macro_f1
        patience_counter = 0
        # Save best model
        os.makedirs('../checkpoints/M3_roberta_gru', exist_ok=True)
        torch.save(model.state_dict(), '../checkpoints/M3_roberta_gru/best_model.pt')
        print(f"✅ New best model saved! Macro-F1: {best_macro_f1:.4f}")
    else:
        patience_counter += 1
        print(f"⏸️ No improvement. Patience: {patience_counter}/{CONFIG['patience']}")
    
    if patience_counter >= CONFIG['patience']:
        print(f"Early stopping triggered after {epoch + 1} epochs")
        break

training_time = time.time() - start_time
print(f"\n🎉 Training completed in {training_time:.1f} seconds ({training_time/60:.1f} minutes)")
print(f"Best validation Macro-F1: {best_macro_f1:.4f}")


Starting M3 training...

Epoch 1/3


Training: 100%|██████████| 1609/1609 [15:31<00:00,  1.73it/s]
Evaluating: 100%|██████████| 464/464 [01:33<00:00,  4.97it/s]


Train Loss: 0.8781
Val Loss: 0.7599
Val Macro-F1: 0.7180
Val AUC: 0.8837
✅ New best model saved! Macro-F1: 0.7180

Epoch 2/3


Training: 100%|██████████| 1609/1609 [15:32<00:00,  1.73it/s]
Evaluating: 100%|██████████| 464/464 [01:33<00:00,  4.98it/s]


Train Loss: 0.7995
Val Loss: 0.7650
Val Macro-F1: 0.7212
Val AUC: 0.8858
✅ New best model saved! Macro-F1: 0.7212

Epoch 3/3


Training: 100%|██████████| 1609/1609 [15:32<00:00,  1.73it/s]
Evaluating: 100%|██████████| 464/464 [01:33<00:00,  4.97it/s]


Train Loss: 0.7968
Val Loss: 0.7132
Val Macro-F1: 0.7227
Val AUC: 0.8882
✅ New best model saved! Macro-F1: 0.7227

🎉 Training completed in 3078.0 seconds (51.3 minutes)
Best validation Macro-F1: 0.7227


In [39]:
# Load best model and test evaluation
model.load_state_dict(torch.load('../checkpoints/M3_roberta_gru/best_model.pt', weights_only=True))
print("✅ Best model loaded for final evaluation")

# Final test evaluation
test_loss, test_macro_f1, test_auc, test_labels, test_probs, test_preds = evaluate(model, test_loader, criterion, device)

print("\n📊 FINAL M3 TEST RESULTS")
print("="*40)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Macro-F1: {test_macro_f1:.4f}")
print(f"Test AUC: {test_auc:.4f}")

# Detailed classification report
print("\nDetailed Classification Report:")
target_names = ['Not Frustrated', 'Will Be Frustrated']
print(classification_report(test_labels, test_preds, target_names=target_names, digits=4))


✅ Best model loaded for final evaluation


Evaluating: 100%|██████████| 471/471 [01:34<00:00,  4.96it/s]


📊 FINAL M3 TEST RESULTS
Test Loss: 0.8602
Test Macro-F1: 0.7408
Test AUC: 0.8768

Detailed Classification Report:
                    precision    recall  f1-score   support

    Not Frustrated     0.9807    0.9007    0.9390      6930
Will Be Frustrated     0.4115    0.7964    0.5426       604

          accuracy                         0.8924      7534
         macro avg     0.6961    0.8485    0.7408      7534
      weighted avg     0.9350    0.8924    0.9072      7534






In [40]:
# Latency benchmarking
print("\n⚡ LATENCY BENCHMARKING")
print("="*30)

model.eval()
latencies = []

# Warm-up
for i, batch in enumerate(test_loader):
    if i >= 5:  # Warm-up with 5 batches
        break
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    with torch.no_grad():
        _ = model(input_ids, attention_mask)

print("Warm-up completed, measuring latency...")

# Measure latency
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Latency test"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        for i in range(input_ids.shape[0]):  # Process each sample individually
            single_input = input_ids[i:i+1]
            single_mask = attention_mask[i:i+1]
            
            start_time = time.perf_counter()
            _ = model(single_input, single_mask)
            end_time = time.perf_counter()
            
            latencies.append((end_time - start_time) * 1000)  # Convert to milliseconds

# Calculate latency statistics
latencies = np.array(latencies)
avg_latency = np.mean(latencies)
median_latency = np.median(latencies)
p95_latency = np.percentile(latencies, 95)
p99_latency = np.percentile(latencies, 99)

print(f"Average Latency: {avg_latency:.2f}ms")
print(f"Median Latency: {median_latency:.2f}ms")
print(f"95th Percentile: {p95_latency:.2f}ms")
print(f"99th Percentile: {p99_latency:.2f}ms")
print(f"Throughput: {1000/avg_latency:.1f} samples/sec")

# Check latency target
latency_target = 15.0  # ms
if avg_latency <= latency_target:
    print(f"✅ LATENCY TARGET MET: {avg_latency:.2f}ms ≤ {latency_target}ms")
else:
    print(f"❌ LATENCY TARGET MISSED: {avg_latency:.2f}ms > {latency_target}ms")



⚡ LATENCY BENCHMARKING
Warm-up completed, measuring latency...


Latency test: 100%|██████████| 471/471 [01:55<00:00,  4.08it/s]

Average Latency: 11.57ms
Median Latency: 15.17ms
95th Percentile: 15.42ms
99th Percentile: 15.45ms
Throughput: 86.5 samples/sec
✅ LATENCY TARGET MET: 11.57ms ≤ 15.0ms





In [41]:
# Performance comparison with M1 and M2
print("\n📈 MODEL COMPARISON")
print("="*50)

# M1 and M2 results (from previous reports)
m1_results = {
    'macro_f1': 0.7156,
    'latency': 10.07,
    'accuracy': 0.9158
}

m2_results = {
    'macro_f1': 0.7396,
    'latency': 72.39,
    'accuracy': 0.8912
}

m3_results = {
    'macro_f1': test_macro_f1,
    'latency': avg_latency,
    'accuracy': np.mean(test_labels == test_preds)
}

# Create comparison table
comparison_df = pd.DataFrame({
    'Model': ['M1 BERT-CLS', 'M2 RoBERTa-CLS', 'M3 RoBERTa-GRU'],
    'Macro-F1': [m1_results['macro_f1'], m2_results['macro_f1'], m3_results['macro_f1']],
    'Latency (ms)': [m1_results['latency'], m2_results['latency'], m3_results['latency']],
    'Accuracy': [m1_results['accuracy'], m2_results['accuracy'], m3_results['accuracy']]
})

print(comparison_df.to_string(index=False, float_format='%.4f'))

# Check if M3 beats M2
if m3_results['macro_f1'] > m2_results['macro_f1']:
    improvement = m3_results['macro_f1'] - m2_results['macro_f1']
    print(f"\n🏆 M3 BEATS M2! Improvement: +{improvement:.4f} Macro-F1 ({improvement/m2_results['macro_f1']*100:.1f}%)")
else:
    decline = m2_results['macro_f1'] - m3_results['macro_f1']
    print(f"\n📉 M3 underperforms M2: -{decline:.4f} Macro-F1 ({decline/m2_results['macro_f1']*100:.1f}%)")

# Production readiness assessment
print("\n🚀 PRODUCTION READINESS ASSESSMENT")
print("="*40)
target_f1 = 0.30
target_latency = 15.0

f1_status = "✅ PASS" if m3_results['macro_f1'] >= target_f1 else "❌ FAIL"
latency_status = "✅ PASS" if m3_results['latency'] <= target_latency else "❌ FAIL"

print(f"Macro-F1 Target (≥{target_f1}): {m3_results['macro_f1']:.4f} {f1_status}")
print(f"Latency Target (≤{target_latency}ms): {m3_results['latency']:.2f}ms {latency_status}")

if m3_results['macro_f1'] >= target_f1 and m3_results['latency'] <= target_latency:
    print("\n🎉 M3 IS PRODUCTION READY!")
else:
    print("\n⚠️ M3 needs optimization for production deployment")



📈 MODEL COMPARISON
         Model  Macro-F1  Latency (ms)  Accuracy
   M1 BERT-CLS    0.7156       10.0700    0.9158
M2 RoBERTa-CLS    0.7396       72.3900    0.8912
M3 RoBERTa-GRU    0.7408       11.5653    0.8924

🏆 M3 BEATS M2! Improvement: +0.0012 Macro-F1 (0.2%)

🚀 PRODUCTION READINESS ASSESSMENT
Macro-F1 Target (≥0.3): 0.7408 ✅ PASS
Latency Target (≤15.0ms): 11.57ms ✅ PASS

🎉 M3 IS PRODUCTION READY!


In [42]:
# Save M3 results
os.makedirs('../results', exist_ok=True)

# Save detailed results
m3_detailed_results = {
    'model_name': 'M3_RoBERTa_GRU',
    'config': CONFIG,
    'training_history': training_history,
    'test_results': {
        'macro_f1': float(test_macro_f1),
        'auc': float(test_auc),
        'accuracy': float(np.mean(test_labels == test_preds)),
        'test_loss': float(test_loss)
    },
    'latency_results': {
        'avg_latency_ms': float(avg_latency),
        'median_latency_ms': float(median_latency),
        'p95_latency_ms': float(p95_latency),
        'p99_latency_ms': float(p99_latency),
        'throughput_samples_per_sec': float(1000/avg_latency)
    },
    'model_comparison': {
        'M1_BERT_CLS': m1_results,
        'M2_RoBERTa_CLS': m2_results,
        'M3_RoBERTa_GRU': m3_results
    },
    'training_time_seconds': training_time
}

with open('../results/M3_roberta_gru_results.json', 'w') as f:
    json.dump(m3_detailed_results, f, indent=2)

print("✅ M3 results saved to ../results/M3_roberta_gru_results.json")
print("✅ M3 model checkpoint saved to ../checkpoints/M3_roberta_gru/best_model.pt")
print("\n🎯 M3 IMPLEMENTATION COMPLETE!")


✅ M3 results saved to ../results/M3_roberta_gru_results.json
✅ M3 model checkpoint saved to ../checkpoints/M3_roberta_gru/best_model.pt

🎯 M3 IMPLEMENTATION COMPLETE!
