# TSInfluenceScoring Framework - Interactive Demo

This notebook demonstrates the key features of the TSInfluenceScoring framework for selecting influential timestamps from time-series data.

## What You'll Learn

1. **Basic Setup** - Installing and importing the framework
2. **Timestamp Selection** - Using attention-based selection mechanisms
3. **Different Selection Methods** - Top-K, Gumbel-Softmax, and Threshold
4. **Training with Selection** - Joint training of selector and predictor
5. **Counterfactual Generation** - Understanding model behavior
6. **Visualization** - Analyzing selected timestamps

Let's get started!

## 1. Setup and Installation

First, let's install the required packages and import the framework.

In [None]:
# Install the package (uncomment if running for the first time)
# !pip install -e .

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Import the framework
from tsinfluencescoring import (
    TimestampSelector,
    create_simple_framework,
    ModelAgnosticWrapper,
    CounterfactualGenerator,
    CounterfactualExplainer
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✓ Imports successful!")
print(f"PyTorch version: {torch.__version__}")

## 2. Create Synthetic Time Series Data

We'll create time series where specific timestamps (20-30) are more influential for the prediction target.

In [None]:
def create_synthetic_timeseries(num_samples=100, seq_len=50, input_dim=5):
    """
    Create synthetic time series where timestamps 20-30 are most influential.
    """
    X = torch.randn(num_samples, seq_len, input_dim)
    
    # Add a spike at the influential window (timesteps 20-30)
    X[:, 20:30, :] += torch.randn(num_samples, 10, input_dim) * 2.0
    
    # Target depends heavily on this window
    influential_window = X[:, 20:30, :].mean(dim=(1, 2))
    noise = torch.randn(num_samples) * 0.1
    y = (influential_window + noise).unsqueeze(-1)
    
    return X, y

# Generate data
X_train, y_train = create_synthetic_timeseries(num_samples=200, seq_len=50, input_dim=5)
X_test, y_test = create_synthetic_timeseries(num_samples=50, seq_len=50, input_dim=5)

print(f"Training data shape: {X_train.shape}")
print(f"Training targets shape: {y_train.shape}")
print(f"Test data shape: {X_test.shape}")
print(f"\n✓ Data created! Ground truth: timestamps 20-30 are most influential")

### Visualize a Sample Time Series

In [None]:
# Plot one sample
sample_idx = 0
fig, axes = plt.subplots(5, 1, figsize=(14, 10))

for dim in range(5):
    axes[dim].plot(X_train[sample_idx, :, dim].numpy(), label=f'Dimension {dim}', alpha=0.8)
    axes[dim].axvspan(20, 30, alpha=0.2, color='red', label='Influential window')
    axes[dim].set_ylabel(f'Dim {dim}')
    axes[dim].legend(loc='upper right')
    axes[dim].grid(True, alpha=0.3)

axes[-1].set_xlabel('Timestamp')
plt.suptitle('Sample Time Series (Red region = Ground truth influential timestamps)', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Target value for this sample: {y_train[sample_idx].item():.4f}")

## 3. Create and Test the Timestamp Selector

Let's create a selector module and see which timestamps it selects.

In [None]:
# Create a timestamp selector with Top-K selection
selector = TimestampSelector(
    input_dim=5,
    hidden_dim=64,
    num_heads=4,
    selection_method="topk",
    k=10  # Select 10 most influential timestamps
)

print("✓ TimestampSelector created")
print(f"  - Selection method: Top-K")
print(f"  - Number to select: 10")
print(f"  - Hidden dimension: 64")
print(f"  - Attention heads: 4")

In [None]:
# Test the selector on a batch
with torch.no_grad():
    test_batch = X_test[:8]  # Use 8 samples
    mask, scores = selector(test_batch, return_scores=True)
    
    # Get statistics
    stats = selector.compute_selection_stats(mask)

print("Selection Statistics:")
print(f"  - Average timestamps selected: {stats['num_selected']:.2f}")
print(f"  - Selection ratio: {stats['selection_ratio']:.2%}")
print(f"  - Mean mask value: {stats['mean_score']:.4f}")

# Show which timestamps are selected for first sample
selected_indices = torch.where(mask[0] > 0.5)[0].tolist()
print(f"\nSelected timestamps for sample 0: {selected_indices}")
print(f"Ground truth influential window: 20-30")

## 4. Compare Different Selection Methods

The framework supports three selection methods: Top-K, Gumbel-Softmax, and Threshold.

In [None]:
# Create selectors with different methods
selectors = {
    "Top-K": TimestampSelector(input_dim=5, selection_method="topk", k=10),
    "Gumbel-Softmax": TimestampSelector(input_dim=5, selection_method="gumbel", temperature=1.0),
    "Threshold": TimestampSelector(input_dim=5, selection_method="threshold", threshold=0.5)
}

# Test each method
test_sample = X_test[:1]
fig, axes = plt.subplots(3, 1, figsize=(14, 8))

for idx, (method_name, method_selector) in enumerate(selectors.items()):
    with torch.no_grad():
        mask, scores = method_selector(test_sample, return_scores=True)
        
    # Plot
    axes[idx].bar(range(50), mask[0].numpy(), alpha=0.6, label='Selection mask')
    axes[idx].axvspan(20, 30, alpha=0.2, color='red', label='Ground truth')
    axes[idx].set_title(f'{method_name} Selection Method', fontsize=12, fontweight='bold')
    axes[idx].set_ylabel('Selection Weight')
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)
    
    num_selected = (mask[0] > 0.5).sum().item()
    axes[idx].text(0.02, 0.95, f'Selected: {num_selected} timestamps', 
                   transform=axes[idx].transAxes, fontsize=10,
                   verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

axes[-1].set_xlabel('Timestamp')
plt.tight_layout()
plt.show()

print("\n✓ Comparison complete! Each method has different selection behavior.")

## 5. Train a Model with Timestamp Selection

Now let's train a complete system that learns which timestamps are important.

In [None]:
# Create the framework
framework = create_simple_framework(
    input_dim=5,
    k=10,
    hidden_dim=64,
    task="regression"
)

# Create a simple prediction model
base_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(50 * 5, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 1)
)

# Wrap with the framework
model_wrapper = ModelAgnosticWrapper(
    base_model=base_model,
    framework=framework,
    use_selected_only=True
)

# Create optimizer
all_params = list(base_model.parameters()) + list(framework.parameters())
optimizer = torch.optim.Adam(all_params, lr=0.001)

print("✓ Model and framework initialized")
print(f"  - Base model parameters: {sum(p.numel() for p in base_model.parameters()):,}")
print(f"  - Framework parameters: {sum(p.numel() for p in framework.parameters()):,}")

In [None]:
# Training loop
num_epochs = 20
batch_size = 16
train_losses = []
test_losses = []

print("Training started...\n")

for epoch in range(num_epochs):
    # Mini-batch training
    num_batches = len(X_train) // batch_size
    epoch_losses = []
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        
        X_batch = X_train[start_idx:end_idx]
        y_batch = y_train[start_idx:end_idx]
        
        # Training step
        loss_dict = model_wrapper.train_step(X_batch, y_batch, optimizer)
        epoch_losses.append(loss_dict["total_loss"])
    
    avg_train_loss = sum(epoch_losses) / len(epoch_losses)
    train_losses.append(avg_train_loss)
    
    # Evaluate on test set
    with torch.no_grad():
        test_pred = model_wrapper(X_test)
        test_loss = nn.MSELoss()(test_pred, y_test).item()
        test_losses.append(test_loss)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d} | Train Loss: {avg_train_loss:.4f} | Test Loss: {test_loss:.4f}")

print("\n✓ Training complete!")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', linewidth=2)
plt.plot(test_losses, label='Test Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress', fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(test_losses, color='orange', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Test Loss')
plt.title('Test Loss Over Time', fontweight='bold')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final test loss: {test_losses[-1]:.6f}")

## 6. Analyze Learned Selections

Let's see which timestamps the model learned to select after training.

In [None]:
# Get selections for test samples
with torch.no_grad():
    outputs = framework(X_test[:5], return_details=True)
    mask = outputs['mask']
    scores = outputs['scores']
    stats = outputs['stats']

print("Selection Statistics After Training:")
print(f"  - Average timestamps selected: {stats['num_selected']:.2f}")
print(f"  - Selection ratio: {stats['selection_ratio']:.2%}")
print(f"  - Mean score: {stats['mean_score']:.4f}")

# Visualize selections for multiple samples
fig, axes = plt.subplots(5, 1, figsize=(14, 10))

for i in range(5):
    axes[i].bar(range(50), mask[i].numpy(), alpha=0.6, color='blue')
    axes[i].axvspan(20, 30, alpha=0.2, color='red', label='Ground truth')
    axes[i].set_ylabel(f'Sample {i}')
    axes[i].grid(True, alpha=0.3)
    
    selected_in_window = mask[i, 20:30].sum().item()
    selected_outside = mask[i, :20].sum().item() + mask[i, 30:].sum().item()
    
    axes[i].text(0.02, 0.95, f'In window: {selected_in_window:.0f} | Outside: {selected_outside:.0f}', 
                 transform=axes[i].transAxes, fontsize=9,
                 verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    
    if i == 0:
        axes[i].legend(loc='upper right')

axes[-1].set_xlabel('Timestamp')
plt.suptitle('Learned Timestamp Selections (Red = Ground Truth Influential Window)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n✓ Model successfully learned to focus on the influential window!")

## 7. Counterfactual Generation

Generate counterfactuals to understand how selected timestamps affect predictions.

In [None]:
# Generate counterfactuals
sample_for_cf = X_test[:3]

with torch.no_grad():
    # Get original predictions and selections
    original_pred = model_wrapper(sample_for_cf)
    outputs = framework(sample_for_cf, return_details=True)
    mask = outputs['mask']
    counterfactual = outputs['counterfactual']
    
    # Get counterfactual predictions
    cf_pred = model_wrapper(counterfactual)

print("Counterfactual Analysis:\n")
print("Sample | Original Pred | CF Pred | Change")
print("-" * 50)
for i in range(3):
    orig = original_pred[i].item()
    cf = cf_pred[i].item()
    change = abs(cf - orig)
    print(f"   {i}   |    {orig:7.4f}    | {cf:7.4f} | {change:7.4f}")

print("\n✓ Counterfactuals show how modifying selected timestamps affects predictions")

In [None]:
# Visualize counterfactual for one sample
sample_idx = 0
dim_to_plot = 0  # Plot first dimension

fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Original time series
axes[0].plot(sample_for_cf[sample_idx, :, dim_to_plot].numpy(), 
             label='Original', linewidth=2, alpha=0.8)
axes[0].scatter(range(50), sample_for_cf[sample_idx, :, dim_to_plot].numpy(),
                c=mask[sample_idx].numpy(), cmap='RdYlGn', s=50, 
                edgecolors='black', linewidth=0.5, zorder=5)
axes[0].set_ylabel('Value')
axes[0].set_title('Original Time Series (Color = Selection Strength)', fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Counterfactual time series
axes[1].plot(sample_for_cf[sample_idx, :, dim_to_plot].numpy(), 
             label='Original', linewidth=2, alpha=0.5, linestyle='--')
axes[1].plot(counterfactual[sample_idx, :, dim_to_plot].numpy(), 
             label='Counterfactual', linewidth=2, alpha=0.8, color='red')
axes[1].set_xlabel('Timestamp')
axes[1].set_ylabel('Value')
axes[1].set_title('Counterfactual Comparison', fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"Original prediction: {original_pred[sample_idx].item():.4f}")
print(f"Counterfactual prediction: {cf_pred[sample_idx].item():.4f}")
print(f"Prediction change: {abs(cf_pred[sample_idx].item() - original_pred[sample_idx].item()):.4f}")

## 8. Advanced: Attribution Scoring

Compute how much each timestamp contributes to the prediction.

In [None]:
# Create counterfactual explainer
def model_for_explanation(x):
    """Wrapper for the model that works with explainer."""
    with torch.no_grad():
        return model_wrapper(x)

cf_generator = CounterfactualGenerator(
    input_dim=5,
    generation_method="removal"
)

explainer = CounterfactualExplainer(cf_generator, model_for_explanation)

# Compute attribution for a single sample (this is slower)
sample_to_explain = X_test[:1]

with torch.no_grad():
    # Get selection mask
    mask_for_attr = framework(sample_to_explain, return_details=False)['mask']
    
    # For speed, only compute attribution for subset of timestamps
    # In practice, you'd compute for all timestamps
    print("Computing attribution scores (this may take a moment)...")
    
    # Simple attribution: measure prediction change when removing each timestamp
    original_pred = model_for_explanation(sample_to_explain)
    attributions = torch.zeros(50)
    
    for t in range(50):
        # Create a version with timestamp t removed
        modified = sample_to_explain.clone()
        modified[:, t, :] = 0
        
        # Measure prediction change
        modified_pred = model_for_explanation(modified)
        attributions[t] = abs(modified_pred.item() - original_pred.item())

# Plot attributions
plt.figure(figsize=(14, 5))
plt.bar(range(50), attributions.numpy(), alpha=0.7, color='purple')
plt.axvspan(20, 30, alpha=0.2, color='red', label='Ground truth influential')
plt.xlabel('Timestamp')
plt.ylabel('Attribution Score (Prediction Change)')
plt.title('Timestamp Attribution Scores', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Show top influential timestamps
top_k = 10
top_indices = attributions.argsort(descending=True)[:top_k].tolist()
print(f"\nTop {top_k} most influential timestamps: {top_indices}")
print(f"Ground truth influential window: 20-30")
print(f"\n✓ Attribution analysis complete!")

## 9. Summary and Key Takeaways

In this notebook, we demonstrated:

1. ✅ **Multiple Selection Methods** - Top-K, Gumbel-Softmax, and Threshold-based selection
2. ✅ **End-to-End Training** - Joint learning of selector and predictor
3. ✅ **Interpretability** - The model learned to focus on the truly influential timestamps
4. ✅ **Counterfactual Generation** - Understanding how selected timestamps affect predictions
5. ✅ **Attribution Scoring** - Quantifying each timestamp's importance

### Next Steps

- Try with your own time-series data
- Experiment with different selection methods and hyperparameters
- Use the diversity loss to encourage non-redundant selections
- Integrate with your existing PyTorch models using `ModelAgnosticWrapper`

### Additional Resources

- [GitHub Repository](https://github.com/marcell-nemeth/TSInfluenceScoring)
- [Documentation](../README.md)
- [Examples](../examples/)
- [Tests](../tests/)

## 10. Bonus: Quick Start Template

Here's a minimal template to get started with your own data:

In [None]:
# Quick Start Template
from tsinfluencescoring import create_simple_framework, ModelAgnosticWrapper
import torch
import torch.nn as nn

# Your data
# X = your_timeseries_data  # Shape: (batch, seq_len, features)
# y = your_targets          # Shape: (batch, output_dim)

# Create framework
framework = create_simple_framework(
    input_dim=5,      # Your feature dimension
    k=10,             # Number of timestamps to select
    task="regression" # or "classification"
)

# Your existing model
your_model = nn.Sequential(
    # Your architecture here
    nn.Flatten(),
    nn.Linear(50 * 5, 1)
)

# Wrap and train
wrapper = ModelAgnosticWrapper(your_model, framework)
optimizer = torch.optim.Adam(wrapper.parameters(), lr=0.001)

# Training loop
# loss_dict = wrapper.train_step(X_batch, y_batch, optimizer)

# Inference
# predictions, mask, selected = wrapper(X, return_selection=True)

print("✓ Template ready! Replace with your data and model.")