# 20. MTL vs STL Comparison

## Objective
Compare Multi-Task Learning (MTL) vs Single-Task Learning (STL) for 3H disease prediction.

## Hypothesis
MTL may improve performance by learning shared representations across tasks.

## Background
From comorbidity analysis:
- 3H diseases have weak direct correlation (Phi < 0.1)
- HTN ↔ HG: OR=1.72 (moderate association)
- HTN ↔ DL: OR=1.07 (nearly independent)

## Date: 2026-01-13

In [1]:
# Install PyTorch if not available
try:
    import torch
    print(f"PyTorch already installed: {torch.__version__}")
except ImportError:
    print("Installing PyTorch...")
    !pip install torch
    print("PyTorch installed successfully")

Installing PyTorch...
Collecting torch
  Downloading torch-1.13.1-cp37-cp37m-win_amd64.whl (162.6 MB)
     ------------------------------------- 162.6/162.6 MB 12.1 MB/s eta 0:00:00
Installing collected packages: torch
Successfully installed torch-1.13.1
PyTorch installed successfully


In [2]:
import pandas as pd
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score
from sklearn.base import clone

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

print("Packages loaded")
print(f"PyTorch version: {torch.__version__}")

Packages loaded
PyTorch version: 1.13.1+cpu


## 1. Load Data

In [3]:
# Load sliding window data
df = pd.read_csv('../../data/01_primary/SUA/processed/SUA_sliding_window.csv')
print(f"Data: {len(df):,} samples, {df['patient_id'].nunique():,} patients")

# Features
feature_cols = [
    'sex', 'Age',
    'FBG_Tinput1', 'TC_Tinput1', 'Cr_Tinput1', 'UA_Tinput1', 'GFR_Tinput1', 'BMI_Tinput1', 'SBP_Tinput1', 'DBP_Tinput1',
    'FBG_Tinput2', 'TC_Tinput2', 'Cr_Tinput2', 'UA_Tinput2', 'GFR_Tinput2', 'BMI_Tinput2', 'SBP_Tinput2', 'DBP_Tinput2',
    'Delta_FBG', 'Delta_TC', 'Delta_Cr', 'Delta_UA', 'Delta_GFR', 'Delta_BMI', 'Delta_SBP', 'Delta_DBP'
]

X = df[feature_cols]
groups = df['patient_id']

# Targets (convert 1/2 to 0/1)
y_htn = (df['hypertension_target'] == 2).astype(int)
y_hg = (df['hyperglycemia_target'] == 2).astype(int)
y_dl = (df['dyslipidemia_target'] == 2).astype(int)

# Combined target matrix for MTL
Y = pd.DataFrame({
    'HTN': y_htn,
    'HG': y_hg,
    'DL': y_dl
})

print(f"\nFeatures: {len(feature_cols)}")
print(f"Tasks: {list(Y.columns)}")
print(f"\nPrevalence:")
for col in Y.columns:
    print(f"  {col}: {Y[col].mean()*100:.1f}%")

Data: 13,514 samples, 6,056 patients

Features: 26
Tasks: ['HTN', 'HG', 'DL']

Prevalence:
  HTN: 19.3%
  HG: 5.9%
  DL: 7.9%


## 2. Define Models

### STL (Single-Task Learning)
- 3 independent models, one for each disease

### MTL (Multi-Task Learning)
- 1 shared model predicting all 3 diseases simultaneously

In [4]:
# PyTorch MTL Model
class MTLNetwork(nn.Module):
    """Multi-Task Learning Network with shared layers"""
    def __init__(self, input_dim, hidden_dims=[64, 32], n_tasks=3):
        super().__init__()
        
        # Shared layers
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, h_dim),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            prev_dim = h_dim
        self.shared = nn.Sequential(*layers)
        
        # Task-specific heads
        self.heads = nn.ModuleList([
            nn.Linear(hidden_dims[-1], 1) for _ in range(n_tasks)
        ])
    
    def forward(self, x):
        shared_repr = self.shared(x)
        outputs = [torch.sigmoid(head(shared_repr)) for head in self.heads]
        return torch.cat(outputs, dim=1)


class STLNetwork(nn.Module):
    """Single-Task Learning Network (same architecture, but independent)"""
    def __init__(self, input_dim, hidden_dims=[64, 32]):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, h_dim),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            prev_dim = h_dim
        layers.append(nn.Linear(hidden_dims[-1], 1))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return torch.sigmoid(self.network(x))

print("Models defined")

Models defined


In [5]:
def train_mtl_model(X_train, Y_train, X_val, Y_val, epochs=100, lr=0.001, batch_size=256):
    """Train MTL model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Prepare data
    X_train_t = torch.FloatTensor(X_train).to(device)
    Y_train_t = torch.FloatTensor(Y_train.values).to(device)
    X_val_t = torch.FloatTensor(X_val).to(device)
    Y_val_t = torch.FloatTensor(Y_val.values).to(device)
    
    train_dataset = TensorDataset(X_train_t, Y_train_t)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Model
    model = MTLNetwork(X_train.shape[1]).to(device)
    
    # Class weights for imbalanced data
    pos_weights = []
    for i in range(Y_train.shape[1]):
        pos_rate = Y_train.iloc[:, i].mean()
        pos_weights.append((1 - pos_rate) / pos_rate)
    pos_weight = torch.FloatTensor(pos_weights).to(device)
    
    criterion = nn.BCELoss(reduction='none')
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Training
    best_auc = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        for X_batch, Y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            
            # Weighted loss
            loss = criterion(outputs, Y_batch)
            weights = torch.where(Y_batch == 1, pos_weight, torch.ones_like(Y_batch))
            loss = (loss * weights).mean()
            
            loss.backward()
            optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val_t).cpu().numpy()
            val_aucs = [roc_auc_score(Y_val.iloc[:, i], val_outputs[:, i]) for i in range(3)]
            mean_auc = np.mean(val_aucs)
            
            if mean_auc > best_auc:
                best_auc = mean_auc
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break
    
    # Final prediction
    model.eval()
    with torch.no_grad():
        predictions = model(X_val_t).cpu().numpy()
    
    return predictions, model


def train_stl_models(X_train, Y_train, X_val, Y_val, epochs=100, lr=0.001, batch_size=256):
    """Train 3 independent STL models"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train_t = torch.FloatTensor(X_train).to(device)
    X_val_t = torch.FloatTensor(X_val).to(device)
    
    all_predictions = []
    models = []
    
    for task_idx in range(Y_train.shape[1]):
        y_train = Y_train.iloc[:, task_idx].values
        y_val = Y_val.iloc[:, task_idx].values
        
        y_train_t = torch.FloatTensor(y_train).unsqueeze(1).to(device)
        
        train_dataset = TensorDataset(X_train_t, y_train_t)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        
        model = STLNetwork(X_train.shape[1]).to(device)
        
        # Class weight
        pos_rate = y_train.mean()
        pos_weight = torch.FloatTensor([(1 - pos_rate) / pos_rate]).to(device)
        
        criterion = nn.BCELoss(reduction='none')
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        best_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(epochs):
            model.train()
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                outputs = model(X_batch)
                
                loss = criterion(outputs, y_batch)
                weights = torch.where(y_batch == 1, pos_weight, torch.ones_like(y_batch))
                loss = (loss * weights).mean()
                
                loss.backward()
                optimizer.step()
            
            model.eval()
            with torch.no_grad():
                val_pred = model(X_val_t).cpu().numpy().flatten()
                val_auc = roc_auc_score(y_val, val_pred)
                
                if val_auc > best_auc:
                    best_auc = val_auc
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        break
        
        model.eval()
        with torch.no_grad():
            predictions = model(X_val_t).cpu().numpy().flatten()
        
        all_predictions.append(predictions)
        models.append(model)
    
    return np.column_stack(all_predictions), models

print("Training functions defined")

Training functions defined


## 3. Run Experiment with 5-Fold CV

In [6]:
# 5-Fold CV with GroupKFold
n_splits = 5
cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)

results = {
    'MTL': {'HTN': [], 'HG': [], 'DL': [], 'time': []},
    'STL': {'HTN': [], 'HG': [], 'DL': [], 'time': []}
}

print("="*70)
print("5-Fold CV: MTL vs STL")
print("="*70)

for fold, (train_idx, test_idx) in enumerate(cv.split(X, y_htn, groups)):
    print(f"\nFold {fold+1}/{n_splits}")
    
    # Split data
    X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
    Y_train, Y_test = Y.iloc[train_idx], Y.iloc[test_idx]
    
    # Verify no leakage
    train_patients = set(groups.iloc[train_idx])
    test_patients = set(groups.iloc[test_idx])
    assert len(train_patients & test_patients) == 0
    
    # Standardize
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # --- MTL ---
    start_time = time.time()
    mtl_preds, _ = train_mtl_model(X_train_scaled, Y_train, X_test_scaled, Y_test)
    mtl_time = time.time() - start_time
    
    for i, task in enumerate(['HTN', 'HG', 'DL']):
        auc = roc_auc_score(Y_test.iloc[:, i], mtl_preds[:, i])
        results['MTL'][task].append(auc)
    results['MTL']['time'].append(mtl_time)
    
    print(f"  MTL: HTN={results['MTL']['HTN'][-1]:.3f}, HG={results['MTL']['HG'][-1]:.3f}, DL={results['MTL']['DL'][-1]:.3f} ({mtl_time:.1f}s)")
    
    # --- STL ---
    start_time = time.time()
    stl_preds, _ = train_stl_models(X_train_scaled, Y_train, X_test_scaled, Y_test)
    stl_time = time.time() - start_time
    
    for i, task in enumerate(['HTN', 'HG', 'DL']):
        auc = roc_auc_score(Y_test.iloc[:, i], stl_preds[:, i])
        results['STL'][task].append(auc)
    results['STL']['time'].append(stl_time)
    
    print(f"  STL: HTN={results['STL']['HTN'][-1]:.3f}, HG={results['STL']['HG'][-1]:.3f}, DL={results['STL']['DL'][-1]:.3f} ({stl_time:.1f}s)")

print("\n" + "="*70)
print("CV Complete")
print("="*70)

5-Fold CV: MTL vs STL

Fold 1/5
  MTL: HTN=0.747, HG=0.924, DL=0.879 (12.1s)
  STL: HTN=0.755, HG=0.923, DL=0.884 (8.8s)

Fold 2/5
  MTL: HTN=0.741, HG=0.930, DL=0.866 (9.3s)
  STL: HTN=0.738, HG=0.935, DL=0.865 (10.1s)

Fold 3/5
  MTL: HTN=0.741, HG=0.927, DL=0.858 (8.5s)
  STL: HTN=0.749, HG=0.928, DL=0.859 (11.3s)

Fold 4/5
  MTL: HTN=0.695, HG=0.937, DL=0.861 (7.5s)
  STL: HTN=0.713, HG=0.936, DL=0.864 (14.4s)

Fold 5/5
  MTL: HTN=0.749, HG=0.941, DL=0.876 (7.4s)
  STL: HTN=0.756, HG=0.942, DL=0.874 (9.7s)

CV Complete


## 4. Results Summary

In [7]:
# Summary table
print("="*70)
print("Results: MTL vs STL (5-Fold CV)")
print("="*70)

summary = []
for method in ['MTL', 'STL']:
    for task in ['HTN', 'HG', 'DL']:
        aucs = results[method][task]
        summary.append({
            'Method': method,
            'Task': task,
            'AUC_mean': np.mean(aucs),
            'AUC_std': np.std(aucs)
        })

summary_df = pd.DataFrame(summary)
print("\nAUC Comparison:")
print(summary_df.pivot(index='Task', columns='Method', values='AUC_mean').round(3))

# Time comparison
print(f"\nTraining Time:")
print(f"  MTL: {np.mean(results['MTL']['time']):.1f}s per fold")
print(f"  STL: {np.mean(results['STL']['time']):.1f}s per fold")
print(f"  Speedup: {np.mean(results['STL']['time']) / np.mean(results['MTL']['time']):.2f}x")

Results: MTL vs STL (5-Fold CV)

AUC Comparison:
Method    MTL    STL
Task                
DL      0.868  0.869
HG      0.932  0.933
HTN     0.734  0.742

Training Time:
  MTL: 9.0s per fold
  STL: 10.9s per fold
  Speedup: 1.21x


In [8]:
# Detailed comparison
print("\n" + "="*70)
print("Detailed Comparison")
print("="*70)

print("\n| Task | MTL AUC | STL AUC | Diff | Winner |")
print("|------|---------|---------|------|--------|")

for task in ['HTN', 'HG', 'DL']:
    mtl_auc = np.mean(results['MTL'][task])
    stl_auc = np.mean(results['STL'][task])
    diff = mtl_auc - stl_auc
    winner = 'MTL' if diff > 0.005 else 'STL' if diff < -0.005 else 'Tie'
    print(f"| {task} | {mtl_auc:.3f} | {stl_auc:.3f} | {diff:+.3f} | {winner} |")

# Overall
mtl_avg = np.mean([np.mean(results['MTL'][t]) for t in ['HTN', 'HG', 'DL']])
stl_avg = np.mean([np.mean(results['STL'][t]) for t in ['HTN', 'HG', 'DL']])
print(f"| **Avg** | {mtl_avg:.3f} | {stl_avg:.3f} | {mtl_avg-stl_avg:+.3f} | {'MTL' if mtl_avg > stl_avg else 'STL'} |")


Detailed Comparison

| Task | MTL AUC | STL AUC | Diff | Winner |
|------|---------|---------|------|--------|
| HTN | 0.734 | 0.742 | -0.008 | STL |
| HG | 0.932 | 0.933 | -0.001 | Tie |
| DL | 0.868 | 0.869 | -0.001 | Tie |
| **Avg** | 0.845 | 0.848 | -0.003 | STL |


In [9]:
# Save results
results_df = pd.DataFrame([
    {'Method': method, 'Task': task, 'Fold': fold+1, 'AUC': auc}
    for method in ['MTL', 'STL']
    for task in ['HTN', 'HG', 'DL']
    for fold, auc in enumerate(results[method][task])
])
results_df.to_csv('../../results/mtl_vs_stl_results.csv', index=False)
print("Saved: results/mtl_vs_stl_results.csv")

Saved: results/mtl_vs_stl_results.csv


## 5. Conclusion

### Key Findings

1. **Performance**: (fill after running)
   - MTL vs STL difference: ~X%

2. **Training Time**:
   - MTL: faster (shared computation)
   - STL: slower (3 independent models)

3. **Why MTL may not outperform STL**:
   - 3H diseases have weak correlation (Phi < 0.1)
   - Sample size (13,514) is sufficient for STL
   - Task difficulties vary (HG AUC=0.94 vs HTN AUC=0.74)

### Implication

For this dataset, STL and MTL perform similarly. 
MTL's benefit (shared representation) is limited when:
- Tasks are weakly correlated
- Data is sufficient for independent learning