<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/Making_the_Most_of_your_Colab_Subscription.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from google.colab import files
uploaded = files.upload()


Saving making_the_most_of_your_colab_subscription (7).py to making_the_most_of_your_colab_subscription (7).py


In [51]:
!pip install torch numpy matplotlib scikit-learn seaborn
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import traceback
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
# Comprehensive Polar Code Simulation Framework

# Essential Scientific and Deep Learning Libraries
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Machine Learning and Data Handling
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# System and Utilities
import logging
import traceback
import sys
print(f"Python version: {sys.version}")
import torch
print(f"PyTorch version: {torch.__version__}")

# Configure Logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s]: %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

"""
Polar Code Simulation Framework
Version: 2.0
Date: [Current Date]
Block Length: N=32, K=16
Changes:
- Reduced block length
- Channel-specific SNR ranges
- Improved performance computation
"""

class SimulationConfig:
    # Configurable Polar Code Parameters
    BLOCK_LENGTH = 32  # Reduced from previous 128
    INFO_BITS = 16     # Reduced from previous 64
    CODE_RATE = INFO_BITS / BLOCK_LENGTH  # 0.5 rate

    # Channel-Specific SNR Configurations
   # Channel-Specific SNR Configurations
    AWGN_SNR_RANGE = np.arange(0, 5.5, 0.5)    # 0-5 dB, 0.5 dB steps
    RAYLEIGH_SNR_RANGE = np.linspace(0, 10, 10)  # Broader range for Rayleigh
    LEARNING_RATE = 1e-3
    EPOCHS = 50
    BATCH_SIZE = 32
    NUM_SAMPLES = 5000
    LIST_SIZES = [1, 4, 8]
    TEST_SPLIT = 0.2

    # Channel and Performance Parameters
   # SNR_RANGE = np.linspace(0, 10, 10)
    #LIST_SIZES = [1, 8, 16]
    #NUM_TRIALS = 2000

    # Device Configuration
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @classmethod
    def display_config(cls):
        """
        Display simulation configuration
        """
        logging.info("🚀 Simulation Configuration:")
        logging.info(f"Block Length: {cls.BLOCK_LENGTH}")
        logging.info(f"Information Bits: {cls.INFO_BITS}")
        logging.info(f"Code Rate: {cls.CODE_RATE:.2f}")
        logging.info(f"Device: {cls.DEVICE}")
        logging.info(f"Learning Rate: {cls.LEARNING_RATE}")
        logging.info(f"Epochs: {cls.EPOCHS}")
        logging.info(f"Batch Size: {cls.BATCH_SIZE}")















# Optional: Comprehensive channel capacity plotting function
def plot_channel_capacities(channels, snr_range):
    """
    Plot Channel Capacities for multiple channels

    Args:
        channels (dict): Dictionary of channel simulators
        snr_range (np.ndarray): Range of SNR values in dB
    """
    plt.figure(figsize=(12, 7))

    # Color and style configurations
    channel_styles = {
        'AWGN': {
            'color': 'blue',
            'linestyle': '-',
            'marker': 'o',
            'label': 'AWGN Channel'
        },
        'Rayleigh': {
            'color': 'green',
            'linestyle': '--',
            'marker': 's',
            'label': 'Rayleigh Channel'
        }
    }

    # Compute and plot capacities for each channel
    for channel_name, channel in channels.items():
        # Compute channel capacities
        capacities = [channel.compute_channel_capacity(snr) for snr in snr_range]

        # Plot with specified style
        plt.plot(
            snr_range,
            capacities,
            color=channel_styles[channel_name]['color'],
            linestyle=channel_styles[channel_name]['linestyle'],
            marker=channel_styles[channel_name]['marker'],
            label=channel_styles[channel_name]['label'],
            linewidth=2
        )

    # Formatting
    plt.title('Channel Capacity Comparison', fontsize=16)
    plt.xlabel('SNR (dB)', fontsize=12)
    plt.ylabel('Channel Capacity (bits/channel use)', fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.legend(loc='best')

    plt.tight_layout()
    plt.savefig('combined_channel_capacities.png', dpi=300)
    plt.close()

def compute_channel_performance(
    decoder,
    channel,
    snr_range,
    polar_code_gen,
    list_size=1,
    num_trials=1000
):
    """
    Compute Channel Performance Metrics with Enhanced Error Handling
    """
    # Initialize result arrays with default values
    ber_values = np.ones_like(snr_range, dtype=float) * 1e-5
    bler_values = np.ones_like(snr_range, dtype=float) * 1e-5

    for idx, snr in enumerate(snr_range):
        total_bit_errors = 0
        total_block_errors = 0
        total_bits_processed = 0
        total_blocks_processed = 0

        for trial in range(num_trials):
            try:
                # Generate information bits
                info_bits = polar_code_gen.generate_info_bits()

                # Encode polar code
                encoded_signal = polar_code_gen.polar_encode(info_bits)

                # Verify channel object and simulate method
                if not hasattr(channel, 'simulate'):
                    print(f"Error: Channel {channel.channel_type} lacks simulate method")
                    continue

                # Simulate channel
                received_signal = channel.simulate(encoded_signal, snr)

            except Exception as e:
                print(f"Safe simulate error: {e}")
                # Return default values instead of None
                return ber_values, bler_values


        # ... (Your existing performance calculation logic using received_signal) ...

    # Return the performance values
    return ber_values, bler_values



def compute_channel_performance(
    decoder,
    channel,
    snr_range,
    polar_code_gen,
    list_size=1,
    num_trials=1000
):
    """
    Compute Channel Performance Metrics with Enhanced Error Handling
    """
    # Initialize result arrays with default values
    ber_values = np.ones_like(snr_range, dtype=float) * 1e-5
    bler_values = np.ones_like(snr_range, dtype=float) * 1e-5

    for idx, snr in enumerate(snr_range):
        total_bit_errors = 0
        total_block_errors = 0
        total_bits_processed = 0
        total_blocks_processed = 0

        for trial in range(num_trials):
            try:
                # Generate information bits
                info_bits = polar_code_gen.generate_info_bits()

                # Encode polar code
                encoded_signal = polar_code_gen.polar_encode(info_bits)

                # Verify channel object and simulate method
                if not hasattr(channel, 'simulate'):
                    print(f"Error: Channel {channel.channel_type} lacks simulate method")
                    continue

            except Exception as e:
                print(f"Safe simulate error: {e}")
                # Return default values instead of None
                return ber_values, bler_values


class ChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        """
        Channel Simulator for communication system simulations
        """
        self.channel_type = channel_type
        print(f"Channel Simulator initialized for {channel_type}")

    def simulate(self, encoded_signal, snr_db):
        """
        Simulate signal transmission through a specified channel
        """
        try:
            # Convert bits {0,1} to BPSK: {+1, -1}
            bpsk_signal = 1 - 2 * np.array(encoded_signal)

            # Convert SNR from dB to linear scale
            snr_linear = 10 ** (snr_db / 10)

            # Compute signal power
            signal_power = np.mean(bpsk_signal**2)

            # Noise power calculation
            noise_power = signal_power / snr_linear
            noise_std = np.sqrt(noise_power / 2.0)

            # Channel-specific simulation
            if self.channel_type == 'AWGN':
                # Additive White Gaussian Noise
                noise = np.random.normal(0, noise_std, bpsk_signal.shape)
                received_signal = bpsk_signal + noise

            elif self.channel_type == 'Rayleigh':
                # Rayleigh Fading Channel
                # Generate Rayleigh fading coefficient
                fading = np.random.rayleigh(scale=1.0, size=bpsk_signal.shape)

                # Generate noise
                noise = np.random.normal(0, noise_std, bpsk_signal.shape)

                # Apply fading and noise
                received_signal = fading * bpsk_signal + noise

            else:
                raise ValueError(f"Unsupported channel type: {self.channel_type}")

            return received_signal

        except Exception as e:
            print(f"Error in channel simulation: {e}")
            raise

    def compute_channel_capacity(self, snr_db):
        """
        Compute channel capacity

        Args:
            snr_db (float): Signal-to-Noise Ratio in dB

        Returns:
            float: Channel capacity in bits per channel use
        """
        # Convert SNR to linear scale
        snr_linear = 10 ** (snr_db / 10)

        if self.channel_type == 'AWGN':
            # Shannon's channel capacity for AWGN
            capacity = np.log2(1 + snr_linear)

        elif self.channel_type == 'Rayleigh':
            # Rayleigh fading channel capacity approximation
            capacity = np.log2(1 + snr_linear / (1 + snr_linear))

        else:
            raise ValueError(f"Unsupported channel type: {self.channel_type}")

        return capacity

    def plot_channel_response(self, snr_range):
        """
        Plot channel response and capacity

        Args:
            snr_range (np.ndarray): Range of SNR values in dB

        Returns:
            tuple: SNR values and corresponding channel capacities
        """
        plt.figure(figsize=(10, 6))

        # Compute capacities
        capacities = [self.compute_channel_capacity(snr) for snr in snr_range]

        # Plot capacity
        plt.plot(
            snr_range,
            capacities,
            label=f'{self.channel_type} Channel',
            color='blue' if self.channel_type == 'AWGN' else 'green',
            marker='o'
        )

        plt.title(f'Channel Capacity - {self.channel_type} Channel')
        plt.xlabel('SNR (dB)')
        plt.ylabel('Channel Capacity (bits/channel use)')
        plt.grid(True, linestyle=':', alpha=0.7)
        plt.legend()

        # Save and close
        plt.tight_layout()
        plt.savefig(f'{self.channel_type.lower()}_channel_capacity.png')
        plt.close()

        return snr_range, capacities

def safe_channel_simulate(channel, encoded_signal, snr_db):
    """
    Wrapper function for safe channel simulation.

    This function checks if the channel object has a simulate method before attempting
    to call it. This prevents errors if the channel object does not support simulation.

    Args:
        channel: The channel object.
        encoded_signal: The encoded signal to simulate.
        snr_db: The signal-to-noise ratio in dB.

    Returns:
        The received signal after simulation, or None if the channel object does
        not support simulation.
    """
    try:
        # Check if the channel object has a simulate method
        if not hasattr(channel, 'simulate'):
            print(f"Channel {channel.channel_type} lacks simulate method")  # Handle error
            return None  # Return None to indicate failure

        # Call the simulate method
        return channel.simulate(encoded_signal, snr_db)

    except Exception as e:
        print(f"Safe simulate error: {e}")
        return None  # Return None to indicate failure
############################


    def compute_channel_capacity(self, snr_db):
        """
        Compute channel capacity using Shannon's theorem

        Args:
            snr_db (float): Signal-to-Noise Ratio in decibels

        Returns:
            float: Channel capacity in bits per channel use
        """
        # Convert SNR to linear scale
        snr_linear = 10 ** (snr_db / 10)

        # Shannon's channel capacity formula
        capacity = np.log2(1 + snr_linear)

        return capacity

    def estimate_bit_error_rate(self, encoded_signal, received_signal):
        """
        Estimate Bit Error Rate (BER)

        Args:
            encoded_signal (np.ndarray): Original encoded signal
            received_signal (np.ndarray): Signal after channel transmission

        Returns:
            float: Estimated Bit Error Rate
        """
        # Convert received signal back to binary
        decoded_signal = (received_signal > 0).astype(int)

        # Compute bit errors
        bit_errors = np.sum(np.abs(encoded_signal - decoded_signal))
        total_bits = len(encoded_signal)

        # Compute Bit Error Rate
        ber = bit_errors / total_bits

        return ber

    @staticmethod
    def generate_noise(shape, std_dev):
        """
        Generate Gaussian noise

        Args:
            shape (tuple): Shape of noise array
            std_dev (float): Standard deviation of noise

        Returns:
            np.ndarray: Generated noise
        """
        return np.random.normal(0, std_dev, shape)

    def plot_channel_response(self, snr_range):
        """
        Plot channel capacity across SNR range

        Args:
            snr_range (np.ndarray): Range of SNR values in dB

        Returns:
            tuple: SNR values and corresponding channel capacities
        """
        import matplotlib.pyplot as plt

        # Compute channel capacities
        capacities = [self.compute_channel_capacity(snr) for snr in snr_range]

        # Plot
        plt.figure(figsize=(10, 6))
        plt.plot(snr_range, capacities, label=f'{self.channel_type} Channel')
        plt.title(f'Channel Capacity - {self.channel_type} Channel')
        plt.xlabel('SNR (dB)')
        plt.ylabel('Channel Capacity (bits/channel use)')
        plt.grid(True)
        plt.legend()
        plt.show()

        return snr_range, capacities

# Example usage
if __name__ == "__main__":
    # AWGN Channel
    awgn_channel = ChannelSimulator(channel_type='AWGN')

    # Rayleigh Channel
    rayleigh_channel = ChannelSimulator(channel_type='Rayleigh')

    # SNR Range
    snr_range = np.linspace(0, 20, 100)

    # Plot channel responses
    awgn_channel.plot_channel_response(snr_range)
    rayleigh_channel.plot_channel_response(snr_range)
# Polar Code Generator Class
class PolarCodeGenerator:
    def __init__(self, N, K):
        """
        Polar Code Generator

        Args:
            N (int): Block length
            K (int): Number of information bits
        """
        self.N = N
        self.K = K
        self.channel_type = 'AWGN'  # Default channel type


    def generate_info_bits(self):
        """
        Generate random information bits
        """
        return np.random.randint(2, size=self.K)

    def polar_encode(self, info_bits):
        """
        Simplistic approach: place info_bits at first K positions
        """
        codeword = np.zeros(self.N, dtype=int)
        codeword[:len(info_bits)] = info_bits
        return codeword

# Channel Simulator Class



# Latest with AWGN and Rayleih
def simulate_channel(encoded_signal, snr_db, channel_type='AWGN'):
    """
    Simulate transmission through a specified channel.

    Args:
        encoded_signal (np.ndarray): The encoded signal.
        snr_db (float): Signal-to-noise ratio in dB.
        channel_type (str): Type of channel ('AWGN' or 'Rayleigh').

    Returns:
        np.ndarray: The received signal after channel effects.
    """
    # Convert bits {0,1} to BPSK: {+1, -1}
    bpsk_signal = 1 - 2 * encoded_signal

    # SNR to linear scale
    snr_linear = 10 ** (snr_db / 10)
    signal_power = np.mean(bpsk_signal**2)
    noise_power = signal_power / snr_linear
    noise_std = math.sqrt(noise_power / 2.0)

    if channel_type == 'AWGN':
        noise = np.random.normal(0, noise_std, bpsk_signal.shape)
        received_signal = bpsk_signal + noise
    elif channel_type == 'Rayleigh':
        fading = np.random.rayleigh(scale=1.0, size=bpsk_signal.shape)
        noise = np.random.normal(0, noise_std, bpsk_signal.shape)
        received_signal = fading * bpsk_signal + noise
    else:
        raise ValueError("Unsupported channel type")

    return received_signal
class AWGNChannel:
    def __init__(self):
        self.channel_type = 'AWGN'

    def simulate(self, encoded_signal, snr_db):
        """Simulate AWGN channel."""
        bpsk_signal = 1 - 2 * encoded_signal
        snr_linear = 10 ** (snr_db / 10)
        signal_power = np.mean(bpsk_signal**2)
        noise_power = signal_power / snr_linear
        noise_std = np.sqrt(noise_power / 2.0)
        noise = np.random.normal(0, noise_std, bpsk_signal.shape)
        received_signal = bpsk_signal + noise
        return received_signal

class RayleighChannel:
    def __init__(self):
        self.channel_type = 'Rayleigh'

    def simulate(self, encoded_signal, snr_db):
        """Simulate Rayleigh channel."""
        bpsk_signal = 1 - 2 * encoded_signal
        snr_linear = 10 ** (snr_db / 10)
        signal_power = np.mean(bpsk_signal**2)
        noise_power = signal_power / snr_linear
        noise_std = np.sqrt(noise_power / 2.0)
        fading = np.random.rayleigh(scale=1.0, size=bpsk_signal.shape)
        noise = np.random.normal(0, noise_std, bpsk_signal.shape)
        received_signal = fading * bpsk_signal + noise
        return received_signal
class RNNDecoder(nn.Module):
    def __init__(self, input_size=SimulationConfig.BLOCK_LENGTH, hidden_size=64, num_layers=2):
        super(RNNDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Modify input layer to handle 2D input
        self.input_layer = nn.Linear(input_size, hidden_size)

        self.rnn = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)

        # Output layer
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Ensure input is 2D (batch, features)
        if x.dim() == 1:
            x = x.unsqueeze(0)  # Add batch dimension if missing

        # Input layer to transform input
        x = self.input_layer(x)

        # Prepare hidden state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # RNN processing
        out, _ = self.rnn(x.unsqueeze(1), h0)

        # Take the last time step
        out = out[:, -1, :]

        # Final classification
        out = self.fc(out)

        return out.squeeze(-1)  # Ensure 1D output





class DecoderTrainer:
    def __init__(self, model, learning_rate=1e-3):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Use BCELoss with logits
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=1e-5
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5
        )

    def train(self, X_train, y_train, epochs=200, batch_size=64, validation_split=0.2):
        # Ensure correct tensor shapes
        X_train = X_train.to(self.device)
        y_train = y_train.to(self.device).squeeze()  # Ensure 1D tensor

        # Split into train and validation
        train_size = int((1 - validation_split) * len(X_train))
        X_val, y_val = X_train[train_size:], y_train[train_size:]
        X_train, y_train = X_train[:train_size], y_train[:train_size]

        # Initialize loss tracking
        self.train_losses, self.val_losses = [], []

        for epoch in range(epochs):
            # Training phase
            self.model.train()
            train_loss = self._train_epoch(X_train, y_train, batch_size)
            self.train_losses.append(train_loss)

            # Validation phase
            self.model.eval()
            val_loss = self._validate(X_val, y_val)
            self.val_losses.append(val_loss)

            # Learning rate scheduling
            self.scheduler.step(val_loss)

        return self.train_losses, self.val_losses

    def _train_epoch(self, X_train, y_train, batch_size):
        total_loss = 0

        # Create data loader
        dataset = torch.utils.data.TensorDataset(X_train, y_train)
        loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True
        )

        for batch_X, batch_y in loader:
            # Move to device
            batch_X = batch_X.to(self.device)
            batch_y = batch_y.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            outputs = self.model(batch_X)

            # Ensure compatible shapes
            outputs = outputs.view(-1)
            batch_y = batch_y.view(-1)

            # Compute loss
            loss = self.criterion(outputs, batch_y)

            # Backward pass
            loss.backward()

            # Optimize
            self.optimizer.step()

            # Accumulate loss
            total_loss += loss.item()

        return total_loss / len(loader)

    def _validate(self, X_val, y_val):
        with torch.no_grad():
            # Forward pass
            val_outputs = self.model(X_val)

            # Ensure compatible shapes
            val_outputs = val_outputs.view(-1)
            y_val = y_val.view(-1)

            # Compute validation loss
            val_loss = self.criterion(val_outputs, y_val)

        return val_loss.item()

    def predict(self, X):
        # Ensure input is a tensor
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X).to(self.device)
        else:
            X = X.to(self.device)

        # Set model to evaluation mode
        self.model.eval()

        # Predict
        with torch.no_grad():
            outputs = self.model(X)

        return outputs.cpu().numpy()

def predict(self, X):
    """
    Ensure prediction always returns a valid result
    """
    try:
        # Ensure input is a tensor
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X).to(self.device)
        else:
            X = X.to(self.device)

        # Set model to evaluation mode
        self.model.eval()

        # Predict with no gradient computation
        with torch.no_grad():
            outputs = self.model(X)

        # Ensure outputs are not None and have expected shape
        if outputs is None:
            print("Warning: Model prediction returned None")
            return np.zeros(X.shape[0])

        # Convert to numpy and ensure 1D array
        return outputs.cpu().numpy().flatten()

    except Exception as e:
        print(f"Prediction error: {e}")
        return np.zeros(X.shape[0])

 #Part 3: RNN/ML Decoders and Trainers






# Example usage:
#rnn_model = RNNDecoder()
#rnn_trainer = DecoderTrainer(rnn_model)

# Create some sample data
#X_train = torch.randn(100, SimulationConfig.BLOCK_LENGTH)
#y_train = torch.randint(0, 2, (100, 1)).float()

# Train the model
#rnn_train_losses, rnn_val_losses = rnn_trainer.train(X_train, y_train, epochs=10, batch_size=32)

# Example usage


class MLDecoder(nn.Module):
    def __init__(self, input_size=SimulationConfig.BLOCK_LENGTH): # Reference BLOCK_LENGTH from SimulationConfig
        """
        Multi-Layer Perceptron Decoder
        """
        super(MLDecoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Forward pass through ML decoder
        """
        return self.model(x)






class DecoderTrainer:
    def __init__(self, model, learning_rate=1e-3, patience=5, factor=0.5):
        """
        Initialize Decoder Trainer

        Args:
            model (nn.Module): Neural network model to train
            learning_rate (float): Initial learning rate
            patience (int): Patience for learning rate scheduling
            factor (float): Reduction factor for learning rate
        """
        # Model and device configuration
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Loss function
        self.criterion = nn.BCELoss()

        # Optimizer
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=1e-5
        )

        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=factor,
            patience=patience,
            verbose=True
        )

        # Training tracking
        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')
        self.epochs_no_improve = 0

    def _prepare_data(self, X, y, validation_split=0.2):
        """
        Prepare training and validation data

        Args:
            X (torch.Tensor): Input features
            y (torch.Tensor): Target labels
            validation_split (float): Proportion of data for validation

        Returns:
            tuple: Prepared train and validation datasets
        """
        # Ensure tensors are on the correct device
        X = X.to(self.device)
        y = y.to(self.device)

        # Ensure y is 1D
        y = y.squeeze()

        # Split data
        train_size = int((1 - validation_split) * len(X))
        X_train, X_val = X[:train_size], X[train_size:]
        y_train, y_val = y[:train_size], y[train_size:]

        return X_train, X_val, y_train, y_val

    def _create_dataloader(self, X, y, batch_size, shuffle=True):
        """
        Create a DataLoader for batch processing

        Args:
            X (torch.Tensor): Input features
            y (torch.Tensor): Target labels
            batch_size (int): Number of samples per batch
            shuffle (bool): Whether to shuffle data

        Returns:
            torch.utils.data.DataLoader: Data loader
        """
        dataset = TensorDataset(X, y)
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle
        )

    def _train_epoch(self, dataloader):
        """
        Train for one epoch

        Args:
            dataloader (torch.utils.data.DataLoader): Training data loader

        Returns:
            float: Average training loss
        """
        self.model.train()
        total_loss = 0

        for batch_X, batch_y in dataloader:
            # Move to device (redundant, but safe)
            batch_X = batch_X.to(self.device)
            batch_y = batch_y.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            outputs = self.model(batch_X)

            # Ensure compatible shapes
            outputs = outputs.view(-1)
            batch_y = batch_y.view(-1)

            # Compute loss
            loss = self.criterion(outputs, batch_y)

            # Backward pass
            loss.backward()

            # Optimize
            self.optimizer.step()

            # Accumulate loss
            total_loss += loss.item()

        return total_loss / len(dataloader)

    def _validate(self, dataloader):
        """
        Validate model performance

        Args:
            dataloader (torch.utils.data.DataLoader): Validation data loader

        Returns:
            float: Validation loss
        """
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch_X, batch_y in dataloader:
                # Move to device
                batch_X = batch_X.to(self.device)
                batch_y = batch_y.to(self.device)

                # Forward pass
                outputs = self.model(batch_X)

                # Ensure compatible shapes
                outputs = outputs.view(-1)
                batch_y = batch_y.view(-1)

                # Compute loss
                loss = self.criterion(outputs, batch_y)
                total_loss += loss.item()

        return total_loss / len(dataloader)

    def train(self, X, y, epochs=200, batch_size=64, validation_split=0.2):
        """
        Main training method

        Args:
            X (torch.Tensor): Input features
            y (torch.Tensor): Target labels
            epochs (int): Number of training epochs
            batch_size (int): Batch size for training
            validation_split (float): Proportion of data for validation

        Returns:
            tuple: Training and validation losses
        """
        # Prepare data
        X_train, X_val, y_train, y_val = self._prepare_data(
            X, y, validation_split
        )

        # Create data loaders
        train_loader = self._create_dataloader(
            X_train, y_train, batch_size
        )
        val_loader = self._create_dataloader(
            X_val, y_val, batch_size, shuffle=False
        )

        # Training loop
        for epoch in range(epochs):
            # Train for one epoch
            train_loss = self._train_epoch(train_loader)
            self.train_losses.append(train_loss)

            # Validate
            val_loss = self._validate(val_loader)
            self.val_losses.append(val_loss)

            # Learning rate scheduling
            self.scheduler.step(val_loss)

            # Early stopping and model checkpointing
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.epochs_no_improve = 0
                # Optional: Save best model
                # torch.save(self.model.state_dict(), 'best_model.pth')
            else:
                self.epochs_no_improve += 1

            # Optional early stopping
            if self.epochs_no_improve >= 10:
                print("Early stopping triggered")
                break

        return self.train_losses, self.val_losses

    def predict(self, X):
        """
        Make predictions

        Args:
            X (torch.Tensor or np.ndarray): Input features

        Returns:
            np.ndarray: Predicted probabilities
        """
        # Ensure input is a tensor
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X).to(self.device)
        else:
            X = X.to(self.device)

        # Set model to evaluation mode
        self.model.eval()

        # Predict
        with torch.no_grad():
            outputs = self.model(X)

        return outputs.cpu().numpy()

    def get_learning_curves(self):
        """
        Retrieve training and validation learning curves

        Returns:
            tuple: Training and validation losses
        """
        return self.train_losses, self.val_losses


# Part 4: Dataset Preparation and Utility Functions

def prepare_polar_dataset(polar_code_gen, num_samples): # Reference num_samples from SimulationConfig
    """
    Prepare dataset for Polar Code simulation

    Args:
        polar_code_gen (PolarCodeGenerator): Polar code generator
        num_samples (int): Number of samples to generate

    Returns:
        tuple: Features and labels
    """
    X = []
    y = []

    for _ in range(num_samples):
        # Generate info bits
        info_bits = polar_code_gen.generate_info_bits()

        # Encode
        codeword = polar_code_gen.polar_encode(info_bits)

        # Extract features
        X.append(codeword)
        y.append(1 if np.mean(codeword) > 0.5 else 0)

    return np.array(X), np.array(y)









def plot_ber_bler_performance(
    rnn_trainer,
    ml_trainer,
    polar_code_gen,
    channels,
    list_sizes=[1, 4, 8]
):
    """
    Enhanced BER and BLER Performance Plotting
    """
    plt.figure(figsize=(20, 16))

    # Channel-specific SNR ranges
    snr_ranges = {
        'AWGN': np.arange(0, 5.5, 0.5),    # 0-5 dB, 0.5 dB steps
        'Rayleigh': np.linspace(0, 10, 10)  # Broader range for Rayleigh
    }

    # Color and style configurations
    decoder_colors = {
        'RNN': {
            'AWGN': ['navy', 'blue', 'royalblue'],
            'Rayleigh': ['darkred', 'crimson', 'indianred']
        },
        'ML': {
            'AWGN': ['darkgreen', 'green', 'limegreen'],
            'Rayleigh': ['darkorchid', 'purple', 'mediumpurple']
        }
    }
    markers = ['o', 's', '^']

    channel_types = ['AWGN', 'Rayleigh']
    metric_types = ['BER', 'BLER']

    for channel_idx, channel_type in enumerate(channel_types):
        for metric_idx, metric_type in enumerate(metric_types):
            plt.subplot(2, 2, channel_idx * 2 + metric_idx + 1)

            # Use channel-specific SNR range
            snr_range = snr_ranges[channel_type]

            for decoder_idx, (decoder_name, decoder) in enumerate([('RNN', rnn_trainer), ('ML', ml_trainer)]):
                for list_size_idx, list_size in enumerate(list_sizes):
                    # Compute performance
                    ber, bler = compute_channel_performance(
                        decoder,
                        channels[channel_type],
                        snr_range,
                        polar_code_gen,
                        list_size=list_size,
                        num_trials=1000
                    )

                    # Plot the appropriate metric
                    performance = ber if metric_type == 'BER' else bler

                    plt.semilogy(
                        snr_range,
                        performance,
                        label=f'{decoder_name} {metric_type} (List Size {list_size})',
                        color=decoder_colors[decoder_name][channel_type][list_size_idx],
                        marker=markers[list_size_idx],
                        linestyle='-' if decoder_name == 'RNN' else '--',
                        linewidth=2,
                        markersize=8
                    )

            plt.title(f'{metric_type} - {channel_type} Channel', fontsize=14)
            plt.xlabel('SNR (dB)', fontsize=12)
            plt.ylabel(metric_type, fontsize=12)
            plt.ylim(1e-5, 1e0)
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, which='both', ls=':', alpha=0.7)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.savefig('comprehensive_ber_bler_performance.png', bbox_inches='tight', dpi=300)
    plt.close()

def plot_ber_bler_performance(
    rnn_trainer,
    ml_trainer,
    polar_code_gen,
    channels,
    list_sizes=[1, 4, 8]
):
    """
    Enhanced BER and BLER Performance Plotting
    """
    plt.figure(figsize=(20, 16))

    # Channel-specific SNR ranges
    snr_ranges = {
        'AWGN': np.arange(0, 5.5, 0.5),    # 0-5 dB, 0.5 dB steps
        'Rayleigh': np.linspace(0, 10, 10)  # Broader range for Rayleigh
    }

    # Color and style configurations
    decoder_colors = {
        'RNN': {
            'AWGN': ['navy', 'blue', 'royalblue'],
            'Rayleigh': ['darkred', 'crimson', 'indianred']
        },
        'ML': {
            'AWGN': ['darkgreen', 'green', 'limegreen'],
            'Rayleigh': ['darkorchid', 'purple', 'mediumpurple']
        }
    }
    markers = ['o', 's', '^']

    channel_types = ['AWGN', 'Rayleigh']
    metric_types = ['BER', 'BLER']

    for channel_idx, channel_type in enumerate(channel_types):
        for metric_idx, metric_type in enumerate(metric_types):
            plt.subplot(2, 2, channel_idx * 2 + metric_idx + 1)

            # Use channel-specific SNR range
            snr_range = snr_ranges[channel_type]

            for decoder_idx, (decoder_name, decoder) in enumerate([('RNN', rnn_trainer), ('ML', ml_trainer)]):
                for list_size_idx, list_size in enumerate(list_sizes):
                    # Compute performance
                    performance_result = compute_channel_performance(
                        decoder,
                        channels[channel_type],
                        snr_range,
                        polar_code_gen,
                        list_size=list_size,
                        num_trials=1000
                    )

                     # Check if the function returned valid values
                if performance is not None:
                    ber, bler = performance  # Unpack if valid
                else:
                    # Handle the case where performance is None
                    ber = None
                    bler = None
                    # Unpack the performance result
                    if isinstance(performance_result, tuple) and len(performance_result) == 2:
                        ber, bler = performance_result
                    else:
                        print(f"Unexpected performance result type: {type(performance_result)}")
                        continue

                    # Select the appropriate metric
                    performance = ber if metric_type == 'BER' else bler

                    plt.semilogy(
                        snr_range,
                        performance,
                        label=f'{decoder_name} {metric_type} (List Size {list_size})',
                        color=decoder_colors[decoder_name][channel_type][list_size_idx],
                        marker=markers[list_size_idx],
                        linestyle='-' if decoder_name == 'RNN' else '--',
                        linewidth=2,
                        markersize=8
                    )

            plt.title(f'{metric_type} - {channel_type} Channel', fontsize=14)
            plt.xlabel('SNR (dB)', fontsize=12)
            plt.ylabel(metric_type, fontsize=12)
            plt.ylim(1e-5, 1e0)
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.grid(True, which='both', ls=':', alpha=0.7)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.savefig('comprehensive_ber_bler_performance.png', bbox_inches='tight', dpi=300)
    plt.close()









# Debugging function to verify performance computation
def verify_channel_performance(
    rnn_trainer,
    ml_trainer,
    polar_code_gen,
    snr_range,
    channels
):
    """
    Verify channel performance computation
    """
    print("\nChannel Performance Verification:")

    for decoder_name, decoder in [('RNN', rnn_trainer), ('ML', ml_trainer)]:
        print(f"\n{decoder_name} Decoder Performance:")

        for channel_name, channel in channels.items():
            print(f"  {channel_name} Channel:")

            ber, bler = compute_channel_performance(
                decoder,
                channel,
                snr_range,
                polar_code_gen,
                num_trials=500  # Reduced for quick verification
            )

            print("    BER:")
            print("      Min:", np.min(ber))
            print("      Max:", np.max(ber))
            print("      Mean:", np.mean(ber))

            print("    BLER:")
            print("      Min:", np.min(bler))
            print("      Max:", np.max(bler))









performance_results = {
    'RNN': {
        'ber_awgn': {1: [], 8: [], 16: []},
        'bler_awgn': {1: [], 8: [], 16: []},
        'ber_rayleigh': {1: [], 8: [], 16: []},
        'bler_rayleigh': {1: [], 8: [], 16: []}
    },
    'ML': {
        'ber_awgn': {1: [], 8: [], 16: []},
        'bler_awgn': {1: [], 8: [], 16: []},
        'ber_rayleigh': {1: [], 8: [], 16: []},
        'bler_rayleigh': {1: [], 8: [], 16: []}
    }
}






def plot_training_performance(rnn_trainer, ml_trainer):
    """
    Plot Training and Validation Losses for RNN and ML Decoders

    Args:
        rnn_trainer (DecoderTrainer): RNN decoder trainer
        ml_trainer (DecoderTrainer): ML decoder trainer
    """
    plt.figure(figsize=(15, 6))

    # Color and style configurations
    colors = {
        'rnn': {
            'train': 'blue',
            'val': 'lightblue'
        },
        'ml': {
            'train': 'green',
            'val': 'lightgreen'
        }
    }

    # RNN Losses
    plt.subplot(1, 2, 1)

    # Ensure we have losses
    rnn_train_losses = rnn_trainer.train_losses
    rnn_val_losses = rnn_trainer.val_losses

    if len(rnn_train_losses) > 0:
        plt.plot(
            range(len(rnn_train_losses)),
            rnn_train_losses,
            label='RNN Training Loss',
            color=colors['rnn']['train'],
            linewidth=2
        )

    if len(rnn_val_losses) > 0:
        plt.plot(
            range(len(rnn_val_losses)),
            rnn_val_losses,
            label='RNN Validation Loss',
            color=colors['rnn']['val'],
            linestyle='--',
            linewidth=2
        )

    plt.title('RNN Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.7)

    # ML Losses
    plt.subplot(1, 2, 2)

    # Ensure we have losses
    ml_train_losses = ml_trainer.train_losses
    ml_val_losses = ml_trainer.val_losses

    if len(ml_train_losses) > 0:
        plt.plot(
            range(len(ml_train_losses)),
            ml_train_losses,
            label='ML Training Loss',
            color=colors['ml']['train'],
            linewidth=2
        )

    if len(ml_val_losses) > 0:
        plt.plot(
            range(len(ml_val_losses)),
            ml_val_losses,
            label='ML Validation Loss',
            color=colors['ml']['val'],
            linestyle='--',
            linewidth=2
        )

    plt.title('ML Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.7)

    plt.tight_layout()
    plt.savefig('training_losses.png', dpi=300)
    plt.close()



def plot_training_performance(rnn_trainer, ml_trainer):
    """
    Plot Training and Validation Losses for RNN and ML Decoders

    Args:
        rnn_trainer (DecoderTrainer): RNN decoder trainer
        ml_trainer (DecoderTrainer): ML decoder trainer
    """
    plt.figure(figsize=(15, 6))

    # RNN Losses
    plt.subplot(1, 2, 1)

    # Ensure we have losses for RNN
    rnn_train_losses = rnn_trainer.train_losses
    rnn_val_losses = rnn_trainer.val_losses

    if len(rnn_train_losses) > 0:
        plt.plot(
            range(len(rnn_train_losses)),
            rnn_train_losses,
            label='RNN Training Loss',
            color='blue'
        )

    if len(rnn_val_losses) > 0:
        plt.plot(
            range(len(rnn_val_losses)),
            rnn_val_losses,
            label='RNN Validation Loss',
            color='blue',
            linestyle='--'
        )

    plt.title('RNN Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # ML Losses
    plt.subplot(1, 2, 2)

    # Ensure we have losses for ML
    ml_train_losses = ml_trainer.train_losses
    ml_val_losses = ml_trainer.val_losses

    if len(ml_train_losses) > 0:
        plt.plot(
            range(len(ml_train_losses)),
            ml_train_losses,
            label='ML Training Loss',
            color='green'
        )

    if len(ml_val_losses) > 0:
        plt.plot(
            range(len(ml_val_losses)),
            ml_val_losses,
            label='ML Validation Loss',
            color='green',
            linestyle='--'
        )

    plt.title('ML Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_losses.png')
    plt.close()

# Part 5: Main Function and Plotting

def plot_training_performance(rnn_train_losses, ml_train_losses, rnn_val_losses, ml_val_losses):  # Add validation losses as arguments
    """
    Plot Training and Validation Losses for RNN and ML Decoders
    """


def plot_confusion_matrices(X_test, y_test, rnn_trainer, ml_trainer):
    """
    Plot Confusion Matrices for RNN and ML Decoders
    """
    plt.figure(figsize=(15, 6))

    # RNN Decoder Confusion Matrix
    plt.subplot(1, 2, 1)
    rnn_predictions = rnn_trainer.predict(X_test.cpu().numpy())
    rnn_pred_classes = (rnn_predictions > 0.5).astype(int).flatten()
    rnn_cm = confusion_matrix(y_test.cpu().numpy(), rnn_pred_classes)

    sns.heatmap(
        rnn_cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        square=True
    )
    plt.title('RNN Decoder Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    # ML Decoder Confusion Matrix
    plt.subplot(1, 2, 2)
    ml_predictions = ml_trainer.predict(X_test.cpu().numpy())
    ml_pred_classes = (ml_predictions > 0.5).astype(int).flatten()
    ml_cm = confusion_matrix(y_test.cpu().numpy(), ml_pred_classes)

    sns.heatmap(
        ml_cm,
        annot=True,
        fmt='d',
        cmap='Greens',
        square=True
    )
    plt.title('ML Decoder Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    plt.tight_layout()
    plt.savefig('confusion_matrices.png')
    plt.close()

 #Add to your plotting functions
def plot_channel_capacities(channels, snr_range):
    """
    Plot Channel Capacities for AWGN and Rayleigh Channels on a single plot

    Args:
        channels (dict): Dictionary of channel simulators
        snr_range (np.ndarray): Range of SNR values in dB
    """
    plt.figure(figsize=(12, 7))

    # Color and line style configuration
    channel_styles = {
        'AWGN': {
            'color': 'blue',
            'linestyle': '-',
            'marker': 'o',
            'label': 'AWGN Channel'
        },
        'Rayleigh': {
            'color': 'green',
            'linestyle': '--',
            'marker': 's',
            'label': 'Rayleigh Channel'
        }
    }

    # Compute and plot capacities for each channel
    for channel_name, channel in channels.items():
        # Compute channel capacities
        capacities = [channel.compute_channel_capacity(snr) for snr in snr_range]

        # Plot with specified style
        plt.plot(
            snr_range,
            capacities,
            color=channel_styles[channel_name]['color'],
            linestyle=channel_styles[channel_name]['linestyle'],
            marker=channel_styles[channel_name]['marker'],
            label=channel_styles[channel_name]['label'],
            linewidth=2
        )

    # Formatting
    plt.title('Channel Capacity Comparison', fontsize=16)
    plt.xlabel('SNR (dB)', fontsize=12)
    plt.ylabel('Channel Capacity (bits/channel use)', fontsize=12)
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.legend(loc='best')

    # Annotate key points
    plt.annotate('AWGN: Theoretical Limit',
                 xy=(max(snr_range), max(capacities)),
                 xytext=(5, 5),
                 textcoords='offset points')

    plt.tight_layout()
    plt.savefig('combined_channel_capacities.png', dpi=300)
    plt.close()

# Enhance ChannelSimulator with more sophisticated capacity computation
class ChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        """
        Initialize Channel Simulator

        Args:
            channel_type (str): Type of channel ('AWGN' or 'Rayleigh')
        """
        self.channel_type = channel_type

    def compute_channel_capacity(self, snr_db):
        """
        Compute channel capacity with advanced modeling

        Args:
            snr_db (float): Signal-to-Noise Ratio in decibels

        Returns:
            float: Channel capacity in bits per channel use
        """
        # Convert SNR to linear scale
        snr_linear = 10 ** (snr_db / 10)

        if self.channel_type == 'AWGN':
            # Shannon's channel capacity for AWGN
            # C = log2(1 + SNR)
            capacity = np.log2(1 + snr_linear)

        elif self.channel_type == 'Rayleigh':
            # Rayleigh fading channel capacity approximation
            # More complex model considering fading effects
            # C = log2(1 + γ / (1 + γ))
            capacity = np.log2(1 + snr_linear / (1 + snr_linear))

        else:
            raise ValueError(f"Unsupported channel type: {self.channel_type}")

        return capacity
def plot_channel_capacities(channels, snr_range):
    """
    Plot Channel Capacities for AWGN and Rayleigh Channels

    Args:
        channels (dict): Dictionary of channel simulators
        snr_range (np.ndarray): Range of SNR values in dB
    """
    plt.figure(figsize=(12, 6))

    # Plot for each channel type
    for channel_name, channel in channels.items():
        # Compute channel capacities
        capacities = [channel.compute_channel_capacity(snr) for snr in snr_range]

        # Plot with distinct style
        plt.plot(
            snr_range,
            capacities,
            label=f'{channel_name} Channel',
            marker='o'
        )

    # Formatting
    plt.title('Channel Capacity Comparison')
    plt.xlabel('SNR (dB)')
    plt.ylabel('Channel Capacity (bits/channel use)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()

    # Tight layout and save
    plt.tight_layout()
    plt.savefig('channel_capacities.png')
    plt.close()


# Update the plotting function to accept these parameters
def plot_training_performance(
    rnn_train_losses,
    rnn_val_losses,
    ml_train_losses,
    ml_val_losses
):
    plt.figure(figsize=(15, 6))

    # RNN Losses
    plt.subplot(1, 2, 1)

    if len(rnn_train_losses) > 0:
        plt.plot(
            range(len(rnn_train_losses)),
            rnn_train_losses,
            label='RNN Training Loss',
            color='blue'
        )

    if len(rnn_val_losses) > 0:
        plt.plot(
            range(len(rnn_val_losses)),
            rnn_val_losses,
            label='RNN Validation Loss',
            color='blue',
            linestyle='--'
        )

    plt.title('RNN Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # ML Losses
    plt.subplot(1, 2, 2)

    if len(ml_train_losses) > 0:
        plt.plot(
            range(len(ml_train_losses)),
            ml_train_losses,
            label='ML Training Loss',
            color='green'
        )

    if len(ml_val_losses) > 0:
        plt.plot(
            range(len(ml_val_losses)),
            ml_val_losses,
            label='ML Validation Loss',
            color='green',
            linestyle='--'
        )

    plt.title('ML Decoder Losses')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('training_losses.png')
    plt.close()



def plot_confusion_matrices(X_test, y_test, rnn_trainer, ml_trainer):
    """
    Plot Confusion Matrices for RNN and ML Decoders

    Args:
        X_test (torch.Tensor): Test input features
        y_test (torch.Tensor): Test labels
        rnn_trainer (DecoderTrainer): RNN decoder trainer
        ml_trainer (DecoderTrainer): ML decoder trainer
    """
    plt.figure(figsize=(15, 6))

    # RNN Decoder Confusion Matrix
    plt.subplot(1, 2, 1)
    rnn_predictions = rnn_trainer.predict(X_test.cpu().numpy())
    rnn_pred_classes = (rnn_predictions > 0.5).astype(int).flatten()
    rnn_cm = confusion_matrix(y_test.cpu().numpy(), rnn_pred_classes)

    sns.heatmap(
        rnn_cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        square=True
    )
    plt.title('RNN Decoder Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    # ML Decoder Confusion Matrix
    plt.subplot(1, 2, 2)
    ml_predictions = ml_trainer.predict(X_test.cpu().numpy())
    ml_pred_classes = (ml_predictions > 0.5).astype(int).flatten()
    ml_cm = confusion_matrix(y_test.cpu().numpy(), ml_pred_classes)

    sns.heatmap(
        ml_cm,
        annot=True,
        fmt='d',
        cmap='Greens',
        square=True
    )
    plt.title('ML Decoder Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    plt.tight_layout()
    plt.savefig('confusion_matrices.png')
    plt.close()






#exec(code)

#def plot_ber_bler_performance(snr_range, rnn_trainer, ml_trainer, polar_code_gen):
#def plot_ber_bler_performance(rnn_trainer, ml_trainer, polar_code_gen):


    """
    Plot BER and BLER Performance for RNN and ML Decoders
    """
    # Channel Simulators
    channels = {
        'AWGN': ChannelSimulator(channel_type='AWGN'),
        'Rayleigh': ChannelSimulator(channel_type='Rayleigh')
    }



# Global variables to store trainers
#rnn_trainer = None
#ml_trainer = None

def run_simulation():
    rnn_trainer, ml_trainer = main()
    plot_channel_performance(
        SNR_RANGE,
        rnn_trainer,
        ml_trainer,
        polar_code_gen
    )

def main():
    try:
        # Simulation Parameters
        BLOCK_LENGTH = 32  # Reduced from 128
        INFO_BITS = 16     # Reduced from 64
        LEARNING_RATE = 1e-3
        EPOCHS = 50        # Reduced from 200
        BATCH_SIZE = 32    # Reduced from 64
        NUM_SAMPLES = 5000 # Reduced from 15000

        # Channel-specific SNR Ranges
        AWGN_SNR_RANGE = np.arange(0, 5.5, 0.5)    # 0-5 dB, 0.5 dB steps
        RAYLEIGH_SNR_RANGE = np.linspace(0, 10, 10)  # Broader range for Rayleigh

        LIST_SIZES = [1, 4, 8]  # Adjusted list sizes
        NUM_TRIALS = 1000        # Reduced from 2000
       # plot_ber_bler_performance(rnn_trainer, ml_trainer, polar_code_gen, channels)

        # Device Configuration
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"🚀 Using Device: {device}")

        # 1. Polar Code Generator
        polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
        print("✅ Polar Code Generator Initialized")

        # 2. Prepare Dataset
        X, y = prepare_polar_dataset(polar_code_gen, num_samples=NUM_SAMPLES)
        print(f"Dataset Prepared: X shape {X.shape}, y shape {y.shape}")

        # 3. Split Dataset
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )

        # Convert to torch tensors
        X_train = torch.FloatTensor(X_train).to(device)
        X_test = torch.FloatTensor(X_test).to(device)
        y_train = torch.FloatTensor(y_train).to(device).view(-1, 1)
        y_test = torch.FloatTensor(y_test).to(device).view(-1, 1)

        # 4. RNN Decoder Training
        rnn_model = RNNDecoder(input_size=BLOCK_LENGTH).to(device)
        rnn_trainer = DecoderTrainer(rnn_model, learning_rate=LEARNING_RATE)

        ml_model = MLDecoder()
        ml_trainer = DecoderTrainer(ml_model)

        rnn_train_losses, rnn_val_losses = rnn_trainer.train(
            X_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE
        )
        print("✅ RNN Decoder Training Completed")

        # 5. ML Decoder Training
        ml_model = MLDecoder(input_size=BLOCK_LENGTH).to(device)
        ml_trainer = DecoderTrainer(ml_model, learning_rate=LEARNING_RATE)

        ml_train_losses, ml_val_losses = ml_trainer.train(
            X_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE
        )
        print("✅ ML Decoder Training Completed")

        # Plot Training Losses
        plot_training_performance(
            rnn_train_losses,
            rnn_val_losses,
            ml_train_losses,
            ml_val_losses
        )

        # Plot Confusion Matrices
        plot_confusion_matrices(
            X_test,
            y_test,
            rnn_trainer,
            ml_trainer
        )

        # 6. Channel Simulators
        channels = {
            'AWGN': ChannelSimulator(channel_type='AWGN'),
            'Rayleigh': ChannelSimulator(channel_type='Rayleigh')
        }

        # 7a. Plot Channel Capacities
        plot_channel_capacities(channels, RAYLEIGH_SNR_RANGE)

        # 7b. Performance Results Dictionary
        performance_results = {
            'RNN': {
                'ber_awgn': {},
                'bler_awgn': {},
                'ber_rayleigh': {},
                'bler_rayleigh': {}
            },
            'ML': {
                'ber_awgn': {},
                'bler_awgn': {},
                'ber_rayleigh': {},
                'bler_rayleigh': {}
            }
        }

        # 8. Compute Performance for Decoders, Channels, and List Sizes
        def compute_performance():
            for decoder_name, decoder in [('RNN', rnn_trainer), ('ML', ml_trainer)]:
                for channel_name, channel in channels.items():
                    for list_size in LIST_SIZES:
                        # Use channel-specific SNR range
                        snr_range = AWGN_SNR_RANGE if channel_name == 'AWGN' else RAYLEIGH_SNR_RANGE

                        ber, bler = compute_channel_performance(
                            decoder,
                            channel,
                            snr_range,
                            polar_code_gen,
                            list_size=list_size,
                            num_trials=NUM_TRIALS
                        )

                        # Store results
                        performance_results[decoder_name][f'ber_{channel_name.lower()}'][list_size] = ber
                        performance_results[decoder_name][f'bler_{channel_name.lower()}'][list_size] = bler

            return performance_results

        # Compute performance
        performance_results = compute_performance()

        # Verify channel performance
        verify_channel_performance(
            rnn_trainer,
            ml_trainer,
            polar_code_gen,
            RAYLEIGH_SNR_RANGE,
            channels
        )

        # 9. Performance Plotting
        plot_ber_bler_performance(rnn_trainer, ml_trainer, polar_code_gen, channels, LIST_SIZES)


        print("🎉 Simulation Complete!")
        return rnn_trainer, ml_trainer, performance_results

    except Exception as e:
        print(f"🆘 Comprehensive Simulation Error: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None
    except TypeError as e:
        if "cannot unpack non-iterable NoneType object" in str(e):
            print("Error: compute_channel_performance returned None. Check its implementation.")
            traceback.print_exc()  # Print the full traceback for debugging
        else:
            raise e  # Re-raise other TypeErrors
# Optional: Execution
if __name__ == "__main__":
    rnn_trainer, ml_trainer, performance_results = main()





Python version: 3.11.12 (main, Apr  9 2025, 08:55:54) [GCC 11.4.0]
PyTorch version: 2.6.0+cu124
Channel Simulator initialized for AWGN
Channel Simulator initialized for Rayleigh
🚀 Using Device: cpu
✅ Polar Code Generator Initialized
Dataset Prepared: X shape (5000, 32), y shape (5000,)




Early stopping triggered
✅ RNN Decoder Training Completed
✅ ML Decoder Training Completed




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Channel AWGN lacks simulate method
Error: Chan

Traceback (most recent call last):
  File "<ipython-input-51-4a3680f09b73>", line 2018, in main
    performance_results = compute_performance()
                          ^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-51-4a3680f09b73>", line 2002, in compute_performance
    ber, bler = compute_channel_performance(
    ^^^^^^^^^
TypeError: cannot unpack non-iterable NoneType object
