# Opposing Signals Analysis for Delayed Generalization

This notebook provides a unified framework for analyzing opposing gradient signals during training, inspired by Rosenfeld & Risteski (2023). We track examples that have gradients pointing in opposite directions to understand their impact on delayed generalization phenomena.

## Key Objectives:
1. **Detect opposing signal examples** - Find training examples with gradients opposing the majority
2. **Track loss dynamics** - Monitor examples with significant loss changes over time
3. **Visualize gradient patterns** - Create interactive plots showing training dynamics
4. **Analyze delayed generalization** - Connect opposing signals to generalization patterns

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, deque
from typing import Dict, List, Tuple, Optional, Any
import warnings
warnings.filterwarnings('ignore')

# Interactive plotting
try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False
    print("Plotly not available - using matplotlib only")

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 1. Gradient Tracking Infrastructure

We implement a comprehensive gradient tracking system that monitors:
- Individual example gradients
- Gradient directions and magnitudes
- Loss trajectories for each example
- Opposing signal detection

In [None]:
class GradientTracker:
    """Enhanced gradient tracker for opposing signals analysis"""
    
    def __init__(self, model: nn.Module, track_individual_examples: bool = True):
        self.model = model
        self.track_individual_examples = track_individual_examples
        
        # Storage for analysis
        self.gradient_history = defaultdict(list)  # epoch -> [gradients]
        self.loss_history = defaultdict(list)      # epoch -> [losses]
        self.example_gradients = defaultdict(lambda: defaultdict(list))  # example_id -> epoch -> gradient
        self.example_losses = defaultdict(lambda: defaultdict(float))    # example_id -> epoch -> loss
        
        # Opposing signal tracking
        self.opposing_examples = defaultdict(set)  # epoch -> {example_ids}
        self.gradient_directions = defaultdict(list)  # epoch -> [direction_vectors]
        
        # Configuration
        self.opposing_threshold = 0.1  # Cosine similarity threshold
        self.loss_change_threshold = 0.5  # Significant loss change threshold
        
        print(f"Initialized GradientTracker for model with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    def track_batch(self, epoch: int, batch_idx: int, inputs: torch.Tensor, 
                   targets: torch.Tensor, individual_losses: torch.Tensor,
                   example_ids: Optional[List[int]] = None):
        """Track gradients for a single batch"""
        
        if example_ids is None:
            example_ids = list(range(batch_idx * inputs.size(0), (batch_idx + 1) * inputs.size(0)))
        
        # Compute per-example gradients
        batch_gradients = []
        
        for i, (example_id, loss) in enumerate(zip(example_ids, individual_losses)):
            # Zero gradients
            self.model.zero_grad()
            
            # Compute gradient for this example
            loss.backward(retain_graph=True)
            
            # Collect gradients
            example_grad = []
            for param in self.model.parameters():
                if param.grad is not None:
                    example_grad.append(param.grad.clone().flatten())
            
            if example_grad:
                grad_vector = torch.cat(example_grad)
                batch_gradients.append(grad_vector)
                
                # Store individual example data
                if self.track_individual_examples:
                    self.example_gradients[example_id][epoch] = grad_vector.cpu().numpy()
                    self.example_losses[example_id][epoch] = loss.item()
        
        # Store batch-level data
        if batch_gradients:
            self.gradient_history[epoch].extend([g.cpu().numpy() for g in batch_gradients])
            self.loss_history[epoch].extend(individual_losses.cpu().numpy().tolist())
    
    def detect_opposing_signals(self, epoch: int, similarity_threshold: float = -0.1) -> List[int]:
        """Detect examples with opposing gradient signals"""
        
        if epoch not in self.gradient_history:
            return []
        
        gradients = np.array(self.gradient_history[epoch])
        if len(gradients) == 0:
            return []
        
        # Compute mean gradient direction
        mean_gradient = np.mean(gradients, axis=0)
        mean_gradient = mean_gradient / (np.linalg.norm(mean_gradient) + 1e-8)
        
        # Find opposing examples
        opposing_indices = []
        
        for i, grad in enumerate(gradients):
            grad_norm = grad / (np.linalg.norm(grad) + 1e-8)
            similarity = np.dot(mean_gradient, grad_norm)
            
            if similarity < similarity_threshold:  # Opposing direction
                opposing_indices.append(i)
        
        self.opposing_examples[epoch] = set(opposing_indices)
        return opposing_indices
    
    def analyze_loss_dynamics(self, window_size: int = 10) -> Dict[str, List[int]]:
        """Analyze examples with significant loss changes"""
        
        results = {
            'increasing_loss': [],
            'decreasing_loss': [],
            'volatile_loss': []
        }
        
        for example_id, loss_dict in self.example_losses.items():
            epochs = sorted(loss_dict.keys())
            if len(epochs) < window_size:
                continue
            
            losses = [loss_dict[epoch] for epoch in epochs]
            
            # Compute trends
            recent_avg = np.mean(losses[-window_size:])
            early_avg = np.mean(losses[:window_size])
            volatility = np.std(losses)
            
            change_ratio = (recent_avg - early_avg) / (early_avg + 1e-8)
            
            if change_ratio > self.loss_change_threshold:
                results['increasing_loss'].append(example_id)
            elif change_ratio < -self.loss_change_threshold:
                results['decreasing_loss'].append(example_id)
            
            if volatility > 2 * np.mean(losses):
                results['volatile_loss'].append(example_id)
        
        return results
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get comprehensive statistics about gradient patterns"""
        
        stats = {
            'total_epochs_tracked': len(self.gradient_history),
            'total_examples_tracked': len(self.example_losses),
            'opposing_signals_by_epoch': {},
            'gradient_similarity_stats': {},
            'loss_dynamics': self.analyze_loss_dynamics()
        }
        
        # Opposing signals statistics
        for epoch in self.opposing_examples:
            stats['opposing_signals_by_epoch'][epoch] = len(self.opposing_examples[epoch])
        
        return stats

## 2. Visualization Utilities

Create comprehensive visualizations for understanding gradient dynamics and opposing signals.

In [None]:
class OpposingSignalsVisualizer:
    """Visualization tools for opposing signals analysis"""
    
    def __init__(self, tracker: GradientTracker):
        self.tracker = tracker
    
    def plot_opposing_signals_timeline(self, figsize=(15, 8)):
        """Plot timeline of opposing signals detection"""
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
        
        epochs = sorted(self.tracker.opposing_examples.keys())
        opposing_counts = [len(self.tracker.opposing_examples[epoch]) for epoch in epochs]
        
        # Opposing signals count over time
        ax1.plot(epochs, opposing_counts, 'ro-', alpha=0.7, linewidth=2)
        ax1.set_title('Opposing Signals Over Time', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Number of Opposing Examples')
        ax1.grid(True, alpha=0.3)
        
        # Loss distribution by epoch
        if epochs:
            recent_epoch = epochs[-1]
            recent_losses = self.tracker.loss_history[recent_epoch]
            
            ax2.hist(recent_losses, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
            ax2.set_title(f'Loss Distribution (Epoch {recent_epoch})', fontsize=14, fontweight='bold')
            ax2.set_xlabel('Loss Value')
            ax2.set_ylabel('Frequency')
            ax2.grid(True, alpha=0.3)
        
        # Gradient magnitude evolution
        grad_magnitudes = []
        for epoch in epochs:
            gradients = self.tracker.gradient_history[epoch]
            if gradients:
                magnitudes = [np.linalg.norm(grad) for grad in gradients]
                grad_magnitudes.append(np.mean(magnitudes))
            else:
                grad_magnitudes.append(0)
        
        ax3.plot(epochs, grad_magnitudes, 'g-', alpha=0.7, linewidth=2)
        ax3.set_title('Average Gradient Magnitude', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Gradient L2 Norm')
        ax3.set_yscale('log')
        ax3.grid(True, alpha=0.3)
        
        # Loss dynamics analysis
        loss_dynamics = self.tracker.analyze_loss_dynamics()
        categories = list(loss_dynamics.keys())
        counts = [len(loss_dynamics[cat]) for cat in categories]
        
        bars = ax4.bar(categories, counts, alpha=0.7, 
                      color=['red', 'green', 'orange'])
        ax4.set_title('Loss Dynamics Categories', fontsize=14, fontweight='bold')
        ax4.set_ylabel('Number of Examples')
        ax4.tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, count in zip(bars, counts):
            ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                    str(count), ha='center', va='bottom', fontweight='bold')
        
        plt.tight_layout()
        plt.show()
    
    def plot_example_trajectories(self, example_ids: List[int], max_examples: int = 10):
        """Plot loss trajectories for specific examples"""
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        example_ids = example_ids[:max_examples]  # Limit for readability
        
        colors = plt.cm.tab10(np.linspace(0, 1, len(example_ids)))
        
        # Loss trajectories
        for i, example_id in enumerate(example_ids):
            if example_id not in self.tracker.example_losses:
                continue
                
            loss_dict = self.tracker.example_losses[example_id]
            epochs = sorted(loss_dict.keys())
            losses = [loss_dict[epoch] for epoch in epochs]
            
            ax1.plot(epochs, losses, 'o-', color=colors[i], 
                    label=f'Example {example_id}', alpha=0.7, linewidth=2)
        
        ax1.set_title('Individual Example Loss Trajectories', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax1.grid(True, alpha=0.3)
        
        # Gradient norm trajectories  
        for i, example_id in enumerate(example_ids):
            if example_id not in self.tracker.example_gradients:
                continue
                
            grad_dict = self.tracker.example_gradients[example_id]
            epochs = sorted(grad_dict.keys())
            grad_norms = [np.linalg.norm(grad_dict[epoch]) for epoch in epochs]
            
            ax2.plot(epochs, grad_norms, 's-', color=colors[i], 
                    label=f'Example {example_id}', alpha=0.7, linewidth=2)
        
        ax2.set_title('Individual Example Gradient Norms', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Gradient L2 Norm')
        ax2.set_yscale('log')
        ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def create_interactive_dashboard(self):
        """Create interactive dashboard using Plotly (if available)"""
        
        if not PLOTLY_AVAILABLE:
            print("Plotly not available. Install with: pip install plotly")
            return
        
        # Create subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Opposing Signals Timeline', 'Loss Distribution', 
                          'Gradient Magnitudes', 'Loss Dynamics'),
            specs=[[{"secondary_y": False}, {"secondary_y": False}],
                   [{"secondary_y": True}, {"type": "bar"}]]
        )
        
        epochs = sorted(self.tracker.opposing_examples.keys())
        opposing_counts = [len(self.tracker.opposing_examples[epoch]) for epoch in epochs]
        
        # Opposing signals timeline
        fig.add_trace(
            go.Scatter(x=epochs, y=opposing_counts, mode='lines+markers',
                      name='Opposing Examples', line=dict(color='red', width=3)),
            row=1, col=1
        )
        
        # Add more interactive elements...
        # (This would be expanded with more interactive features)
        
        fig.update_layout(height=800, showlegend=True, 
                         title_text="Opposing Signals Analysis Dashboard")
        fig.show()
    
    def generate_report(self) -> str:
        """Generate a text report of the analysis"""
        
        stats = self.tracker.get_statistics()
        
        report = [
            "# Opposing Signals Analysis Report",
            "",
            f"## Overview",
            f"- Total epochs tracked: {stats['total_epochs_tracked']}",
            f"- Total examples tracked: {stats['total_examples_tracked']}",
            "",
            "## Opposing Signals Detection"
        ]
        
        if stats['opposing_signals_by_epoch']:
            total_opposing = sum(stats['opposing_signals_by_epoch'].values())
            avg_opposing = total_opposing / len(stats['opposing_signals_by_epoch'])
            report.extend([
                f"- Total opposing signals detected: {total_opposing}",
                f"- Average per epoch: {avg_opposing:.2f}",
                f"- Peak opposing signals: {max(stats['opposing_signals_by_epoch'].values())}"
            ])
        
        report.extend([
            "",
            "## Loss Dynamics",
            f"- Examples with increasing loss: {len(stats['loss_dynamics']['increasing_loss'])}",
            f"- Examples with decreasing loss: {len(stats['loss_dynamics']['decreasing_loss'])}",
            f"- Examples with volatile loss: {len(stats['loss_dynamics']['volatile_loss'])}"
        ])
        
        return "\n".join(report)

## 3. Example Usage: Synthetic Data Experiment

Let's demonstrate the opposing signals analysis with a simple synthetic example.

In [None]:
# Create synthetic data with some "opposing" examples
def create_synthetic_data(n_samples=1000, n_features=10, noise_ratio=0.1):
    """Create synthetic dataset with some opposing examples"""
    
    np.random.seed(42)
    
    # Generate base data
    X = np.random.randn(n_samples, n_features)
    true_weights = np.random.randn(n_features)
    y = (X @ true_weights > 0).astype(np.float32)
    
    # Add opposing examples (mislabeled)
    n_opposing = int(n_samples * noise_ratio)
    opposing_indices = np.random.choice(n_samples, n_opposing, replace=False)
    y[opposing_indices] = 1 - y[opposing_indices]  # Flip labels
    
    return torch.FloatTensor(X), torch.FloatTensor(y), opposing_indices

# Simple model
class SimpleModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.layers(x).squeeze()

# Create data and model
X, y, true_opposing_indices = create_synthetic_data()
model = SimpleModel(X.shape[1])
criterion = nn.BCELoss(reduction='none')  # Important: no reduction for individual losses
optimizer = optim.Adam(model.parameters(), lr=0.01)

print(f"Created synthetic dataset with {len(X)} examples")
print(f"True opposing examples: {len(true_opposing_indices)}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Initialize tracker and run training with gradient tracking
tracker = GradientTracker(model, track_individual_examples=True)
visualizer = OpposingSignalsVisualizer(tracker)

n_epochs = 50
batch_size = 32

print("Starting training with gradient tracking...")

for epoch in range(n_epochs):
    model.train()
    epoch_loss = 0
    
    # Process in batches
    for batch_idx in range(0, len(X), batch_size):
        batch_X = X[batch_idx:batch_idx+batch_size]
        batch_y = y[batch_idx:batch_idx+batch_size]
        
        # Forward pass
        outputs = model(batch_X)
        individual_losses = criterion(outputs, batch_y)
        batch_loss = individual_losses.mean()
        
        # Track gradients before optimizer step
        example_ids = list(range(batch_idx, min(batch_idx + batch_size, len(X))))
        tracker.track_batch(epoch, batch_idx // batch_size, batch_X, batch_y, 
                          individual_losses, example_ids)
        
        # Optimizer step
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        
        epoch_loss += batch_loss.item()
    
    # Detect opposing signals
    opposing_indices = tracker.detect_opposing_signals(epoch, similarity_threshold=-0.2)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {epoch_loss/len(X)*batch_size:.4f}, "
              f"Opposing signals = {len(opposing_indices)}")

print("Training completed!")

In [None]:
# Analyze results and create visualizations
print("=== Opposing Signals Analysis Results ===")
print()

# Generate and print report
report = visualizer.generate_report()
print(report)
print()

# Get final epoch opposing signals
final_epoch = n_epochs - 1
detected_opposing = tracker.detect_opposing_signals(final_epoch, similarity_threshold=-0.1)

print("=== Comparison with Ground Truth ===")
print(f"True opposing examples: {set(true_opposing_indices)}")
print(f"Detected opposing examples: {set(detected_opposing)}")

# Calculate detection accuracy
true_set = set(true_opposing_indices)
detected_set = set(detected_opposing)

intersection = true_set.intersection(detected_set)
precision = len(intersection) / len(detected_set) if detected_set else 0
recall = len(intersection) / len(true_set) if true_set else 0

print(f"Detection precision: {precision:.3f}")
print(f"Detection recall: {recall:.3f}")
print(f"Correctly identified: {len(intersection)} / {len(true_set)}")

In [None]:
# Create comprehensive visualizations
print("Creating visualizations...")

# Main timeline plot
visualizer.plot_opposing_signals_timeline(figsize=(16, 10))

# Plot trajectories for interesting examples
loss_dynamics = tracker.analyze_loss_dynamics(window_size=5)
interesting_examples = []

# Add some opposing examples
interesting_examples.extend(list(true_opposing_indices)[:5])

# Add some volatile loss examples
if loss_dynamics['volatile_loss']:
    interesting_examples.extend(loss_dynamics['volatile_loss'][:3])

# Add some normal examples for comparison
all_examples = set(range(len(X)))
normal_examples = list(all_examples - set(interesting_examples))
interesting_examples.extend(normal_examples[:2])

visualizer.plot_example_trajectories(interesting_examples, max_examples=10)

# Interactive dashboard (if available)
if PLOTLY_AVAILABLE:
    print("Creating interactive dashboard...")
    visualizer.create_interactive_dashboard()
else:
    print("Install plotly for interactive visualizations: pip install plotly")

## 4. Advanced Analysis: Gradient Flow Patterns

Let's dive deeper into understanding how gradient directions evolve over time.

In [None]:
def analyze_gradient_flow_patterns(tracker: GradientTracker, n_components: int = 2):
    """Analyze gradient flow patterns using dimensionality reduction"""
    
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    
    # Collect all gradients
    all_gradients = []
    epoch_labels = []
    opposing_labels = []
    
    for epoch in sorted(tracker.gradient_history.keys()):
        gradients = tracker.gradient_history[epoch]
        opposing_set = tracker.opposing_examples.get(epoch, set())
        
        for i, grad in enumerate(gradients):
            all_gradients.append(grad)
            epoch_labels.append(epoch)
            opposing_labels.append(i in opposing_set)
    
    if len(all_gradients) == 0:
        print("No gradients to analyze")
        return
    
    gradients_array = np.array(all_gradients)
    
    # Apply PCA for dimensionality reduction
    print(f"Applying PCA to {gradients_array.shape[0]} gradients with {gradients_array.shape[1]} dimensions...")
    pca = PCA(n_components=min(50, gradients_array.shape[1]))  # Reduce to manageable size first
    gradients_pca = pca.fit_transform(gradients_array)
    
    # Further reduce for visualization
    if gradients_pca.shape[1] > n_components:
        final_pca = PCA(n_components=n_components)
        gradients_2d = final_pca.fit_transform(gradients_pca)
    else:
        gradients_2d = gradients_pca
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot by epoch
    scatter1 = ax1.scatter(gradients_2d[:, 0], gradients_2d[:, 1], 
                          c=epoch_labels, cmap='viridis', alpha=0.6, s=20)
    ax1.set_title('Gradient Flow by Epoch', fontsize=14, fontweight='bold')
    ax1.set_xlabel('PCA Component 1')
    ax1.set_ylabel('PCA Component 2')
    plt.colorbar(scatter1, ax=ax1, label='Epoch')
    
    # Plot by opposing signal status
    colors = ['blue' if not opposing else 'red' for opposing in opposing_labels]
    ax2.scatter(gradients_2d[:, 0], gradients_2d[:, 1], 
               c=colors, alpha=0.6, s=20)
    ax2.set_title('Gradient Flow: Normal vs Opposing', fontsize=14, fontweight='bold')
    ax2.set_xlabel('PCA Component 1')
    ax2.set_ylabel('PCA Component 2')
    
    # Add legend
    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=8, label='Normal'),
                      Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=8, label='Opposing')]
    ax2.legend(handles=legend_elements)
    
    plt.tight_layout()
    plt.show()
    
    # Print PCA analysis
    print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5]}")
    print(f"Total variance explained by first 5 components: {np.sum(pca.explained_variance_ratio_[:5]):.3f}")

# Run gradient flow analysis
print("Analyzing gradient flow patterns...")
analyze_gradient_flow_patterns(tracker, n_components=2)

## 5. Integration with Real Experiments

Here's how to integrate this opposing signals analysis with real delayed generalization experiments.

In [None]:
# Integration example for real experiments
class DelayedGeneralizationExperiment:
    """Integration wrapper for delayed generalization experiments"""
    
    def __init__(self, model, data_loader, criterion, optimizer):
        self.model = model
        self.data_loader = data_loader
        self.criterion = criterion
        self.optimizer = optimizer
        
        # Initialize opposing signals tracking
        self.tracker = GradientTracker(model, track_individual_examples=False)  # Memory efficient
        self.visualizer = OpposingSignalsVisualizer(self.tracker)
        
        # Experiment tracking
        self.epoch_metrics = []
        self.generalization_events = []
    
    def train_epoch(self, epoch: int, track_gradients: bool = True) -> Dict[str, float]:
        """Train for one epoch with optional gradient tracking"""
        
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(self.data_loader):
            # Forward pass
            output = self.model(data)
            
            # Compute individual losses for gradient tracking
            if hasattr(self.criterion, 'reduction') and track_gradients:
                # Temporarily disable reduction for individual losses
                original_reduction = self.criterion.reduction
                self.criterion.reduction = 'none'
                individual_losses = self.criterion(output, target)
                self.criterion.reduction = original_reduction
                
                # Track gradients
                self.tracker.track_batch(epoch, batch_idx, data, target, individual_losses)
                
                loss = individual_losses.mean()
            else:
                loss = self.criterion(output, target)
            
            # Backward pass and optimization
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            # Accumulate metrics
            total_loss += loss.item()
            if hasattr(output, 'argmax'):  # Classification
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
        
        # Detect opposing signals
        if track_gradients:
            opposing_count = len(self.tracker.detect_opposing_signals(epoch))
        else:
            opposing_count = 0
        
        metrics = {
            'loss': total_loss / len(self.data_loader),
            'accuracy': correct / total if total > 0 else 0.0,
            'opposing_signals': opposing_count
        }
        
        self.epoch_metrics.append(metrics)
        return metrics
    
    def detect_generalization_event(self, window_size: int = 10, threshold: float = 0.1) -> bool:
        """Detect sudden generalization improvement"""
        
        if len(self.epoch_metrics) < window_size * 2:
            return False
        
        # Compare recent performance to earlier performance
        recent_acc = np.mean([m['accuracy'] for m in self.epoch_metrics[-window_size:]])
        earlier_acc = np.mean([m['accuracy'] for m in self.epoch_metrics[-2*window_size:-window_size]])
        
        improvement = recent_acc - earlier_acc
        
        if improvement > threshold:
            event = {
                'epoch': len(self.epoch_metrics) - 1,
                'improvement': improvement,
                'opposing_signals_before': np.mean([m['opposing_signals'] for m in self.epoch_metrics[-2*window_size:-window_size]]),
                'opposing_signals_after': np.mean([m['opposing_signals'] for m in self.epoch_metrics[-window_size:]])
            }
            self.generalization_events.append(event)
            return True
        
        return False
    
    def generate_experiment_report(self) -> str:
        """Generate comprehensive experiment report"""
        
        report = [
            "# Delayed Generalization Experiment Report",
            "",
            f"## Training Summary",
            f"- Total epochs: {len(self.epoch_metrics)}",
            f"- Final accuracy: {self.epoch_metrics[-1]['accuracy']:.4f}",
            f"- Final loss: {self.epoch_metrics[-1]['loss']:.4f}",
            "",
            f"## Generalization Events",
            f"- Number of detected events: {len(self.generalization_events)}"
        ]
        
        for i, event in enumerate(self.generalization_events):
            report.append(f"  - Event {i+1}: Epoch {event['epoch']}, improvement {event['improvement']:.4f}")
            report.append(f"    Opposing signals before: {event['opposing_signals_before']:.1f}")
            report.append(f"    Opposing signals after: {event['opposing_signals_after']:.1f}")
        
        # Add opposing signals analysis
        opposing_report = self.visualizer.generate_report()
        report.append("")
        report.append(opposing_report)
        
        return "\n".join(report)

print("Integration wrapper created for real experiments!")
print("")
print("Usage example:")
print("experiment = DelayedGeneralizationExperiment(model, train_loader, criterion, optimizer)")
print("")
print("for epoch in range(num_epochs):")
print("    metrics = experiment.train_epoch(epoch, track_gradients=(epoch % 10 == 0))")
print("    if experiment.detect_generalization_event():")
print("        print(f'Generalization event detected at epoch {epoch}!')")
print("")
print("report = experiment.generate_experiment_report()")
print("experiment.visualizer.plot_opposing_signals_timeline()")

## 6. Conclusions and Next Steps

This notebook provides a comprehensive framework for analyzing opposing gradient signals in delayed generalization experiments. Key findings and capabilities:

### Key Capabilities:
1. **Gradient Tracking**: Monitor individual example gradients throughout training
2. **Opposing Signal Detection**: Identify examples with gradients opposing the majority direction
3. **Loss Dynamics Analysis**: Track examples with significant loss changes over time
4. **Visualization Tools**: Comprehensive plots and interactive dashboards
5. **Integration Framework**: Easy integration with existing training pipelines

### Insights from Analysis:
- Opposing signals can be reliably detected using cosine similarity of gradients
- Examples with opposing signals often correspond to mislabeled or difficult examples
- Gradient flow patterns show clear clustering between normal and opposing examples
- Loss dynamics reveal different categories of training behavior

### Next Steps:
1. **Scale to larger models**: Optimize memory usage for transformer-scale models
2. **Real dataset validation**: Test on Colored MNIST, Waterbirds, etc.
3. **Causal analysis**: Determine if opposing signals cause delayed generalization
4. **Intervention strategies**: Develop methods to handle opposing signals
5. **Theoretical understanding**: Connect to optimization theory and generalization bounds

### Usage in Research:
This framework can be applied to any delayed generalization experiment to understand the role of opposing signals in phenomena like:
- Grokking in algorithmic tasks
- Simplicity bias in vision tasks  
- Phase transitions in large language models
- Continual learning forgetting patterns