# Phase 4: Regression Training Validation (Task 8.5)
## January 27, 2023 11:40 AM

Validates regression model convergence using pre-computed DeBERTa embeddings.

**Prerequisites:**
- Upload `training_data_with_embeddings.parquet` to `/workspace/data/`
- RunPod GPU instance with PyTorch

In [None]:
# Install/upgrade dependencies - typing_extensions first to avoid import errors
!pip install -q --upgrade typing_extensions
!pip install -q torch pandas pyarrow scikit-learn scipy tqdm wandb matplotlib seaborn

# NOTE: If you get import errors after running this cell, restart the kernel and run again

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from typing import Dict, List, Tuple
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Optional: Initialize wandb
USE_WANDB = True

if USE_WANDB:
    import wandb
    wandb.init(
        project="white-regression-validation",
        name="phase4-regression-convergence",
        tags=["phase4", "regression", "runpod"]
    )

## Configuration

In [None]:
CONFIG = {
    "data": {
        "parquet_path": "/workspace/data/training_data_with_embeddings.parquet",
        "embedding_column": "embedding",
        "label_smoothing": 0.1,
        "train_split": 0.8,
        "random_seed": 42,
    },
    "model": {
        "embedding_dim": 768,
        "hidden_dims": [256, 128],
        "dropout": 0.3,
    },
    "training": {
        "batch_size": 64,
        "epochs": 30,
        "learning_rate": 1e-3,
        "weight_decay": 1e-4,
        "early_stopping_patience": 7,
        "loss_weights": {
            "temporal": 1.0,
            "spatial": 1.0,
            "ontological": 1.0,
            "confidence": 0.5,
        },
    },
}

## Load Data

In [None]:
print(f"Loading data from {CONFIG['data']['parquet_path']}...")
df = pd.read_parquet(CONFIG['data']['parquet_path'])
print(f"Loaded {len(df)} segments with embeddings")
print(f"\nColumns: {list(df.columns)}")

In [None]:
# Check embedding shape
sample_emb = df['embedding'].iloc[0]
if isinstance(sample_emb, list):
    sample_emb = np.array(sample_emb)
print(f"Embedding shape: {sample_emb.shape}")
print(f"Embedding dtype: {sample_emb.dtype}")

In [None]:
# Show Rainbow Table distribution
print("\nRainbow Table Distribution:")
print("=" * 50)

for col in ['rainbow_color_temporal_mode', 'rainbow_color_objectional_mode', 'rainbow_color_ontological_mode']:
    if col in df.columns:
        print(f"\n{col}:")
        print(df[col].value_counts())

## Soft Target Generation

In [None]:
class SoftTargetGenerator:
    """Converts discrete Rainbow Table labels to continuous regression targets."""

    TEMPORAL_MODES = ["Past", "Present", "Future"]
    SPATIAL_MODES = ["Thing", "Place", "Person"]
    ONTOLOGICAL_MODES = ["Imagined", "Forgotten", "Known"]

    def __init__(self, label_smoothing: float = 0.1):
        self.smoothing = label_smoothing

    def to_soft_target(self, label: str, mode_list: List[str]) -> np.ndarray:
        """Convert discrete label to smoothed soft target."""
        if label is None or pd.isna(label) or str(label) == "None":
            return np.array([1 / 3, 1 / 3, 1 / 3])

        target = np.zeros(len(mode_list))
        try:
            idx = mode_list.index(str(label))
            target[idx] = 1.0
        except ValueError:
            return np.array([1 / len(mode_list)] * len(mode_list))

        smoothed = (1 - self.smoothing) * target + self.smoothing * (1 / len(mode_list))
        return smoothed

    def generate_targets(self, row: pd.Series) -> Dict[str, np.ndarray]:
        """Generate all regression targets for a segment."""
        temporal_label = row.get("rainbow_color_temporal_mode")
        spatial_label = row.get("rainbow_color_objectional_mode")
        ontological_label = row.get("rainbow_color_ontological_mode")

        temporal = self.to_soft_target(temporal_label, self.TEMPORAL_MODES)
        spatial = self.to_soft_target(spatial_label, self.SPATIAL_MODES)
        ontological = self.to_soft_target(ontological_label, self.ONTOLOGICAL_MODES)

        is_black = all(
            pd.isna(x) or x is None or str(x) == "None"
            for x in [temporal_label, spatial_label, ontological_label]
        )
        confidence = np.array([0.0 if is_black else 1.0])

        return {
            "temporal": temporal,
            "spatial": spatial,
            "ontological": ontological,
            "confidence": confidence,
        }

## Dataset

In [None]:
class EmbeddingRegressionDataset(Dataset):
    """Dataset using pre-computed embeddings for regression training."""

    def __init__(self, df: pd.DataFrame, embedding_col: str, label_smoothing: float = 0.1):
        self.df = df.reset_index(drop=True)
        self.embedding_col = embedding_col
        self.target_generator = SoftTargetGenerator(label_smoothing)

        # Pre-compute all targets for efficiency
        print("Pre-computing soft targets...")
        self.targets = [
            self.target_generator.generate_targets(self.df.iloc[i])
            for i in tqdm(range(len(self.df)), desc="Generating targets")
        ]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        embedding = row[self.embedding_col]

        if isinstance(embedding, list):
            embedding = np.array(embedding)

        targets = self.targets[idx]

        return {
            "embedding": torch.tensor(embedding, dtype=torch.float32),
            "temporal_target": torch.tensor(targets["temporal"], dtype=torch.float32),
            "spatial_target": torch.tensor(targets["spatial"], dtype=torch.float32),
            "ontological_target": torch.tensor(targets["ontological"], dtype=torch.float32),
            "confidence_target": torch.tensor(targets["confidence"], dtype=torch.float32),
        }

## Model

In [None]:
class RainbowRegressionHead(nn.Module):
    """Regression head for Rainbow Table ontological mode prediction."""

    def __init__(
        self,
        embedding_dim: int = 768,
        hidden_dims: List[int] = [256, 128],
        dropout: float = 0.3,
    ):
        super().__init__()

        # Shared layers
        layers = []
        in_dim = embedding_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
            ])
            in_dim = hidden_dim

        self.shared = nn.Sequential(*layers)

        # Task-specific heads
        self.temporal_head = nn.Linear(hidden_dims[-1], 3)
        self.spatial_head = nn.Linear(hidden_dims[-1], 3)
        self.ontological_head = nn.Linear(hidden_dims[-1], 3)
        self.confidence_head = nn.Linear(hidden_dims[-1], 1)

    def forward(self, embeddings: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Forward pass returning softmax distributions and confidence."""
        x = self.shared(embeddings)

        return {
            "temporal": F.softmax(self.temporal_head(x), dim=-1),
            "spatial": F.softmax(self.spatial_head(x), dim=-1),
            "ontological": F.softmax(self.ontological_head(x), dim=-1),
            "confidence": torch.sigmoid(self.confidence_head(x)),
        }

## Loss and Metrics

In [None]:
class MultiTaskRegressionLoss(nn.Module):
    """Combined loss for multi-task regression."""

    def __init__(self, weights: Dict[str, float]):
        super().__init__()
        self.weights = weights
        self.kl_div = nn.KLDivLoss(reduction="batchmean")
        self.bce = nn.BCELoss()

    def forward(
        self, predictions: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        losses = {}

        # KL divergence for probability distributions
        for dim in ["temporal", "spatial", "ontological"]:
            pred_log = predictions[dim].clamp(min=1e-8).log()
            losses[dim] = self.kl_div(pred_log, targets[f"{dim}_target"])

        # BCE for confidence
        losses["confidence"] = self.bce(
            predictions["confidence"], targets["confidence_target"]
        )

        # Weighted sum
        total = sum(self.weights[k] * v for k, v in losses.items())

        return total, losses


def compute_metrics(
    predictions: Dict[str, torch.Tensor],
    targets: Dict[str, torch.Tensor],
) -> Dict[str, float]:
    """Compute comprehensive regression metrics."""
    metrics = {}

    for dim in ["temporal", "spatial", "ontological"]:
        pred = predictions[dim].cpu().numpy()
        targ = targets[f"{dim}_target"].cpu().numpy()

        # MAE
        metrics[f"{dim}_mae"] = mean_absolute_error(targ.flatten(), pred.flatten())

        # Mode accuracy (argmax matches)
        pred_mode = pred.argmax(axis=1)
        targ_mode = targ.argmax(axis=1)
        metrics[f"{dim}_mode_accuracy"] = (pred_mode == targ_mode).mean()

    # Confidence metrics
    conf_pred = predictions["confidence"].cpu().numpy().flatten()
    conf_targ = targets["confidence_target"].cpu().numpy().flatten()
    metrics["confidence_mae"] = mean_absolute_error(conf_targ, conf_pred)

    # Album prediction accuracy
    pred_temporal = predictions["temporal"].argmax(dim=-1)
    pred_spatial = predictions["spatial"].argmax(dim=-1)
    pred_ontological = predictions["ontological"].argmax(dim=-1)

    targ_temporal = targets["temporal_target"].argmax(dim=-1)
    targ_spatial = targets["spatial_target"].argmax(dim=-1)
    targ_ontological = targets["ontological_target"].argmax(dim=-1)

    correct = (
        (pred_temporal == targ_temporal)
        & (pred_spatial == targ_spatial)
        & (pred_ontological == targ_ontological)
    )
    metrics["album_accuracy"] = correct.float().mean().item()

    return metrics

## Training Functions

In [None]:
def train_epoch(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> Dict[str, float]:
    """Train for one epoch."""
    model.train()
    total_loss = 0
    component_losses = {k: 0 for k in ["temporal", "spatial", "ontological", "confidence"]}

    for batch in tqdm(loader, desc="Training", leave=False):
        embeddings = batch["embedding"].to(device)
        targets = {k: v.to(device) for k, v in batch.items() if "target" in k}

        predictions = model(embeddings)
        loss, losses = criterion(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        for k, v in losses.items():
            component_losses[k] += v.item()

    n_batches = len(loader)
    return {
        "train_loss": total_loss / n_batches,
        **{f"train_{k}_loss": v / n_batches for k, v in component_losses.items()},
    }


def validate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    """Validate model."""
    model.eval()
    total_loss = 0
    all_preds = {k: [] for k in ["temporal", "spatial", "ontological", "confidence"]}
    all_targets = {k: [] for k in ["temporal_target", "spatial_target", "ontological_target", "confidence_target"]}

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating", leave=False):
            embeddings = batch["embedding"].to(device)
            targets = {k: v.to(device) for k, v in batch.items() if "target" in k}

            predictions = model(embeddings)
            loss, _ = criterion(predictions, targets)

            total_loss += loss.item()

            for k in all_preds:
                all_preds[k].append(predictions[k])
            for k in all_targets:
                all_targets[k].append(targets[k])

    # Concatenate
    preds_cat = {k: torch.cat(v) for k, v in all_preds.items()}
    targets_cat = {k: torch.cat(v) for k, v in all_targets.items()}

    # Compute metrics
    metrics = compute_metrics(preds_cat, targets_cat)
    metrics["val_loss"] = total_loss / len(loader)

    return metrics, preds_cat, targets_cat

## Prepare Data

In [None]:
# Split data
train_idx, val_idx = train_test_split(
    range(len(df)),
    train_size=CONFIG["data"]["train_split"],
    random_state=CONFIG["data"]["random_seed"],
)

train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]

print(f"Train: {len(train_df)} segments")
print(f"Val: {len(val_df)} segments")

In [None]:
# Create datasets
train_dataset = EmbeddingRegressionDataset(
    train_df,
    CONFIG["data"]["embedding_column"],
    CONFIG["data"]["label_smoothing"],
)
val_dataset = EmbeddingRegressionDataset(
    val_df,
    CONFIG["data"]["embedding_column"],
    CONFIG["data"]["label_smoothing"],
)

In [None]:
# Create dataloaders (num_workers=0 to avoid multiprocessing issues in Jupyter)
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["training"]["batch_size"],
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["training"]["batch_size"],
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

## Initialize Model

In [None]:
model = RainbowRegressionHead(
    embedding_dim=CONFIG["model"]["embedding_dim"],
    hidden_dims=CONFIG["model"]["hidden_dims"],
    dropout=CONFIG["model"]["dropout"],
).to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")
print(model)

In [None]:
# Setup training
criterion = MultiTaskRegressionLoss(CONFIG["training"]["loss_weights"])
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG["training"]["learning_rate"],
    weight_decay=CONFIG["training"]["weight_decay"],
)
# Schedule based on album_accuracy (higher is better)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", patience=3, factor=0.5
)

## Training Loop

In [None]:
print("=" * 80)
print("TRAINING (early stopping on album_accuracy)")
print("=" * 80)

best_album_accuracy = 0.0
best_metrics = {}
patience_counter = 0
history = []

for epoch in range(CONFIG["training"]["epochs"]):
    print(f"\nEpoch {epoch + 1}/{CONFIG['training']['epochs']}")

    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_metrics, val_preds, val_targets = validate(model, val_loader, criterion, device)

    # Combine metrics
    all_metrics = {**train_metrics, **val_metrics}
    all_metrics["epoch"] = epoch
    history.append(all_metrics)

    # Log to wandb
    if USE_WANDB:
        wandb.log(all_metrics)

    # Print progress
    print(f"  Train Loss: {train_metrics['train_loss']:.4f}")
    print(f"  Val Loss: {val_metrics['val_loss']:.4f}")
    print(f"  Temporal Mode Acc: {val_metrics.get('temporal_mode_accuracy', 0):.3f}")
    print(f"  Spatial Mode Acc: {val_metrics.get('spatial_mode_accuracy', 0):.3f}")
    print(f"  Ontological Mode Acc: {val_metrics.get('ontological_mode_accuracy', 0):.3f}")
    print(f"  Album Accuracy: {val_metrics.get('album_accuracy', 0):.3f}")

    # Learning rate scheduling based on album accuracy
    current_album_acc = val_metrics.get("album_accuracy", 0)
    scheduler.step(current_album_acc)

    # Early stopping based on album accuracy (higher is better)
    if current_album_acc > best_album_accuracy:
        best_album_accuracy = current_album_acc
        best_metrics = val_metrics.copy()
        patience_counter = 0

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "album_accuracy": best_album_accuracy,
            "metrics": best_metrics,
            "config": CONFIG,
        }, "/workspace/output/regression_validation_best.pt")
        print(f"  Saved best model (album_accuracy={best_album_accuracy:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG["training"]["early_stopping_patience"]:
            print(f"\nEarly stopping triggered (patience={patience_counter})")
            break

print(f"\nBest album accuracy: {best_album_accuracy:.4f}")

## Results Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
ax = axes[0, 0]
ax.plot([h['train_loss'] for h in history], label='Train')
ax.plot([h['val_loss'] for h in history], label='Val')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True)

# Mode Accuracies
ax = axes[0, 1]
for dim in ['temporal', 'spatial', 'ontological']:
    ax.plot([h[f'{dim}_mode_accuracy'] for h in history], label=dim.capitalize())
ax.set_xlabel('Epoch')
ax.set_ylabel('Mode Accuracy')
ax.set_title('Mode Accuracies')
ax.legend()
ax.grid(True)

# Album Accuracy
ax = axes[1, 0]
ax.plot([h['album_accuracy'] for h in history], color='purple')
ax.set_xlabel('Epoch')
ax.set_ylabel('Album Accuracy')
ax.set_title('Album Prediction Accuracy')
ax.grid(True)

# MAE
ax = axes[1, 1]
for dim in ['temporal', 'spatial', 'ontological']:
    ax.plot([h[f'{dim}_mae'] for h in history], label=dim.capitalize())
ax.set_xlabel('Epoch')
ax.set_ylabel('MAE')
ax.set_title('Mean Absolute Error')
ax.legend()
ax.grid(True)

plt.tight_layout()
plt.savefig('/workspace/output/regression_training_curves.png', dpi=150)
plt.show()

## Convergence Assessment

In [None]:
print("=" * 80)
print("CONVERGENCE ASSESSMENT (Task 8.5)")
print("=" * 80)

print(f"\nBest Album Accuracy: {best_album_accuracy:.4f}")
print("\nBest Metrics:")
for k, v in sorted(best_metrics.items()):
    if isinstance(v, float):
        print(f"  {k}: {v:.4f}")

# Convergence checks
print("\n" + "-" * 40)
print("CONVERGENCE CHECKS")
print("-" * 40)

converged = True
issues = []

# Check mode accuracies
for dim in ["temporal", "spatial", "ontological"]:
    acc = best_metrics.get(f"{dim}_mode_accuracy", 0)
    if acc < 0.5:
        converged = False
        issues.append(f"{dim} mode accuracy ({acc:.3f}) below 0.5")
    elif acc < 0.7:
        issues.append(f"{dim} mode accuracy ({acc:.3f}) below 0.7 (warn)")

# Check album accuracy
if best_album_accuracy < 0.3:
    converged = False
    issues.append(f"Album accuracy ({best_album_accuracy:.3f}) below 0.3")
elif best_album_accuracy < 0.7:
    issues.append(f"Album accuracy ({best_album_accuracy:.3f}) below 0.7 (warn)")

# Check loss reduction
if len(history) > 5:
    initial_loss = history[0]["val_loss"]
    final_loss = history[-1]["val_loss"]
    reduction = (initial_loss - final_loss) / initial_loss if initial_loss > 0 else 0
    if reduction < 0.1:
        issues.append(f"Loss reduction ({reduction:.1%}) below 10% (info only)")

print()
if converged and not issues:
    print("CONVERGENCE VALIDATED - Model training successful!")
elif converged:
    print("CONVERGENCE WITH WARNINGS:")
    for issue in issues:
        print(f"   - {issue}")
else:
    print("CONVERGENCE FAILED:")
    for issue in issues:
        print(f"   - {issue}")

In [None]:
# Log final results to wandb
if USE_WANDB:
    wandb.log({
        "converged": converged,
        "num_issues": len(issues),
        "best_album_accuracy": best_album_accuracy,
        "best_val_loss": best_metrics.get("val_loss", 0),
        "best_temporal_mode_accuracy": best_metrics.get("temporal_mode_accuracy", 0),
        "best_spatial_mode_accuracy": best_metrics.get("spatial_mode_accuracy", 0),
        "best_ontological_mode_accuracy": best_metrics.get("ontological_mode_accuracy", 0),
    })
    
    # Log images
    wandb.log({"training_curves": wandb.Image("/workspace/output/regression_training_curves.png")})
    
    wandb.finish()
    print("\nResults logged to W&B")

## Save Final Results

In [None]:
# Save results summary
import json

results = {
    "task": "8.5 - Regression Convergence Validation",
    "converged": converged,
    "issues": issues,
    "best_album_accuracy": best_album_accuracy,
    "best_val_loss": best_metrics.get("val_loss", None),
    "best_metrics": {k: float(v) if isinstance(v, (int, float, np.floating)) else v for k, v in best_metrics.items()},
    "epochs_trained": len(history),
    "config": CONFIG,
}

with open("/workspace/output/regression_validation_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("\nResults saved to /workspace/output/regression_validation_results.json")
print("\nTask 8.5 Complete!")