# Neural Receiver Training for 5G PUSCH

**Objective**: Train a deep learning model to replace conventional channel estimation + equalization + demapping

**Architecture**: ResNet-based neural receiver with attention mechanism

**Expected Performance**: 2-3 dB SNR gain at BLER = 10^-2

**Training Time**: ~2 hours on RTX 4090

---

## Neural Receiver Architecture

```
Input: y (received signal) [batch, num_rx, num_subcarriers, num_symbols, 2]
       ‚Üì
  Spatial Processing (across RX antennas)
       ‚Üì
  Frequency-Time Feature Extraction (ResNet blocks)
       ‚Üì
  Attention Mechanism (focus on data subcarriers)
       ‚Üì
  LLR Estimation (log-likelihood ratios)
       ‚Üì
Output: Soft bits [batch, num_bits]
```

## Key Innovations

1. **Joint processing**: No separate channel estimation step
2. **SNR-agnostic**: Trained across multiple SNR points
3. **Attention**: Learns to focus on reliable subcarriers
4. **ResNet backbone**: Deep architecture with skip connections

In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import h5py
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
import time
from datetime import datetime

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

# Configure GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"‚úÖ GPU configured: {gpus}")
else:
    print("‚ö†Ô∏è  No GPU found!")

## 1. Load and Prepare Dataset

In [None]:
# Load dataset
dataset_path = '/opt/app-root/src/data/pusch_dataset.h5'

print(f"üìÇ Loading dataset: {dataset_path}")
print(f"   File size: {os.path.getsize(dataset_path) / 1024**3:.2f} GB")

with h5py.File(dataset_path, 'r') as f:
    print("\nüìä Dataset Information:")
    print(f"  y_received: {f['y_received'].shape}")
    print(f"  h_channel: {f['h_channel'].shape}")
    print(f"  bits: {f['bits'].shape}")
    print(f"  snr_db: {f['snr_db'].shape}")
    
    # Extract metadata
    num_samples = f.attrs['num_samples']
    num_rx_antennas = f.attrs['num_rx_antennas']
    num_tx_antennas = f.attrs['num_tx_antennas']
    num_subcarriers = f.attrs['num_subcarriers']
    num_ofdm_symbols = f.attrs['num_ofdm_symbols']
    modulation_order = f.attrs['modulation_order']
    
    print(f"\nüìã Metadata:")
    print(f"  Total samples: {num_samples:,}")
    print(f"  RX antennas: {num_rx_antennas}")
    print(f"  Subcarriers: {num_subcarriers}")
    print(f"  OFDM symbols: {num_ofdm_symbols}")
    print(f"  Modulation: {modulation_order}-QAM")

In [None]:
class PUSCHDataset:
    """Efficient data loader for PUSCH training - OPTIMIZED with RAM caching"""
    
    def __init__(self, h5_path, batch_size=16, validation_split=0.1):
        self.h5_path = h5_path
        self.batch_size = batch_size
        
        print("\nüì¶ Loading dataset into RAM for fast training...")
        
        with h5py.File(h5_path, 'r') as f:
            # Load metadata
            self.num_samples = f.attrs['num_samples']
            self.num_bits = f['bits'].shape[-1]
            self.input_shape = f['y_received'].shape[1:]
            
            # Load ALL data into RAM (faster than HDF5 reads during training)
            print(f"   Loading {self.num_samples:,} samples (~43 GB)...")
            self.y_data = f['y_received'][:]  # Load to RAM
            self.bits_data = f['bits'][:]      # Load to RAM
            print(f"   ‚úÖ Data loaded into RAM")
        
        # Train/val split
        num_val = int(self.num_samples * validation_split)
        num_train = self.num_samples - num_val
        
        # Shuffle indices
        np.random.seed(42)
        indices = np.random.permutation(self.num_samples)
        
        self.train_indices = indices[:num_train]
        self.val_indices = indices[num_train:]
        
        print(f"\n‚úÖ Dataset loader initialized")
        print(f"   Training samples: {len(self.train_indices):,}")
        print(f"   Validation samples: {len(self.val_indices):,}")
        print(f"   Batch size: {self.batch_size}")
        print(f"   Input shape: {self.input_shape}")
        print(f"   Output bits: {self.num_bits}")
    
    def _generator(self, indices):
        """Generator for tf.data.Dataset - reads from RAM (fast!)"""
        num_batches = len(indices) // self.batch_size
        
        for i in range(num_batches):
            batch_indices = indices[i * self.batch_size:(i + 1) * self.batch_size]
            
            # Load from RAM (no disk I/O, no sorting needed!)
            y = self.y_data[batch_indices]      # Fast RAM access
            bits = self.bits_data[batch_indices]  # Fast RAM access
            
            # Reshape bits to [batch, bits]
            bits = bits.reshape(bits.shape[0], -1)
            
            # Convert complex to real (stack real/imag)
            y_real = np.stack([y.real, y.imag], axis=-1).astype(np.float32)
            
            yield y_real, bits.astype(np.float32)
    
    def get_dataset(self, training=True):
        """Get tf.data.Dataset for training or validation"""
        indices = self.train_indices if training else self.val_indices
        
        # Shuffle training indices each epoch
        if training:
            np.random.shuffle(indices)
        
        output_signature = (
            tf.TensorSpec(shape=(self.batch_size, *self.input_shape, 2), dtype=tf.float32),
            tf.TensorSpec(shape=(self.batch_size, self.num_bits), dtype=tf.float32)
        )
        
        dataset = tf.data.Dataset.from_generator(
            lambda: self._generator(indices),
            output_signature=output_signature
        )
        
        # Aggressive prefetching for GPU utilization
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset

# Create dataset loader - data will be loaded into RAM
print("="*70)
data_loader = PUSCHDataset(dataset_path, batch_size=16, validation_split=0.1)
print("="*70)

## 2. Build Neural Receiver Architecture

In [None]:
def residual_block(x, filters, kernel_size=(3, 3), activation='relu'):
    """ResNet-style residual block"""
    shortcut = x
    
    # First conv
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation(activation)(x)
    
    # Second conv
    x = layers.Conv2D(filters, kernel_size, padding='same')(x)
    x = layers.BatchNormalization()(x)
    
    # Match dimensions if needed
    if shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, (1, 1), padding='same')(shortcut)
    
    # Add residual
    x = layers.Add()([x, shortcut])
    x = layers.Activation(activation)(x)
    
    return x

def attention_block(x, key_dim=64):
    """Multi-head self-attention for focusing on reliable subcarriers"""
    # Get spatial dimensions from shape (batch dimension handled automatically)
    height, width, channels = x.shape[1:]
    
    # Flatten spatial dimensions: [batch, seq_len, features]
    x_reshaped = layers.Reshape((height * width, channels))(x)
    
    # Multi-head attention
    attn_output = layers.MultiHeadAttention(
        num_heads=4,
        key_dim=key_dim,
        dropout=0.1
    )(x_reshaped, x_reshaped)
    
    # Reshape back to spatial dimensions
    attn_output = layers.Reshape((height, width, channels))(attn_output)
    
    # Residual connection
    x = layers.Add()([x, attn_output])
    x = layers.LayerNormalization()(x)
    
    return x

def build_neural_receiver(input_shape, num_bits, num_rx_antennas=4):
    """Build neural receiver model
    
    Args:
        input_shape: (num_rx, num_subcarriers, num_symbols, 2)
        num_bits: Number of output bits
        num_rx_antennas: Number of receive antennas
    
    Returns:
        Keras model
    """
    inputs = keras.Input(shape=input_shape, name='received_signal')
    
    # Spatial processing: process each RX antenna separately then combine
    antenna_features = []
    for rx_idx in range(num_rx_antennas):
        # Extract single antenna: [batch, num_sc, num_sym, 2]
        x = layers.Lambda(lambda x, idx=rx_idx: x[:, idx, :, :, :])(inputs)
        
        # Initial feature extraction
        x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        antenna_features.append(x)
    
    # Combine antenna features
    if num_rx_antennas > 1:
        x = layers.Concatenate(axis=-1)(antenna_features)
    else:
        x = antenna_features[0]
    
    # Deep feature extraction with ResNet blocks
    x = residual_block(x, filters=64, kernel_size=(5, 3))
    x = residual_block(x, filters=64, kernel_size=(5, 3))
    x = layers.MaxPooling2D((2, 1))(x)  # [896,14] -> [448,14]
    
    x = residual_block(x, filters=128, kernel_size=(5, 3))
    x = residual_block(x, filters=128, kernel_size=(5, 3))
    
    # Moderate pooling before attention to fit in 48GB memory
    # Reduces sequence length from 6272 to 1568 (4x reduction)
    x = layers.MaxPooling2D((2, 2))(x)  # [448,14] -> [224,7]
    
    # Attention mechanism (seq_len=1568, fits comfortably in 48GB)
    x = attention_block(x, key_dim=64)
    
    x = residual_block(x, filters=256, kernel_size=(3, 3))
    x = layers.MaxPooling2D((2, 1))(x)  # [224,7] -> [112,7]
    
    x = residual_block(x, filters=256, kernel_size=(3, 3))
    
    # Global feature aggregation
    x = layers.GlobalAveragePooling2D()(x)
    
    # Dense layers for LLR estimation
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    
    # Output: LLRs (log-likelihood ratios)
    # Use tanh activation to bound outputs
    outputs = layers.Dense(num_bits, activation='tanh', name='llr_output')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs, name='neural_receiver')
    
    return model

# Build model
model = build_neural_receiver(
    input_shape=(*data_loader.input_shape, 2),
    num_bits=data_loader.num_bits,
    num_rx_antennas=num_rx_antennas
)

print("\nüß† Neural Receiver Architecture:")
model.summary()

# Count parameters
total_params = model.count_params()
print(f"\nüìä Total parameters: {total_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024**2:.1f} MB (FP32)")

## 3. Define Loss and Metrics

In [None]:
def binary_cross_entropy_with_llr(y_true, y_pred):
    """BCE loss for LLR outputs
    
    y_true: ground truth bits (0 or 1)
    y_pred: predicted LLRs (tanh output, -1 to +1)
    """
    # Convert bits {0, 1} to {-1, +1}
    y_true_bipolar = 2.0 * y_true - 1.0
    
    # LLR-based BCE
    # When bit=1 (bipolar=+1), we want LLR > 0
    # When bit=0 (bipolar=-1), we want LLR < 0
    loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=(y_true_bipolar + 1.0) / 2.0,
        logits=y_pred * 5.0  # Scale LLRs
    )
    
    return tf.reduce_mean(loss)

def bit_error_rate(y_true, y_pred):
    """Bit Error Rate metric"""
    # Hard decision: LLR > 0 ‚Üí bit = 1
    y_pred_hard = tf.cast(y_pred > 0, tf.float32)
    
    errors = tf.not_equal(y_true, y_pred_hard)
    ber = tf.reduce_mean(tf.cast(errors, tf.float32))
    
    return ber

# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=binary_cross_entropy_with_llr,
    metrics=[bit_error_rate]
)

print("‚úÖ Model compiled")
print(f"   Optimizer: Adam (lr=1e-3)")
print(f"   Loss: Binary Cross-Entropy with LLR")
print(f"   Metrics: Bit Error Rate")

## 4. Training Configuration

In [None]:
# Training parameters
EPOCHS = 20
INITIAL_LR = 1e-3
MIN_LR = 1e-6

# Callbacks
callbacks = [
    # Save BEST model (based on validation BER)
    keras.callbacks.ModelCheckpoint(
        '/opt/app-root/src/models/neural_rx_best.h5',
        monitor='val_bit_error_rate',
        mode='min',
        save_best_only=True,
        verbose=1
    ),
    
    # Save checkpoint EVERY epoch (for crash recovery)
    keras.callbacks.ModelCheckpoint(
        '/opt/app-root/src/models/neural_rx_epoch_{epoch:02d}.h5',
        save_freq='epoch',
        verbose=1
    ),
    
    # Learning rate schedule
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=MIN_LR,
        verbose=1
    ),
    
    # Early stopping
    keras.callbacks.EarlyStopping(
        monitor='val_bit_error_rate',
        mode='min',  # Lower BER is better
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    
    # TensorBoard
    keras.callbacks.TensorBoard(
        log_dir=f'/opt/app-root/src/results/logs/{datetime.now().strftime("%Y%m%d-%H%M%S")}',
        histogram_freq=1
    )
]

print("‚úÖ Training configuration:")
print(f"   Epochs: {EPOCHS}")
print(f"   Initial LR: {INITIAL_LR}")
print(f"   Checkpointing:")
print(f"     - Best model: neural_rx_best.h5 (when val_ber improves)")
print(f"     - Every epoch: neural_rx_epoch_XX.h5 (crash recovery)")
print(f"   Callbacks: LR Scheduler, Early Stopping, TensorBoard")

## 5. Train Neural Receiver

‚è±Ô∏è **Expected Duration**: ~2 hours on RTX 4090

In [None]:
# Get datasets
train_dataset = data_loader.get_dataset(training=True)
val_dataset = data_loader.get_dataset(training=False)

# Calculate steps
steps_per_epoch = len(data_loader.train_indices) // data_loader.batch_size
validation_steps = len(data_loader.val_indices) // data_loader.batch_size

print(f"\nüöÄ Starting training...")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Validation steps: {validation_steps}")
print(f"   Total epochs: {EPOCHS}")
print(f"\n{'='*70}\n")

# Train
start_time = time.time()

history = model.fit(
    train_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_dataset,
    validation_steps=validation_steps,
    callbacks=callbacks,
    verbose=1
)

training_time = time.time() - start_time

print(f"\n{'='*70}")
print(f"‚úÖ Training complete!")
print(f"   Total time: {training_time/3600:.2f} hours")
print(f"   Final BER: {history.history['val_bit_error_rate'][-1]:.6f}")
print(f"   Best BER: {min(history.history['val_bit_error_rate']):.6f}")

## 6. Save Final Model

In [None]:
# Save final model
model.save('/opt/app-root/src/models/neural_rx_final.h5')
print(f"‚úÖ Final model saved: /opt/app-root/src/models/neural_rx_final.h5")

# Save model architecture as JSON
with open('/opt/app-root/src/models/neural_rx_architecture.json', 'w') as f:
    f.write(model.to_json())
print(f"‚úÖ Architecture saved: /opt/app-root/src/models/neural_rx_architecture.json")

# Save training history
history_dict = {
    'loss': history.history['loss'],
    'val_loss': history.history['val_loss'],
    'bit_error_rate': history.history['bit_error_rate'],
    'val_bit_error_rate': history.history['val_bit_error_rate'],
    'training_time_hours': training_time / 3600,
    'epochs': len(history.history['loss'])
}

with open('/opt/app-root/src/results/training_history.json', 'w') as f:
    json.dump(history_dict, f, indent=2)
print(f"‚úÖ Training history saved: /opt/app-root/src/results/training_history.json")

## 7. Visualize Training Progress

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
axes[0].plot(history.history['loss'], label='Training Loss', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# BER curves
axes[1].semilogy(history.history['bit_error_rate'], label='Training BER', linewidth=2)
axes[1].semilogy(history.history['val_bit_error_rate'], label='Validation BER', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Bit Error Rate')
axes[1].set_title('Training and Validation BER')
axes[1].legend()
axes[1].grid(True, alpha=0.3, which='both')

plt.tight_layout()
plt.savefig('/opt/app-root/src/results/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Training curves saved: /opt/app-root/src/results/training_curves.png")

## 8. Quick Performance Test

In [None]:
# Test inference speed
print("\n‚ö° Testing inference performance...\n")

# Load best model
best_model = keras.models.load_model(
    '/opt/app-root/src/models/neural_rx_best.h5',
    custom_objects={
        'binary_cross_entropy_with_llr': binary_cross_entropy_with_llr,
        'bit_error_rate': bit_error_rate
    }
)

# Get test batch
test_batch = next(iter(val_dataset))
x_test, y_test = test_batch

# Warmup
for _ in range(10):
    _ = best_model.predict(x_test, verbose=0)

# Benchmark
num_runs = 100
latencies = []

for _ in tqdm(range(num_runs), desc="Benchmarking"):
    start = time.time()
    _ = best_model.predict(x_test, verbose=0)
    latencies.append(time.time() - start)

latencies = np.array(latencies)

print(f"\nüìä Inference Performance (FP32):")
print(f"   Batch size: {data_loader.batch_size}")
print(f"   Mean latency: {latencies.mean() * 1000:.2f} ms")
print(f"   Std latency: {latencies.std() * 1000:.2f} ms")
print(f"   Throughput: {data_loader.batch_size / latencies.mean():.1f} slots/sec")
print(f"   Per-slot latency: {latencies.mean() * 1000 / data_loader.batch_size:.3f} ms")

print(f"\n‚ö†Ô∏è  Note: TensorRT optimization (next notebook) will reduce latency to <1ms per slot")

## Summary

**‚úÖ Neural receiver training complete!**

**Models saved:**
- Best model: `/opt/app-root/src/models/neural_rx_best.h5`
- Final model: `/opt/app-root/src/models/neural_rx_final.h5`
- Architecture: `/opt/app-root/src/models/neural_rx_architecture.json`

**Results:**
- Training history: `/opt/app-root/src/results/training_history.json`
- Training curves: `/opt/app-root/src/results/training_curves.png`

**Next Steps:**
1. Proceed to `03-optimize-tensorrt.ipynb` for FP16 optimization
2. Target: <1ms inference latency per slot
3. Then validate performance in `04-validate-performance.ipynb`