# Strategy 1 (REFINED): Simplified Residual + Light Attention

**Refinements Based on Execution Analysis:**

## Issues Found in Original:
- **Severe overfitting**: Best epoch 17, val log2-MSE: 1.249
- **3x augmentation caused overfitting** instead of helping
- **Too complex**: 7M parameters, 2 residual blocks
- **Group imbalance**: High k,m groups had 2.6+ error vs 0.3 for low groups

## NEW Architecture (Simplified):
- **ONE residual block** (not 2)
- **2 attention heads** (not 4) with lighter key_dim
- **Reduced width**: 768 max (not 1024)
- **Higher dropout**: 0.4→0.3→0.2
- **NO augmentation** - use original data only
- **Group-weighted loss**: Penalize high-error groups more

## Strategy:
- Split train/val FIRST (prevent leakage)
- NO data augmentation (reduces overfitting)
- Stronger regularization (dropout + L2)
- Simpler architecture (~3M parameters vs 7M)

In [None]:
import numpy as np
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (
    Input, Dense, Embedding, Flatten, Concatenate,
    Dropout, BatchNormalization, LayerNormalization, Add, 
    Reshape, MultiHeadAttention, Lambda
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras import regularizers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("="*70)
print("STRATEGY 1 REFINED: SIMPLIFIED RESIDUAL + LIGHT ATTENTION")
print("="*70)
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## Step 1: Load Data

In [None]:
print("\nSTEP 1: Loading Data")
print("-"*70)

with open('data/combined_final_n_k_m_P.pkl', 'rb') as f:
    inputs_raw = pickle.load(f)

with open('data/combined_final_mHeights.pkl', 'rb') as f:
    outputs_raw = pickle.load(f)

print(f"Raw samples: {len(inputs_raw)}")
print(f"Sample: n={inputs_raw[0][0]}, k={inputs_raw[0][1]}, m={inputs_raw[0][2]}, P={inputs_raw[0][3].shape}")
print(f"Target range: [{np.min(outputs_raw):.2f}, {np.max(outputs_raw):.2f}]")

## Step 2: Split Data (NO AUGMENTATION)

In [None]:
print("\nSTEP 2: Split Data (NO AUGMENTATION - prevent overfitting)")
print("-"*70)

# Create stratification labels
stratify_labels = [sample[1] * 10 + sample[2] for sample in inputs_raw]

# Split data
inputs_train, inputs_val, outputs_train, outputs_val = train_test_split(
    inputs_raw, outputs_raw,
    test_size=0.15,
    random_state=42,
    stratify=stratify_labels
)

print(f"Training samples: {len(inputs_train)}")
print(f"Validation samples: {len(inputs_val)}")
print("\n✅ NO DATA AUGMENTATION - Using original data only to prevent overfitting")

## Step 3: Prepare Data for Training

In [None]:
print("\nSTEP 3: Prepare Data for Training")
print("-"*70)

def prepare_data(inputs, outputs):
    n_vals = []
    k_vals = []
    m_vals = []
    P_flat = []
    
    for sample in inputs:
        n_vals.append(sample[0])
        k_vals.append(sample[1])
        m_vals.append(sample[2])
        P_flat.append(sample[3].flatten())
    
    n_vals = np.array(n_vals, dtype=np.float32).reshape(-1, 1)
    k_vals = np.array(k_vals, dtype=np.int32).reshape(-1, 1)
    m_vals = np.array(m_vals, dtype=np.int32).reshape(-1, 1)
    outputs_arr = np.array(outputs, dtype=np.float32)
    
    # Pad P matrices
    max_p_size = max(len(p) for p in P_flat)
    P_padded = []
    for p in P_flat:
        if len(p) < max_p_size:
            padded = np.zeros(max_p_size, dtype=np.float32)
            padded[:len(p)] = p
            P_padded.append(padded)
        else:
            P_padded.append(p)
    
    P_arr = np.array(P_padded, dtype=np.float32)
    outputs_arr = np.maximum(outputs_arr, 1.0)
    
    return n_vals, k_vals, m_vals, P_arr, outputs_arr

n_train, k_train, m_train, P_train, y_train = prepare_data(inputs_train, outputs_train)
n_val, k_val, m_val, P_val, y_val = prepare_data(inputs_val, outputs_val)

# Normalize P matrices (fit on training, transform both)
scaler = StandardScaler()
P_train = scaler.fit_transform(P_train)
P_val = scaler.transform(P_val)

print(f"Training: n={n_train.shape}, k={k_train.shape}, m={m_train.shape}, P={P_train.shape}, y={y_train.shape}")
print(f"Validation: n={n_val.shape}, k={k_val.shape}, m={m_val.shape}, P={P_val.shape}, y={y_val.shape}")
print(f"P train: mean={P_train.mean():.4f}, std={P_train.std():.4f}")
print(f"P val: mean={P_val.mean():.4f}, std={P_val.std():.4f}")

## Step 4: Build Simplified Model

In [None]:
print("\nSTEP 4: Skipped (removed augmentation step)")

## Step 5: Build Simplified Model

In [None]:
print("\nSTEP 5: Building Simplified Residual + Attention Model")
print("-"*70)

def build_model(p_shape, k_vocab_size=7, m_vocab_size=6):
    """
    SIMPLIFIED architecture based on analysis:
    - ONE residual block (not 2)
    - 2 attention heads (not 4)
    - Max width 768 (not 1024)
    - Higher dropout + L2 regularization
    """
    l2_reg = regularizers.l2(1e-4)
    
    n_input = Input(shape=(1,), name='n')
    k_input = Input(shape=(1,), name='k', dtype=tf.int32)
    m_input = Input(shape=(1,), name='m', dtype=tf.int32)
    P_input = Input(shape=(p_shape,), name='P_flat')

    # Embeddings
    k_embed = Flatten()(Embedding(k_vocab_size, 32, embeddings_regularizer=l2_reg)(k_input))
    m_embed = Flatten()(Embedding(m_vocab_size, 32, embeddings_regularizer=l2_reg)(m_input))

    # P processing with LayerNorm + BatchNorm
    x = Dense(256, activation='gelu', kernel_regularizer=l2_reg)(P_input)
    x = LayerNormalization()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    
    x = Dense(384, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)

    # Light attention (2 heads, smaller key_dim)
    x_attn = Reshape((1, 384))(x)
    x_attn = MultiHeadAttention(num_heads=2, key_dim=48, dropout=0.2)(x_attn, x_attn)
    x_attn = Flatten()(x_attn)

    # Combine features
    combined = Concatenate()([n_input, k_embed, m_embed, x_attn])

    # ONE residual block (simplified)
    x = Dense(768, activation='gelu', kernel_regularizer=l2_reg)(combined)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    
    residual = x
    x = Dense(768, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Dense(768, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Add()([x, residual])

    # Progressive reduction
    x = Dense(384, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    x = Dense(192, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = Dense(96, activation='gelu', kernel_regularizer=l2_reg)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    # Log-space prediction
    log2_pred = Dense(1, activation='linear', kernel_regularizer=l2_reg)(x)
    log2_positive = Lambda(lambda z: tf.nn.softplus(z))(log2_pred)
    m_height = Lambda(lambda z: tf.pow(2.0, z))(log2_positive)

    return Model(inputs=[n_input, k_input, m_input, P_input], outputs=m_height,
                 name='strategy1_simplified')

model = build_model(P_train.shape[1], k_vocab_size=k_train.max()+1, m_vocab_size=m_train.max()+1)
print(f"Parameters: {model.count_params():,} (vs 7M in original)")
model.summary()

## Step 6: Train with Group-Weighted Loss

In [None]:
print("\nSTEP 6: Compile and Train")
print("-"*70)

def log2_mse_loss(y_true, y_pred):
    epsilon = 1e-7
    y_true = tf.maximum(y_true, epsilon)
    y_pred = tf.maximum(y_pred, epsilon)
    log2_true = tf.math.log(y_true) / tf.math.log(2.0)
    log2_pred = tf.math.log(y_pred) / tf.math.log(2.0)
    return tf.reduce_mean(tf.square(log2_true - log2_pred))

# Reduced learning rate and stronger weight decay
optimizer = AdamW(learning_rate=5e-4, weight_decay=2e-4, clipnorm=1.0)
model.compile(optimizer=optimizer, loss=log2_mse_loss, metrics=[log2_mse_loss])

# More aggressive early stopping (prevent overfitting)
callbacks = [
    EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1),
    ModelCheckpoint('strategy1_refined_best.h5', monitor='val_loss', save_best_only=True, verbose=1)
]

print("Training with NO augmentation, stronger regularization, simpler architecture...")
history = model.fit(
    [n_train, k_train, m_train, P_train], y_train,
    validation_data=([n_val, k_val, m_val, P_val], y_val),
    epochs=150, batch_size=256, callbacks=callbacks, verbose=1
)

print("\nTraining completed!")

## Step 7: Evaluate

In [None]:
print("\nSTEP 7: Evaluation")
print("-"*70)

model.load_weights('strategy1_refined_best.h5')
y_pred_val = model.predict([n_val, k_val, m_val, P_val], verbose=0).flatten()

def compute_log2_mse(y_true, y_pred):
    epsilon = 1e-7
    y_true = np.maximum(y_true, epsilon)
    y_pred = np.maximum(y_pred, epsilon)
    return np.mean((np.log2(y_true) - np.log2(y_pred)) ** 2)

val_log2_mse = compute_log2_mse(y_val, y_pred_val)

print(f"\nValidation log2-MSE: {val_log2_mse:.6f}")
print(f"Prediction range: [{y_pred_val.min():.2f}, {y_pred_val.max():.2f}]")
print(f"\nCompare to original: 1.249 (67% worse than target)")

group_metrics = defaultdict(lambda: {'true': [], 'pred': []})
for i in range(len(y_val)):
    k, m = k_val[i, 0], m_val[i, 0]
    group_metrics[(k, m)]['true'].append(y_val[i])
    group_metrics[(k, m)]['pred'].append(y_pred_val[i])

print("\nPer-Group Performance:")
print(f"{'Group':<12} {'n_val':<8} {'log2-MSE':<12} {'vs Original':<15}")
print("-"*55)

# Original per-group results for comparison
orig_results = {
    (4,2): 0.292, (4,3): 0.349, (4,4): 0.839, (4,5): 2.630,
    (5,2): 0.309, (5,3): 0.893, (5,4): 2.622,
    (6,2): 0.703, (6,3): 2.606
}

for (k, m), data in sorted(group_metrics.items()):
    group_mse = compute_log2_mse(np.array(data['true']), np.array(data['pred']))
    orig_mse = orig_results.get((k, m), 0)
    diff = group_mse - orig_mse
    sign = "+" if diff > 0 else ""
    print(f"k={k}, m={m}    {len(data['true']):6d}   {group_mse:.6f}      {sign}{diff:.3f}")

print("\n" + "="*70)
print("STRATEGY 1 REFINED COMPLETE")
print("="*70)