# Task 8.6: Multi-Task vs Single-Task Comparison
## January 27, 2023 11:15 AM
Compares:
1. **Single-task classification** (already done: 100% accuracy)
2. **Single-task regression** (already done: 98.4% mode acc, 96.9% album acc)
3. **Multi-task** (both heads trained simultaneously)

Using pre-computed DeBERTa embeddings.

In [None]:
# STEP 1: Upgrade typing_extensions FIRST, then restart kernel
!pip install -q --upgrade 'typing_extensions>=4.12.0'

print("typing_extensions upgraded. Now restart kernel:")
print("  Kernel -> Restart Kernel")
print("Then SKIP this cell and run from cell 2")

In [None]:
# STEP 2: Install other dependencies (run this AFTER kernel restart)
!pip install -q torch pandas pyarrow scikit-learn scipy tqdm wandb matplotlib seaborn

In [None]:
# Wandb (optional)
USE_WANDB = True
if USE_WANDB:
    import wandb
    wandb.init(project="white-multitask-comparison", name="task-8.6-comparison")

## Configuration

In [None]:
CONFIG = {
    "data": {
        "parquet_path": "/workspace/data/training_data_with_embeddings.parquet",
        "embedding_column": "embedding",
        "label_smoothing": 0.1,
        "train_split": 0.8,
        "random_seed": 42,
    },
    "model": {
        "embedding_dim": 768,
        "hidden_dims": [256, 128],
        "dropout": 0.3,
        "num_classes": 8,  # rebracketing types
    },
    "training": {
        "batch_size": 64,
        "epochs": 30,
        "learning_rate": 1e-3,
        "weight_decay": 1e-4,
        "early_stopping_patience": 7,
    },
    # Single-task results for comparison
    "baseline": {
        "single_task_classification_accuracy": 1.0,
        "single_task_regression_temporal_acc": 0.984,
        "single_task_regression_spatial_acc": 0.984,
        "single_task_regression_ontological_acc": 0.985,
        "single_task_regression_album_acc": 0.969,
    },
}

# Class mapping for rebracketing types
CLASS_MAPPING = {
    "spatial": 0, "temporal": 1, "causal": 2, "perceptual": 3,
    "memory": 4, "ontological": 5, "narrative": 6, "identity": 7,
}
IDX_TO_CLASS = {v: k for k, v in CLASS_MAPPING.items()}

## Load Data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

df = pd.read_parquet(CONFIG['data']['parquet_path'])
print(f"Loaded {len(df)} segments")

# Check for rebracketing_type column
if 'rebracketing_type' not in df.columns:
    print("WARNING: rebracketing_type not in data, will use 'temporal' as default")
    df['rebracketing_type'] = 'temporal'
else:
    print("\nRebracketing type distribution:")
    print(df['rebracketing_type'].value_counts())

## Dataset with Both Classification and Regression Targets

In [None]:
class SoftTargetGenerator:
    TEMPORAL_MODES = ["Past", "Present", "Future"]
    SPATIAL_MODES = ["Thing", "Place", "Person"]
    ONTOLOGICAL_MODES = ["Imagined", "Forgotten", "Known"]

    def __init__(self, label_smoothing: float = 0.1):
        self.smoothing = label_smoothing

    def to_soft_target(self, label, mode_list):
        if label is None or pd.isna(label) or str(label) == "None":
            return np.array([1/3, 1/3, 1/3])
        target = np.zeros(len(mode_list))
        try:
            target[mode_list.index(str(label))] = 1.0
        except ValueError:
            return np.array([1/len(mode_list)] * len(mode_list))
        return (1 - self.smoothing) * target + self.smoothing * (1/len(mode_list))

    def generate(self, row):
        temporal = self.to_soft_target(row.get("rainbow_color_temporal_mode"), self.TEMPORAL_MODES)
        spatial = self.to_soft_target(row.get("rainbow_color_objectional_mode"), self.SPATIAL_MODES)
        ontological = self.to_soft_target(row.get("rainbow_color_ontological_mode"), self.ONTOLOGICAL_MODES)
        is_black = all(pd.isna(x) or x is None or str(x) == "None" 
                       for x in [row.get("rainbow_color_temporal_mode"), 
                                 row.get("rainbow_color_objectional_mode"),
                                 row.get("rainbow_color_ontological_mode")])
        confidence = np.array([0.0 if is_black else 1.0])
        return {"temporal": temporal, "spatial": spatial, "ontological": ontological, "confidence": confidence}


class MultiTaskDataset(Dataset):
    def __init__(self, df, embedding_col, class_mapping, label_smoothing=0.1):
        self.df = df.reset_index(drop=True)
        self.embedding_col = embedding_col
        self.class_mapping = class_mapping
        self.target_gen = SoftTargetGenerator(label_smoothing)
        
        print("Pre-computing targets...")
        self.regression_targets = [self.target_gen.generate(self.df.iloc[i]) for i in tqdm(range(len(df)))]
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        embedding = row[self.embedding_col]
        if isinstance(embedding, list):
            embedding = np.array(embedding)
        
        # Classification target
        rtype = str(row.get('rebracketing_type', 'temporal')).lower()
        class_idx = self.class_mapping.get(rtype, 1)  # default to temporal
        
        # Regression targets
        reg = self.regression_targets[idx]
        
        return {
            "embedding": torch.tensor(embedding, dtype=torch.float32),
            "class_label": torch.tensor(class_idx, dtype=torch.long),
            "temporal_target": torch.tensor(reg["temporal"], dtype=torch.float32),
            "spatial_target": torch.tensor(reg["spatial"], dtype=torch.float32),
            "ontological_target": torch.tensor(reg["ontological"], dtype=torch.float32),
            "confidence_target": torch.tensor(reg["confidence"], dtype=torch.float32),
        }

## Multi-Task Model

In [None]:
class MultiTaskHead(nn.Module):
    """Combined classification + regression head."""
    
    def __init__(self, embedding_dim=768, hidden_dims=[256, 128], num_classes=8, dropout=0.3):
        super().__init__()
        
        # Shared layers
        layers = []
        in_dim = embedding_dim
        for h in hidden_dims:
            layers.extend([nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)])
            in_dim = h
        self.shared = nn.Sequential(*layers)
        
        # Task-specific heads
        self.classifier = nn.Linear(hidden_dims[-1], num_classes)
        self.temporal_head = nn.Linear(hidden_dims[-1], 3)
        self.spatial_head = nn.Linear(hidden_dims[-1], 3)
        self.ontological_head = nn.Linear(hidden_dims[-1], 3)
        self.confidence_head = nn.Linear(hidden_dims[-1], 1)
        
    def forward(self, x):
        h = self.shared(x)
        return {
            "class_logits": self.classifier(h),
            "temporal": F.softmax(self.temporal_head(h), dim=-1),
            "spatial": F.softmax(self.spatial_head(h), dim=-1),
            "ontological": F.softmax(self.ontological_head(h), dim=-1),
            "confidence": torch.sigmoid(self.confidence_head(h)),
        }

## Multi-Task Loss

In [None]:
class MultiTaskLoss(nn.Module):
    def __init__(self, cls_weight=1.0, reg_weight=1.0):
        super().__init__()
        self.cls_weight = cls_weight
        self.reg_weight = reg_weight
        self.ce = nn.CrossEntropyLoss()
        self.kl = nn.KLDivLoss(reduction="batchmean")
        self.bce = nn.BCELoss()
        
    def forward(self, preds, targets):
        # Classification loss
        cls_loss = self.ce(preds["class_logits"], targets["class_label"])
        
        # Regression losses
        t_loss = self.kl(preds["temporal"].clamp(min=1e-8).log(), targets["temporal_target"])
        s_loss = self.kl(preds["spatial"].clamp(min=1e-8).log(), targets["spatial_target"])
        o_loss = self.kl(preds["ontological"].clamp(min=1e-8).log(), targets["ontological_target"])
        c_loss = self.bce(preds["confidence"], targets["confidence_target"])
        
        reg_loss = t_loss + s_loss + o_loss + 0.5 * c_loss
        total = self.cls_weight * cls_loss + self.reg_weight * reg_loss
        
        return total, {"cls": cls_loss, "temporal": t_loss, "spatial": s_loss, 
                       "ontological": o_loss, "confidence": c_loss}

## Training and Evaluation

In [None]:
def compute_metrics(preds, targets):
    metrics = {}
    
    # Classification accuracy
    cls_pred = preds["class_logits"].argmax(dim=-1).cpu().numpy()
    cls_true = targets["class_label"].cpu().numpy()
    metrics["classification_accuracy"] = accuracy_score(cls_true, cls_pred)
    metrics["classification_f1"] = f1_score(cls_true, cls_pred, average="macro")
    
    # Regression mode accuracies
    for dim in ["temporal", "spatial", "ontological"]:
        pred_mode = preds[dim].argmax(dim=-1).cpu().numpy()
        true_mode = targets[f"{dim}_target"].argmax(dim=-1).cpu().numpy()
        metrics[f"{dim}_mode_accuracy"] = (pred_mode == true_mode).mean()
    
    # Album accuracy
    pred_t = preds["temporal"].argmax(dim=-1)
    pred_s = preds["spatial"].argmax(dim=-1)
    pred_o = preds["ontological"].argmax(dim=-1)
    true_t = targets["temporal_target"].argmax(dim=-1)
    true_s = targets["spatial_target"].argmax(dim=-1)
    true_o = targets["ontological_target"].argmax(dim=-1)
    correct = (pred_t == true_t) & (pred_s == true_s) & (pred_o == true_o)
    metrics["album_accuracy"] = correct.float().mean().item()
    
    return metrics


def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(loader, desc="Training", leave=False):
        emb = batch["embedding"].to(device)
        targets = {k: v.to(device) for k, v in batch.items() if k != "embedding"}
        
        preds = model(emb)
        loss, _ = criterion(preds, targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = {k: [] for k in ["class_logits", "temporal", "spatial", "ontological", "confidence"]}
    all_targets = {k: [] for k in ["class_label", "temporal_target", "spatial_target", "ontological_target", "confidence_target"]}
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating", leave=False):
            emb = batch["embedding"].to(device)
            targets = {k: v.to(device) for k, v in batch.items() if k != "embedding"}
            
            preds = model(emb)
            loss, _ = criterion(preds, targets)
            total_loss += loss.item()
            
            for k in all_preds:
                all_preds[k].append(preds[k])
            for k in all_targets:
                all_targets[k].append(targets[k])
    
    preds_cat = {k: torch.cat(v) for k, v in all_preds.items()}
    targets_cat = {k: torch.cat(v) for k, v in all_targets.items()}
    
    metrics = compute_metrics(preds_cat, targets_cat)
    metrics["val_loss"] = total_loss / len(loader)
    return metrics

## Prepare Data

In [None]:
train_idx, val_idx = train_test_split(
    range(len(df)), train_size=CONFIG["data"]["train_split"], 
    random_state=CONFIG["data"]["random_seed"]
)

train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]

print(f"Train: {len(train_df)}, Val: {len(val_df)}")

In [None]:
train_dataset = MultiTaskDataset(train_df, CONFIG["data"]["embedding_column"], CLASS_MAPPING)
val_dataset = MultiTaskDataset(val_df, CONFIG["data"]["embedding_column"], CLASS_MAPPING)

train_loader = DataLoader(train_dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG["training"]["batch_size"], shuffle=False, num_workers=0)

## Train Multi-Task Model

In [None]:
model = MultiTaskHead(
    embedding_dim=CONFIG["model"]["embedding_dim"],
    hidden_dims=CONFIG["model"]["hidden_dims"],
    num_classes=CONFIG["model"]["num_classes"],
    dropout=CONFIG["model"]["dropout"],
).to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

criterion = MultiTaskLoss(cls_weight=1.0, reg_weight=1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["training"]["learning_rate"], 
                              weight_decay=CONFIG["training"]["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)

In [None]:
print("=" * 80)
print("MULTI-TASK TRAINING")
print("=" * 80)

best_val_loss = float("inf")
best_metrics = {}
patience = 0
history = []

for epoch in range(CONFIG["training"]["epochs"]):
    print(f"\nEpoch {epoch+1}/{CONFIG['training']['epochs']}")
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_metrics = validate(model, val_loader, criterion, device)
    
    history.append({"train_loss": train_loss, **val_metrics})
    
    if USE_WANDB:
        wandb.log({"train_loss": train_loss, **val_metrics, "epoch": epoch})
    
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_metrics['val_loss']:.4f}")
    print(f"  Classification Acc: {val_metrics['classification_accuracy']:.3f}")
    print(f"  Temporal Mode Acc: {val_metrics['temporal_mode_accuracy']:.3f}")
    print(f"  Spatial Mode Acc: {val_metrics['spatial_mode_accuracy']:.3f}")
    print(f"  Ontological Mode Acc: {val_metrics['ontological_mode_accuracy']:.3f}")
    print(f"  Album Acc: {val_metrics['album_accuracy']:.3f}")
    
    scheduler.step(val_metrics["val_loss"])
    
    if val_metrics["val_loss"] < best_val_loss:
        best_val_loss = val_metrics["val_loss"]
        best_metrics = val_metrics.copy()
        patience = 0
        torch.save(model.state_dict(), "/workspace/output/multitask_best.pt")
        print("  Saved best model")
    else:
        patience += 1
        if patience >= CONFIG["training"]["early_stopping_patience"]:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

## Comparison Results

In [None]:
print("=" * 80)
print("TASK 8.6: MULTI-TASK vs SINGLE-TASK COMPARISON")
print("=" * 80)

baseline = CONFIG["baseline"]

comparison = {
    "Metric": [
        "Classification Accuracy",
        "Temporal Mode Accuracy",
        "Spatial Mode Accuracy", 
        "Ontological Mode Accuracy",
        "Album Accuracy",
    ],
    "Single-Task": [
        baseline["single_task_classification_accuracy"],
        baseline["single_task_regression_temporal_acc"],
        baseline["single_task_regression_spatial_acc"],
        baseline["single_task_regression_ontological_acc"],
        baseline["single_task_regression_album_acc"],
    ],
    "Multi-Task": [
        best_metrics["classification_accuracy"],
        best_metrics["temporal_mode_accuracy"],
        best_metrics["spatial_mode_accuracy"],
        best_metrics["ontological_mode_accuracy"],
        best_metrics["album_accuracy"],
    ],
}

comparison["Delta"] = [m - s for s, m in zip(comparison["Single-Task"], comparison["Multi-Task"])]

comparison_df = pd.DataFrame(comparison)
print("\n" + comparison_df.to_string(index=False))

In [None]:
# Visualization
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(comparison["Metric"]))
width = 0.35

bars1 = ax.bar(x - width/2, comparison["Single-Task"], width, label="Single-Task", color="steelblue")
bars2 = ax.bar(x + width/2, comparison["Multi-Task"], width, label="Multi-Task", color="coral")

ax.set_ylabel("Accuracy")
ax.set_title("Task 8.6: Single-Task vs Multi-Task Performance Comparison")
ax.set_xticks(x)
ax.set_xticklabels([m.replace(" ", "\n") for m in comparison["Metric"]], fontsize=9)
ax.legend()
ax.set_ylim(0.9, 1.02)
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)

# Add value labels
for bar in bars1:
    ax.annotate(f'{bar.get_height():.3f}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)
for bar in bars2:
    ax.annotate(f'{bar.get_height():.3f}', xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig("/workspace/output/multitask_comparison.png", dpi=150)
plt.show()

In [None]:
# Summary
print("\n" + "=" * 80)
print("CONCLUSION")
print("=" * 80)

avg_single = np.mean(comparison["Single-Task"])
avg_multi = np.mean(comparison["Multi-Task"])
avg_delta = np.mean(comparison["Delta"])

print(f"\nAverage Single-Task Performance: {avg_single:.3f}")
print(f"Average Multi-Task Performance: {avg_multi:.3f}")
print(f"Average Delta: {avg_delta:+.3f}")

if avg_delta > 0.01:
    verdict = "Multi-task IMPROVES performance"
elif avg_delta < -0.01:
    verdict = "Multi-task HURTS performance (task interference)"
else:
    verdict = "Multi-task performs SIMILARLY to single-task"

print(f"\nVerdict: {verdict}")
print("\nNote: Both approaches achieve near-ceiling performance (>96%).")
print("The embeddings already capture the task-relevant information effectively.")

In [None]:
# Save results
import json

results = {
    "task": "8.6 - Multi-Task vs Single-Task Comparison",
    "single_task": {k: float(v) for k, v in zip(comparison["Metric"], comparison["Single-Task"])},
    "multi_task": {k: float(v) for k, v in zip(comparison["Metric"], comparison["Multi-Task"])},
    "delta": {k: float(v) for k, v in zip(comparison["Metric"], comparison["Delta"])},
    "average_single": float(avg_single),
    "average_multi": float(avg_multi),
    "average_delta": float(avg_delta),
    "verdict": verdict,
    "epochs_trained": len(history),
}

with open("/workspace/output/multitask_comparison_results.json", "w") as f:
    json.dump(results, f, indent=2)

print("\nResults saved to /workspace/output/multitask_comparison_results.json")

if USE_WANDB:
    wandb.log({"comparison_chart": wandb.Image("/workspace/output/multitask_comparison.png")})
    wandb.log(results)
    wandb.finish()

print("\nTask 8.6 Complete!")