# Strategy 1: Advanced Residual + Selective Attention

**Based on Professor's Top Submissions (Avg Score: 77.8)**

## Architecture Features:
- Residual connections with skip paths
- Selective multi-head attention (1 attention block)
- LayerNorm + BatchNorm combinations
- Log-space prediction (log2)
- Progressive dropout (0.3 → 0.2 → 0.1)
- AdamW optimizer with weight decay

## Data Strategy:
- Heavy augmentation: 3x multiplier (108K → 324K samples)
- Techniques: Gaussian noise, perturbations, SMOTE-like interpolation
- Balanced across all (k,m) groups

In [None]:
import numpy as np
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (
    Input, Dense, Embedding, Flatten, Concatenate,
    Dropout, BatchNormalization, LayerNormalization, Add, 
    Reshape, MultiHeadAttention, Lambda
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from tensorflow.keras.optimizers import AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("="*70)
print("STRATEGY 1: ADVANCED RESIDUAL + SELECTIVE ATTENTION")
print("="*70)
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## Step 1: Load Data

In [None]:
print("\nSTEP 1: Loading Data")
print("-"*70)

with open('data/combined_final_n_k_m_P.pkl', 'rb') as f:
    inputs_raw = pickle.load(f)

with open('data/combined_final_mHeights.pkl', 'rb') as f:
    outputs_raw = pickle.load(f)

print(f"Raw samples: {len(inputs_raw)}")
print(f"Sample structure: n={inputs_raw[0][0]}, k={inputs_raw[0][1]}, m={inputs_raw[0][2]}, P shape={inputs_raw[0][3].shape}")
print(f"Target range: [{np.min(outputs_raw):.2f}, {np.max(outputs_raw):.2f}]")

## Step 2: Heavy Data Augmentation (3x)

In [None]:
print("\nSTEP 2: Heavy Data Augmentation (3x multiplier)")
print("-"*70)

def augment_sample_gaussian(n, k, m, P, target, noise_level=0.03):
    """Add Gaussian noise to P matrix"""
    P_aug = P.copy().astype(np.float32)
    noise = np.random.normal(0, noise_level, P_aug.shape)
    P_aug = P_aug + noise * np.std(P_aug)
    return [n, k, m, P_aug], target

def augment_sample_perturbation(n, k, m, P, target, strength=0.02):
    """Apply multiplicative perturbations"""
    P_aug = P.copy().astype(np.float32)
    perturbation = np.random.uniform(-strength, strength, P_aug.shape)
    P_aug = P_aug * (1 + perturbation)
    return [n, k, m, P_aug], target

def augment_sample_interpolation(sample1, sample2, target1, target2):
    """SMOTE-like interpolation between samples"""
    alpha = np.random.uniform(0.3, 0.7)
    n, k, m = sample1[0], sample1[1], sample1[2]
    P1, P2 = sample1[3], sample2[3]
    
    # Only interpolate if same (k,m)
    if sample1[1] == sample2[1] and sample1[2] == sample2[2]:
        P_new = alpha * P1.astype(np.float32) + (1 - alpha) * P2.astype(np.float32)
        target_new = alpha * target1 + (1 - alpha) * target2
        return [n, k, m, P_new], target_new
    return None, None

# Group by (k,m)
groups = defaultdict(list)
for i, sample in enumerate(inputs_raw):
    k, m = sample[1], sample[2]
    groups[(k, m)].append(i)

inputs_augmented = []
outputs_augmented = []

# For each sample, create 2 augmented copies (1 original + 2 aug = 3x total)
for i, (sample, target) in enumerate(zip(inputs_raw, outputs_raw)):
    n, k, m, P = sample
    
    # Keep original
    inputs_augmented.append(sample)
    outputs_augmented.append(target)
    
    # Augmentation 1: Gaussian noise
    aug1, tgt1 = augment_sample_gaussian(n, k, m, P, target)
    inputs_augmented.append(aug1)
    outputs_augmented.append(tgt1)
    
    # Augmentation 2: Perturbation or interpolation
    if np.random.rand() < 0.5:
        aug2, tgt2 = augment_sample_perturbation(n, k, m, P, target)
        inputs_augmented.append(aug2)
        outputs_augmented.append(tgt2)
    else:
        # Try interpolation with random sample from same group
        group_indices = groups[(k, m)]
        if len(group_indices) > 1:
            j = np.random.choice([idx for idx in group_indices if idx != i])
            aug2, tgt2 = augment_sample_interpolation(
                sample, inputs_raw[j], target, outputs_raw[j]
            )
            if aug2 is not None:
                inputs_augmented.append(aug2)
                outputs_augmented.append(tgt2)
            else:
                # Fallback to perturbation
                aug2, tgt2 = augment_sample_perturbation(n, k, m, P, target)
                inputs_augmented.append(aug2)
                outputs_augmented.append(tgt2)
        else:
            aug2, tgt2 = augment_sample_perturbation(n, k, m, P, target)
            inputs_augmented.append(aug2)
            outputs_augmented.append(tgt2)

print(f"Original samples: {len(inputs_raw)}")
print(f"Augmented samples: {len(inputs_augmented)}")
print(f"Augmentation ratio: {len(inputs_augmented) / len(inputs_raw):.2f}x")

## Step 3: Prepare Data

In [None]:
print("\nSTEP 3: Preparing Data for Training")
print("-"*70)

n_values = []
k_values = []
m_values = []
P_matrices_flattened = []

for sample in inputs_augmented:
    n_values.append(sample[0])
    k_values.append(sample[1])
    m_values.append(sample[2])
    P_matrices_flattened.append(sample[3].flatten())

n_values = np.array(n_values, dtype=np.float32).reshape(-1, 1)
k_values = np.array(k_values, dtype=np.int32).reshape(-1, 1)
m_values = np.array(m_values, dtype=np.int32).reshape(-1, 1)
outputs_array = np.array(outputs_augmented, dtype=np.float32)

# Pad P matrices
max_p_size = max(len(p) for p in P_matrices_flattened)
P_matrices_padded = []
for p in P_matrices_flattened:
    if len(p) < max_p_size:
        padded = np.zeros(max_p_size, dtype=np.float32)
        padded[:len(p)] = p
        P_matrices_padded.append(padded)
    else:
        P_matrices_padded.append(p)

P_matrices = np.array(P_matrices_padded, dtype=np.float32)

# Normalize
scaler = StandardScaler()
P_matrices_normalized = scaler.fit_transform(P_matrices)
outputs_array = np.maximum(outputs_array, 1.0)

print(f"Data shapes: n={n_values.shape}, k={k_values.shape}, m={m_values.shape}, P={P_matrices.shape}")
print(f"P normalized: mean={P_matrices_normalized.mean():.4f}, std={P_matrices_normalized.std():.4f}")

## Step 4: Train-Val Split

In [None]:
print("\nSTEP 4: Stratified Train-Validation Split")
print("-"*70)

stratify_labels = k_values.flatten() * 10 + m_values.flatten()

(n_train, n_val, k_train, k_val, m_train, m_val, P_train, P_val, 
 y_train, y_val, _, _) = train_test_split(
    n_values, k_values, m_values, P_matrices_normalized, outputs_array, stratify_labels,
    test_size=0.15, random_state=42, stratify=stratify_labels
)

print(f"Training: {len(y_train):,} samples")
print(f"Validation: {len(y_val):,} samples")

## Step 5: Build Advanced Residual + Attention Model

In [None]:
print("\nSTEP 5: Building Advanced Residual + Attention Model")
print("-"*70)

def build_advanced_residual_attention_model(p_shape, k_vocab_size=7, m_vocab_size=6):
    """
    Strategy 1: Advanced Residual + Selective Attention (Score: 77.8)
    - Residual connections with skip paths
    - Selective multi-head attention (1 attention block)
    - LayerNorm + BatchNorm combinations
    - Progressive dropout (0.3 → 0.2 → 0.1)
    """
    # Inputs
    n_input = Input(shape=(1,), name='n')
    k_input = Input(shape=(1,), name='k', dtype=tf.int32)
    m_input = Input(shape=(1,), name='m', dtype=tf.int32)
    P_input = Input(shape=(p_shape,), name='P_flat')

    # Embeddings
    k_embed = Flatten()(Embedding(k_vocab_size, 32)(k_input))
    m_embed = Flatten()(Embedding(m_vocab_size, 32)(m_input))

    # P processing with LayerNorm + BatchNorm
    x = Dense(256, activation='gelu')(P_input)
    x = LayerNormalization()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    x = Dense(512, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)

    # Selective Multi-Head Attention (1 block)
    x_attn = Reshape((1, 512))(x)
    x_attn = MultiHeadAttention(num_heads=4, key_dim=64, dropout=0.1)(x_attn, x_attn)
    x_attn = Flatten()(x_attn)

    # Combine features
    combined = Concatenate()([n_input, k_embed, m_embed, x_attn])

    # Residual Block 1
    x = Dense(1024, activation='gelu')(combined)
    x = LayerNormalization()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    
    residual = x
    x = Dense(1024, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.2)(x)
    x = Dense(1024, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Add()([x, residual])  # Skip connection

    # Residual Block 2
    residual = x
    x = Dense(1024, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.2)(x)
    x = Dense(1024, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Add()([x, residual])  # Skip connection

    # Progressive width reduction
    x = Dense(512, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.2)(x)
    
    x = Dense(256, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.1)(x)
    
    x = Dense(128, activation='gelu')(x)
    x = LayerNormalization()(x)
    x = Dropout(0.1)(x)

    # Log-space prediction with Softplus
    log2_pred = Dense(1, activation='linear')(x)
    log2_positive = Lambda(lambda z: tf.nn.softplus(z))(log2_pred)
    m_height = Lambda(lambda z: tf.pow(2.0, z))(log2_positive)

    model = Model(
        inputs=[n_input, k_input, m_input, P_input],
        outputs=m_height,
        name='strategy1_advanced_residual_attention'
    )
    return model

model = build_advanced_residual_attention_model(
    P_train.shape[1],
    k_vocab_size=k_values.max()+1,
    m_vocab_size=m_values.max()+1
)

print(f"Model parameters: {model.count_params():,}")
model.summary()

## Step 6: Compile & Train

In [None]:
print("\nSTEP 6: Compile and Train")
print("-"*70)

def log2_mse_loss(y_true, y_pred):
    epsilon = 1e-7
    y_true = tf.maximum(y_true, epsilon)
    y_pred = tf.maximum(y_pred, epsilon)
    log2_true = tf.math.log(y_true) / tf.math.log(2.0)
    log2_pred = tf.math.log(y_pred) / tf.math.log(2.0)
    return tf.reduce_mean(tf.square(log2_true - log2_pred))

optimizer = AdamW(learning_rate=1e-3, weight_decay=1e-4, clipnorm=1.0)
model.compile(optimizer=optimizer, loss=log2_mse_loss, metrics=[log2_mse_loss])

callbacks = [
    EarlyStopping(monitor='val_loss', patience=30, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.7, patience=15, min_lr=1e-6, verbose=1),
    ModelCheckpoint('strategy1_best_model.h5', monitor='val_loss', save_best_only=True, verbose=1)
]

history = model.fit(
    [n_train, k_train, m_train, P_train], y_train,
    validation_data=([n_val, k_val, m_val, P_val], y_val),
    epochs=200, batch_size=256, callbacks=callbacks, verbose=1
)

print("\nTraining completed!")

## Step 7: Evaluate

In [None]:
print("\nSTEP 7: Evaluation")
print("-"*70)

model.load_weights('strategy1_best_model.h5')
y_pred_val = model.predict([n_val, k_val, m_val, P_val], verbose=0).flatten()

def compute_log2_mse(y_true, y_pred):
    epsilon = 1e-7
    y_true = np.maximum(y_true, epsilon)
    y_pred = np.maximum(y_pred, epsilon)
    log2_true = np.log2(y_true)
    log2_pred = np.log2(y_pred)
    return np.mean((log2_true - log2_pred) ** 2)

val_log2_mse = compute_log2_mse(y_val, y_pred_val)

print(f"\nValidation log2-MSE: {val_log2_mse:.6f}")
print(f"Prediction range: [{y_pred_val.min():.2f}, {y_pred_val.max():.2f}]")

# Per-group analysis
group_metrics = defaultdict(lambda: {'true': [], 'pred': []})
for i in range(len(y_val)):
    k, m = k_val[i, 0], m_val[i, 0]
    group_metrics[(k, m)]['true'].append(y_val[i])
    group_metrics[(k, m)]['pred'].append(y_pred_val[i])

print("\nPer-Group Performance:")
print(f"{'Group':<12} {'n_val':<8} {'log2-MSE':<12}")
print("-"*40)
for (k, m), data in sorted(group_metrics.items()):
    true_vals = np.array(data['true'])
    pred_vals = np.array(data['pred'])
    group_mse = compute_log2_mse(true_vals, pred_vals)
    print(f"k={k}, m={m}    {len(true_vals):6d}   {group_mse:.6f}")

print("\n" + "="*70)
print("STRATEGY 1 COMPLETE")
print("="*70)