# Height Prediction Model Training Script - Run 1 (Best Performing)

**Objective:** Beat validation log2-MSE of 0.374 using TensorFlow

**Performance:** Validation log2-MSE = **0.495** (Best among train_model.py, train_model_v3.py, train_model_v4.py)

## Key Improvements

1. Dataset rebalancing (~9,000 samples per (k,m) combination)
2. Predicting log2(m-height) instead of raw values
3. Advanced architecture with attention and residual blocks
4. Group-weighted loss (1.0-5.0x for hard cases)
5. Stronger regularization and early stopping

## Setup and Imports

In [None]:
import numpy as np
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import (
    Input, Dense, Embedding, Flatten, Concatenate,
    Dropout, BatchNormalization, Add, Reshape, MultiHeadAttention, Lambda
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import AdamW
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("="*70)
print("HEIGHT PREDICTION MODEL - TRAINING SCRIPT")
print("="*70)
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

## Step 1: Load Data

Using augmented dataset (properly merged DS-1 + DS-2 + DS-3 with balanced augmentation)

In [None]:
print("STEP 1: Loading Data")
print("-"*70)

with open('augmented_n_k_m_P.pkl', 'rb') as f:
    inputs_raw = pickle.load(f)

with open('augmented_mHeights.pkl', 'rb') as f:
    outputs_raw = pickle.load(f)

print(f"Raw input samples: {len(inputs_raw)}")
print(f"Raw output samples: {len(outputs_raw)}")
if len(inputs_raw) > 0:
    sample = inputs_raw[0]
    print(f"Sample structure: [n={sample[0]}, k={sample[1]}, m={sample[2]}, P_matrix shape={sample[3].shape}]")
print(f"Output range: [{np.min(outputs_raw):.2f}, {np.max(outputs_raw):.2f}]")

## Step 2: Analyze Class Distribution

In [None]:
print("STEP 2: Analyzing Class Distribution (BEFORE Rebalancing)")
print("-"*70)

# Group samples by (k, m) combinations
groups = defaultdict(list)
for i, sample in enumerate(inputs_raw):
    k = int(sample[1])
    m = int(sample[2])
    groups[(k, m)].append(i)

print(f"Total unique (k,m) combinations: {len(groups)}")
print("\nDistribution by (k,m):")
total_samples = len(inputs_raw)
for (k, m), indices in sorted(groups.items()):
    count = len(indices)
    percentage = (count / total_samples) * 100
    print(f"  k={k}, m={m}: {count:6d} samples ({percentage:5.2f}%)")

max_count = max(len(indices) for indices in groups.values())
min_count = min(len(indices) for indices in groups.values())
print(f"\nImbalance ratio: {max_count/min_count:.1f}x")

## Step 3: Verify Balanced Dataset

In [None]:
print("STEP 3: Verifying Dataset Balance")
print("-"*70)

print("Using pre-balanced augmented dataset (12,000 samples per (k,m) group)")
print(f"Dataset already balanced at: {len(inputs_raw):,} total samples")

## Step 4: Prepare Data for Training

In [None]:
print("STEP 4: Preparing Data for Training")
print("-"*70)

inputs_rebalanced = inputs_raw
outputs_rebalanced = outputs_raw

# Extract n, k, m values and flatten P matrices
n_values = []
k_values = []
m_values = []
P_matrices_flattened = []

for sample in inputs_rebalanced:
    n_values.append(sample[0])
    k_values.append(sample[1])
    m_values.append(sample[2])
    P_matrices_flattened.append(sample[3].flatten())

n_values = np.array(n_values, dtype=np.float32).reshape(-1, 1)
k_values = np.array(k_values, dtype=np.int32).reshape(-1, 1)
m_values = np.array(m_values, dtype=np.int32).reshape(-1, 1)
outputs_array = np.array(outputs_rebalanced, dtype=np.float32)

# Pad P matrices to same size
max_p_size = max(len(p) for p in P_matrices_flattened)
P_matrices_padded = []

for p in P_matrices_flattened:
    if len(p) < max_p_size:
        padded = np.zeros(max_p_size, dtype=np.float32)
        padded[:len(p)] = p
        P_matrices_padded.append(padded)
    else:
        P_matrices_padded.append(p)

P_matrices = np.array(P_matrices_padded, dtype=np.float32)

# Normalize P matrices
scaler = StandardScaler()
P_matrices_normalized = scaler.fit_transform(P_matrices)

outputs_array = np.maximum(outputs_array, 1.0)

print(f"n_values shape: {n_values.shape}")
print(f"k_values shape: {k_values.shape}, range: [{k_values.min()}, {k_values.max()}]")
print(f"m_values shape: {m_values.shape}, range: [{m_values.min()}, {m_values.max()}]")
print(f"P_matrices shape: {P_matrices.shape}")
print(f"P matrices normalized: mean={P_matrices_normalized.mean():.4f}, std={P_matrices_normalized.std():.4f}")
print(f"Output (m-height) range: [{outputs_array.min():.2f}, {outputs_array.max():.2f}]")

## Step 5: Stratified Train-Val Split

In [None]:
print("STEP 5: Creating Stratified Train-Validation Split")
print("-"*70)

# Create stratification labels based on (k, m) combinations
stratify_labels = k_values.flatten() * 10 + m_values.flatten()

# Split data (85% train, 15% validation)
(n_train, n_val,
 k_train, k_val,
 m_train, m_val,
 P_train, P_val,
 y_train, y_val,
 strat_train, strat_val) = train_test_split(
    n_values, k_values, m_values, P_matrices_normalized, outputs_array,
    stratify_labels,
    test_size=0.15,
    random_state=42,
    stratify=stratify_labels
)

print(f"Training samples: {len(y_train)}")
print(f"Validation samples: {len(y_val)}")

# Verify stratification
print("\nValidation set distribution:")
val_groups = defaultdict(int)
for k, m in zip(k_val.flatten(), m_val.flatten()):
    val_groups[(k, m)] += 1
for (k, m), count in sorted(val_groups.items()):
    percentage = (count / len(y_val)) * 100
    print(f"  k={k}, m={m}: {count:5d} samples ({percentage:5.2f}%)")

## Step 5B: Compute Sample Weights for Group-Weighted Loss

Higher k and m values get higher weights to force the model to focus on difficult groups.

In [None]:
print("STEP 5B: Computing Sample Weights for Group-Weighted Loss")
print("-"*70)

# Define per-group weights based on complexity
group_weights = {
    (4, 2): 1.0,   # baseline
    (4, 3): 1.0,
    (5, 2): 1.0,
    (4, 4): 1.5,   # increase focus
    (5, 3): 1.5,
    (6, 2): 2.0,   # significant focus
    (5, 4): 3.0,   # high focus
    (6, 3): 3.0,
    (4, 5): 5.0,   # maximum focus on worst performer
}

# Compute sample weights for training set
sample_weights_train = np.array([
    group_weights.get((int(k), int(m)), 1.0)
    for k, m in zip(k_train.flatten(), m_train.flatten())
], dtype=np.float32)

print("Sample weights distribution:")
for (k, m), weight in sorted(group_weights.items()):
    count = np.sum((k_train.flatten() == k) & (m_train.flatten() == m))
    print(f"  k={k}, m={m}: weight={weight:.1f} ({count:5d} samples)")

print(f"\nSample weights shape: {sample_weights_train.shape}")
print(f"Sample weights range: [{sample_weights_train.min():.1f}, {sample_weights_train.max():.1f}]")
print(f"Average weight: {sample_weights_train.mean():.2f}")

## Step 6: Build Model

Advanced model with:
- Embeddings for categorical k and m
- Deep processing of P matrix with attention
- Residual blocks
- Output: log2(m-height) with constraint ≥ 1.0

In [None]:
def build_model(p_shape, k_vocab_size=7, m_vocab_size=6):
    # Inputs
    n_input = Input(shape=(1,), name='n_input')
    k_input = Input(shape=(1,), name='k_input', dtype=tf.int32)
    m_input = Input(shape=(1,), name='m_input', dtype=tf.int32)
    P_input = Input(shape=(p_shape,), name='P_input')

    # Embeddings for categorical variables
    k_embed = Flatten()(Embedding(k_vocab_size, 32, name='k_embedding')(k_input))
    m_embed = Flatten()(Embedding(m_vocab_size, 32, name='m_embedding')(m_input))

    # P matrix processing
    x = Dense(256, activation='gelu')(P_input)
    x = BatchNormalization()(x)
    x = Dropout(0.2)(x)
    x = Dense(512, activation='gelu')(x)
    x = BatchNormalization()(x)

    # Multi-head attention on P features
    x_attn = Reshape((1, 512))(x)
    x_attn = MultiHeadAttention(num_heads=8, key_dim=64, dropout=0.1)(x_attn, x_attn)
    x_attn = Flatten()(x_attn)

    # Combine all features
    combined = Concatenate()([n_input, k_embed, m_embed, x_attn])

    # Deep network with residual connections
    x = Dense(1024, activation='gelu')(combined)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)

    # 4 residual blocks
    for i in range(4):
        residual = x
        x = Dense(1024, activation='gelu')(x)
        x = BatchNormalization()(x)
        x = Dropout(0.4)(x)
        x = Dense(1024, activation='gelu')(x)
        x = BatchNormalization()(x)
        x = Add()([x, residual])

    # Final dense layers
    x = Dense(512, activation='gelu')(x)
    x = BatchNormalization()(x)
    x = Dense(256, activation='gelu')(x)

    # Output layer: Predict log2(m-height), then convert to m-height
    log2_pred = Dense(1, activation='linear', name='log2_prediction')(x)
    
    # Ensure log2_pred ≥ 0 (so m-height ≥ 1) using softplus
    log2_positive = Lambda(lambda x: tf.nn.softplus(x), name='softplus_activation')(log2_pred)
    
    # Convert to m-height: 2^(log2_pred)
    output = Lambda(lambda x: tf.pow(2.0, x), name='m_height_output')(log2_positive)

    model = Model(
        inputs=[n_input, k_input, m_input, P_input],
        outputs=output,
        name='height_prediction_model'
    )

    return model

# Build the model
p_shape = P_train.shape[1]
model = build_model(p_shape, k_vocab_size=k_values.max()+1, m_vocab_size=m_values.max()+1)

print(f"Model built successfully!")
print(f"Total parameters: {model.count_params():,}")
model.summary()

## Step 7: Define Custom Loss Function

In [None]:
def log2_mse_loss(y_true, y_pred):
    """
    Custom loss function: MSE in log2 space
    Loss = mean((log2(y_true) - log2(y_pred))^2)
    """
    epsilon = 1e-7

    # Ensure positive values
    y_true = tf.maximum(y_true, epsilon)
    y_pred = tf.maximum(y_pred, epsilon)

    # Convert to log2
    log2_true = tf.math.log(y_true) / tf.math.log(2.0)
    log2_pred = tf.math.log(y_pred) / tf.math.log(2.0)

    # MSE in log2 space
    return tf.reduce_mean(tf.square(log2_true - log2_pred))

print("Custom log2-MSE loss function defined")

## Step 8: Compile Model

In [None]:
print("STEP 8: Compiling Model")
print("-"*70)

# Reduced learning rate (1e-3 -> 5e-4) and increased weight decay (1e-4 -> 1e-3)
optimizer = AdamW(learning_rate=5e-4, weight_decay=1e-3)

model.compile(
    optimizer=optimizer,
    loss=log2_mse_loss,
    metrics=[log2_mse_loss]
)

print("Model compiled with AdamW optimizer and log2-MSE loss")
print(f"  Learning rate: 5e-4 (reduced from 1e-3)")
print(f"  Weight decay: 1e-3 (increased from 1e-4)")

## Step 9: Setup Callbacks

In [None]:
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=30,  # Reduced from 50
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=20,
        min_lr=1e-6,
        verbose=1
    ),
    ModelCheckpoint(
        'best_model.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
]

print("Callbacks configured:")
print("  - EarlyStopping (patience=30)")
print("  - ReduceLROnPlateau (patience=20, factor=0.5)")
print("  - ModelCheckpoint (saves best model)")

## Step 10: Train Model

In [None]:
print("="*70)
print("STEP 10: TRAINING MODEL")
print("="*70)
print(f"Batch size: 256")
print(f"Max epochs: 200")
print(f"Early stopping patience: 30")
print(f"Group-weighted loss: ENABLED (weights 1.0-5.0x)")

history = model.fit(
    [n_train, k_train, m_train, P_train],
    y_train,
    sample_weight=sample_weights_train,
    validation_data=([n_val, k_val, m_val, P_val], y_val),
    epochs=200,
    batch_size=256,
    callbacks=callbacks,
    verbose=1
)

print("\nTraining completed!")

## Step 11: Evaluate Model

In [None]:
print("="*70)
print("STEP 11: EVALUATING MODEL")
print("="*70)

# Load best model
model.load_weights('best_model.h5')

# Make predictions
y_pred_train = model.predict([n_train, k_train, m_train, P_train], verbose=0).flatten()
y_pred_val = model.predict([n_val, k_val, m_val, P_val], verbose=0).flatten()

# Compute overall log2-MSE
def compute_log2_mse(y_true, y_pred):
    epsilon = 1e-7
    y_true = np.maximum(y_true, epsilon)
    y_pred = np.maximum(y_pred, epsilon)
    log2_true = np.log2(y_true)
    log2_pred = np.log2(y_pred)
    return np.mean((log2_true - log2_pred) ** 2)

train_log2_mse = compute_log2_mse(y_train, y_pred_train)
val_log2_mse = compute_log2_mse(y_val, y_pred_val)

print(f"Training log2-MSE: {train_log2_mse:.6f}")
print(f"Validation log2-MSE: {val_log2_mse:.6f}")

# Check prediction constraints
print("\nPrediction Statistics:")
print(f"  Train predictions - Min: {y_pred_train.min():.4f}, Max: {y_pred_train.max():.4f}")
print(f"  Val predictions - Min: {y_pred_val.min():.4f}, Max: {y_pred_val.max():.4f}")
print(f"  All predictions ≥ 1.0: {(y_pred_val.min() >= 1.0)}")

## Step 12: Per-Group Analysis

In [None]:
print("="*70)
print("PER-GROUP PERFORMANCE ANALYSIS")
print("="*70)

# Compute per-(k,m) metrics for validation set
group_metrics = defaultdict(lambda: {'true': [], 'pred': []})

for i in range(len(y_val)):
    k = k_val[i, 0]
    m = m_val[i, 0]
    group_metrics[(k, m)]['true'].append(y_val[i])
    group_metrics[(k, m)]['pred'].append(y_pred_val[i])

# Save to file and print
with open('per_group_performance.txt', 'w') as f:
    f.write("="*70 + "\n")
    f.write("PER-GROUP PERFORMANCE BREAKDOWN\n")
    f.write("="*70 + "\n\n")

    print("\nValidation Log2-MSE by (k,m) combination:")
    f.write("Validation Log2-MSE by (k,m) combination:\n")
    f.write("-"*70 + "\n")

    for (k, m), data in sorted(group_metrics.items()):
        true_vals = np.array(data['true'])
        pred_vals = np.array(data['pred'])
        group_log2_mse = compute_log2_mse(true_vals, pred_vals)

        output_line = f"  k={k}, m={m}: {group_log2_mse:.6f} (n={len(true_vals)} samples)"
        print(output_line)
        f.write(output_line + "\n")

print("\nPer-group performance saved to: per_group_performance.txt")

## Step 13: Generate Plots

In [None]:
# Plot 1: Training History
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss', linewidth=2)
plt.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Log2-MSE Loss', fontsize=12)
plt.title('Training History', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
# Plot last 80% of training
start_epoch = int(len(history.history['loss']) * 0.2)
plt.plot(range(start_epoch, len(history.history['loss'])),
         history.history['loss'][start_epoch:],
         label='Train Loss', linewidth=2)
plt.plot(range(start_epoch, len(history.history['val_loss'])),
         history.history['val_loss'][start_epoch:],
         label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Log2-MSE Loss', fontsize=12)
plt.title('Training History (Last 80%)', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

# Plot 2: Predictions vs True Values
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.scatter(y_val, y_pred_val, alpha=0.3, s=10)
plt.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()],
         'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('True m-Height', fontsize=12)
plt.ylabel('Predicted m-Height', fontsize=12)
plt.title('Predictions vs True Values (Validation Set)', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.scatter(y_val, y_pred_val, alpha=0.3, s=10)
plt.plot([y_val.min(), y_val.max()], [y_val.min(), y_val.max()],
         'r--', linewidth=2, label='Perfect Prediction')
plt.xlabel('True m-Height (log scale)', fontsize=12)
plt.ylabel('Predicted m-Height (log scale)', fontsize=12)
plt.title('Predictions vs True Values - Log Scale', fontsize=14, fontweight='bold')
plt.xscale('log')
plt.yscale('log')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('predictions_scatter.png', dpi=150, bbox_inches='tight')
plt.show()

## Final Results Summary

In [None]:
print("="*70)
print("FINAL RESULTS SUMMARY")
print("="*70)
print(f"Training samples: {len(y_train):,}")
print(f"Validation samples: {len(y_val):,}")
print(f"Model parameters: {model.count_params():,}")
print()
print(f"Training log2-MSE:   {train_log2_mse:.6f}")
print(f"Validation log2-MSE: {val_log2_mse:.6f}")
print(f"Target to beat:      0.374000")
print()

if val_log2_mse < 0.374:
    improvement = ((0.374 - val_log2_mse) / 0.374) * 100
    print(f"✅ SUCCESS! Beat target by {improvement:.1f}%")
    print(f"   Improvement: {0.374 - val_log2_mse:.6f}")
else:
    deficit = ((val_log2_mse - 0.374) / 0.374) * 100
    print(f"❌ Did not beat target (worse by {deficit:.1f}%)")
    print(f"   Need to improve by: {val_log2_mse - 0.374:.6f}")

print()
print("All predictions ≥ 1.0:", "✅ Yes" if y_pred_val.min() >= 1.0 else "❌ No")
print(f"Prediction range: [{y_pred_val.min():.2f}, {y_pred_val.max():.2f}]")
print()
print("="*70)
print("DELIVERABLES SAVED:")
print("="*70)
print("  1. best_model.h5 - Trained model weights")
print("  2. training_history.png - Loss curves")
print("  3. predictions_scatter.png - Prediction quality plots")
print("  4. per_group_performance.txt - Detailed per-(k,m) metrics")
print("="*70)