<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/Making_the_Most_of_your_Colab_Subscription.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from google.colab import files
uploaded = files.upload()

Saving 610ltstpolarML.py to 610ltstpolarML.py


In [None]:
#latest from Gemini evening /night 11:40 PM
# Polar transform
#Latest done early morning 06/09/25
#no doing during evening 06/08/25
#plots need be fixed. They are flat.
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import logging
import pandas as pd
import traceback # Import the traceback module
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Configure logging
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration parameters
BLOCK_LENGTH = 128
INFO_BITS = 64
LEARNING_RATE = 1e-3
EPOCHS = 50
BATCH_SIZE = 32
NUM_SAMPLES_TRAIN = 50000
NUM_TRIALS_PERF = 1500
SNR_RANGE_AWGN = np.linspace(0, 5, 11)
LIST_SIZES = [1, 8, 16]

#Part 2: Polar Code Generator and Modulation
####################################################


#######################################
#latest on evening of 06/10/25

# Centralized CRC Calculation Function (for 3GPP style)
def compute_crc(data, polynomial):
    """
    Computes CRC for data using the given polynomial (3GPP style).
    Args:
        data: Numpy array of binary data bits (0s and 0s).
        polynomial: Numpy array of binary polynomial coefficients (e.g., [1, 0, 0, 0, 1, 0, 0, 1] for x^7 + x^3 + 1).
    Returns:
        Numpy array of CRC bits.
    """
    poly_len = len(polynomial)
    crc_length = poly_len - 1 # CRC length is polynomial degree
    # Append crc_length zeros
    data_with_zeros = np.concatenate((data, np.zeros(crc_length, dtype=int)))

    remainder = np.copy(data_with_zeros)

    # Perform polynomial division from MSB
    for i in range(len(remainder) - poly_len + 1):
        if remainder[i] == 1:
            remainder[i : i + poly_len] ^= polynomial

    return remainder[-crc_length:]


class PolarCodeGenerator:
    def __init__(self, N, K, channel_snr_db, crc_type='CRC-7'):
        self.N = N
        self.K = K
        self.R = K / N  # Add code rate
        self.channel_snr_db = channel_snr_db # SNR for channel reliability calculation
        self.crc_type = crc_type
        self.crc_polynomials = {'CRC-7': (np.array([1, 0, 0, 0, 1, 0, 0, 1], dtype=int), 7)}

        if crc_type in self.crc_polynomials:
            self._crc_polynomial = self.crc_polynomials[crc_type][0]
            self._crc_length = self.crc_polynomials[crc_type][1]
        else:
            self._crc_polynomial = None
            self._crc_length = 0

        self.K_crc = self.K + self._crc_length

        # Determine frozen/info sets using mutual information based method
        self.frozen_set, self.info_set = self._get_frozen_and_info_sets_mi()

        logging.info(f"PolarCodeGenerator: N={self.N}, K={self.K}, R={self.R:.4f}, CRC Length={self._crc_length}, K_crc={self.K_crc}")
        logging.info(f"Info Set (first 5, last 5): {self.info_set[:5]} ... {self.info_set[-5:]}")
        logging.info(f"Frozen Set (first 5, last 5): {self.frozen_set[:5]} ... {self.frozen_set[-5:]}")

    def _get_frozen_and_info_sets_mi(self):
        # Simplified mutual information approximation for reliability
        # This is a basic approach and more sophisticated methods exist.
        # Reliability is inversely related to the Bhattacharyya parameter I(W).
        # For BEC, I(W) = p (erasure probability)
        # For AWGN, calculating exact mutual information is complex.
        # A common approximation relates reliability to channel capacity or similar metrics.
        # For a fixed N and channel type, the relative reliability order is constant.
        # We can use a pre-computed order or a more involved recursive calculation.

        # Let's use a simple recursive method to get the reliability order
        # based on the channel combination operation.

        def get_reliability_order(n):
            if n == 1:
                return [0]
            half_n = n // 2
            order_half = get_reliability_order(half_n)
            order = []
            # Indices for the combined channels
            order.extend([2 * i for i in order_half]) # Less reliable channels
            order.extend([2 * i + 1 for i in order_half]) # More reliable channels
            return order

        reliability_order = get_reliability_order(self.N)

        # Sort indices based on a metric (e.g., Bhattacharyya parameter).
        # Since calculating the exact Bhattacharyya parameter recursively for AWGN
        # is involved, we'll use the known relative order of channels for AWGN.
        # The 'better' channels have lower Bhattacharyya parameter values (or higher capacity/reliability).
        # The recursive function above generates indices in a specific order.
        # To get the reliability order, we need to know which index corresponds
        # to a more reliable channel after the polarization step.
        # The standard polarization construction for AWGN results in the channels
        # formed from (u_i + u_{i+N/2}) being less reliable than u_i + u_{i+N/2} given u_i.

        # A more correct way to get the reliability order for AWGN is using a
        # recursive calculation of the Bhattacharyya parameter or similar metric.
        # For practical purposes and for N=128, the 3GPP sequence is derived
        # from these principles.

        # Let's revert to using the 3GPP sequence approach for N=128 for accuracy,
        # but keep the method name general (`_get_frozen_and_info_sets_mi`) to
        # indicate the underlying principle. For other N, a proper recursive
        # calculation would be needed.

        if self.N == 128:
             reliability_sequence = self._get_3gpp_reliability_sequence_128()
             if len(reliability_sequence) >= self.K_crc:
                 info_channel_indices = sorted(reliability_sequence[-self.K_crc:])
             else:
                 logging.error(f"Reliability sequence (length {len(reliability_sequence)}) is shorter than K_crc ({self.K_crc}). Cannot determine info set correctly.")
                 info_channel_indices = sorted(reliability_sequence[-len(reliability_sequence):])

             frozen_channel_indices = sorted(list(set(range(self.N)) - set(info_channel_indices)))

        else:
            # For N != 128, a proper recursive calculation of reliability indices is needed.
            # This is a placeholder for a more general implementation.
            # For this exercise, we'll stick to N=128 or raise an error.
            raise NotImplementedError(f"Mutual information-based frozen set calculation is not implemented for N={self.N}. Only N=128 is supported with the 3GPP sequence.")


        if len(info_channel_indices) != self.K_crc:
            logging.warning(f"Mismatch: Expected {self.K_crc} info indices, but got {len(info_channel_indices)}")

        if len(frozen_channel_indices) != self.N - self.K_crc:
             logging.warning(f"Mismatch: Expected {self.N - self.K_crc} frozen indices, but got {len(frozen_channel_indices)}")


        return frozen_channel_indices, info_channel_indices


    def _get_3gpp_reliability_sequence_128(self):
        # The pre-computed reliability sequence for N=128, AWGN (from 3GPP TS 38.212).
        # This sequence lists bit indices in increasing order of reliability
        # (least reliable first).
        return [
            0, 1, 2, 4, 8, 16, 3, 5, 9, 6, 17, 10, 18, 32, 12, 33,
            20, 24, 34, 36, 40, 7, 11, 19, 21, 13, 22, 25, 26, 28,
            48, 35, 37, 38, 41, 42, 44, 56, 14, 15, 23, 27, 29, 30,
            31, 39, 43, 45, 46, 49, 50, 52, 57, 58, 60, 63, 47, 51,
            53, 54, 59, 61, 62, 65, 66, 67, 68, 70, 72, 73, 74, 75,
            76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 77,
            79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 102, 103,
            104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
            116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127
        ]


    def generate_info_bits(self):
        return np.random.randint(2, size=self.K)

    def polar_encode(self, info_bits, verbose=False):
        info_bits_with_crc = self.append_crc(info_bits)

        if verbose:
             logging.info(f"polar_encode: info_bits_with_crc (len {len(info_bits_with_crc)}): {info_bits_with_crc[:10]}...")

        if len(info_bits_with_crc) != len(self.info_set):
            raise ValueError(f"Length of info_bits_with_crc ({len(info_bits_with_crc)}) must match length of info_set ({len(self.info_set)})")

        u = np.zeros(self.N, dtype=int)
        u[np.array(list(self.info_set))] = info_bits_with_crc

        encoded = self._polar_transform(u)

        return encoded

    def _polar_transform(self, u):
        N = len(u)
        if N == 1:
            return u
        else:
            half_N = N // 2
            u_np = np.array(u)
            x_upper = self._polar_transform(u_np[:half_N])
            x_lower = self._polar_transform(u_np[half_N:])
            codeword = np.concatenate([(x_upper + x_lower) % 2, x_lower])
            return codeword

    def append_crc(self, info_bits):
        if self.crc_type not in self.crc_polynomials:
            return info_bits
        polynomial, length = self.crc_polynomials[self.crc_type]
        data_for_crc = np.copy(info_bits)
        crc_bits = compute_crc(data_for_crc, polynomial)
        if len(crc_bits) != length:
             logging.warning(f"Calculated CRC length ({len(crc_bits)}) does not match expected length ({length})")
        return np.concatenate((info_bits, crc_bits))
############################################



#Latest Polar Code Generator


####################################################
#Latest Polar Code Generator
# ========== CRC Polynomial (CRC-8) ==========
#crc_poly = np.array([1, 0, 0, 0, 1, 1, 0, 1, 1], dtype=int)  # x^8 + x^4 + x^3 + x + 1

# ========== Functions ==========
#######################################





def bpsk_modulate(bits):
    return 1 - 2 * bits

def add_awgn(signal, snr_db):
    snr_linear = 10 ** (snr_db / 10)
    noise_var = 1 / (2 * snr_linear)
    noise = np.sqrt(noise_var) * np.random.randn(len(signal))
    return signal + noise, noise_var





################################################################


#############################################

# Keep the rest of your code (BPSK modulation, Channel Simulation, RNN Decoder, Trainer, SCL Decoder, Plotting, Performance Comparison, and main function) as is for now
####################################################


#Part 3: Dataset Preparation and Channel Simulation
############################################################
#latest add bpsk_modulate
def bpsk_modulate(codeword):
  """
  Performs BPSK modulation on a binary codeword.

  Args:
    codeword: A numpy array of binary bits (0s and 1s).

  Returns:
    A numpy array of BPSK symbols (+1s and -1s).
  """
  return 1 - 2 * codeword

###########################################################
class EnhancedChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        self.channel_type = channel_type

    def simulate(self, signal, snr_db):
        snr_linear = 10 ** (snr_db / 10)
        noise_std = np.sqrt(1 / (2 * snr_linear))
        noise = noise_std * np.random.randn(*signal.shape)
        return signal + noise

def prepare_polar_dataset(polar_code_gen, num_samples, snr_db=5, channel_type='AWGN'):
    channel_simulator = EnhancedChannelSimulator(channel_type=channel_type)
    X, y = [], []

    for _ in range(num_samples):
        info_bits = polar_code_gen.generate_info_bits()
        encoded_signal = polar_code_gen.polar_encode(info_bits)
        modulated_signal = bpsk_modulate(encoded_signal)
        received_signal = channel_simulator.simulate(modulated_signal, snr_db)
        X.append(received_signal)
        y.append(info_bits)

    return np.array(X), np.array(y)

def save_dataset_to_csv(X, y, filename='dataset.csv'):
    data = np.hstack((X, y))
    columns = [f'received_{i}' for i in range(X.shape[1])] + [f'bit_{j}' for j in range(y.shape[1])]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(filename, index=False)
    logging.info(f"Dataset saved to {filename}")

#Part 4
class EnhancedRNNDecoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=128, num_layers=2):
        super(EnhancedRNNDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers  # Ensure this is defined
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x_reshaped = x.unsqueeze(1)  # Assuming sequence_length = 1

        # Ensure h0 and c0 are created with the correct device and shape
        # Use self.num_layers for multiple RNN layers
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)

        out, _ = self.rnn(x_reshaped, (h0, c0))
        out = self.fc(out[:, -1, :])
        return self.sigmoid(out)

class DecoderTrainer:
    def __init__(self, model, learning_rate):
        self.model = model
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    def train(self, X_train, y_train, X_val=None, y_val=None, epochs=50, batch_size=32):
        dataset = torch.utils.data.TensorDataset(X_train, y_train)
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

        train_losses, val_losses = [], []

        for epoch in range(epochs):
            epoch_loss = 0
            self.model.train()

            for X_batch, y_batch in loader:
                X_batch = X_batch.view(-1, BLOCK_LENGTH)
                self.optimizer.zero_grad()
                outputs = self.model(X_batch)
                loss = self.criterion(outputs, y_batch)
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item()

            train_loss = epoch_loss / len(loader)
            train_losses.append(train_loss)
            logging.info(f"Epoch {epoch+1}/{epochs}, Training Loss: {train_loss:.4f}")

            if X_val is not None and y_val is not None:
                self.model.eval()
                with torch.no_grad():
                    val_output = self.model(X_val.view(-1, BLOCK_LENGTH))
                    val_loss = self.criterion(val_output, y_val).item()
                    val_losses.append(val_loss)
                    logging.info(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss:.4f}")

        return train_losses, val_losses if X_val is not None else None

    def evaluate(self, X_test, y_test):
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X_test.view(-1, BLOCK_LENGTH))
            predicted = (outputs > 0.5).int()
            correct = (predicted == y_test).sum().item()
            total = y_test.numel()
            accuracy = correct / total
            # Calculate BER and BLER
            bit_errors = torch.sum(predicted != y_test).item()
            block_errors = torch.sum(torch.any(predicted != y_test, dim=1)).item()
            ber = bit_errors / total
            bler = block_errors / X_test.size(0)

        return ber, bler


#Part 5 and 6: SCL Decoder
#########################################################
#latest Polarcode decoder
#It's on 06/08/25 evening/night
#on 06/10/25 more changes for SCL decoder
#latest Polarcode decoder
#It's on 06/08/25 evening/night
###########################################
#SC decoder

class SCDecoder:
    def __init__(self, N, K, info_set, frozen_set):
        self.N = N
        self.K = K
        self.info_set = set(info_set)  # Convert to set for faster lookups
        self.frozen_set = set(frozen_set) # Convert to set for faster lookups
        self.u_hat = np.zeros(N, dtype=int) # Decoded bits
        self.llrs = None # LLRs will be initialized at decode time

    def decode(self, received_llrs):
        if len(received_llrs) != self.N:
            raise ValueError(f"Input LLR length ({len(received_llrs)}) does not match code length N ({self.N}).")

        self.llrs = np.copy(received_llrs)
        self._recursive_decode(0, self.N)

        # Extract the information bits
        decoded_info_bits = self.u_hat[list(sorted(self.info_set))]

        return decoded_info_bits

    def _recursive_decode(self, bit_index, block_size):
        if block_size == 1:
            # Decision step at the leaf nodes
            if bit_index in self.frozen_set:
                self.u_hat[bit_index] = 0 # Frozen bit is fixed to 0
            else:
                # Information bit - make a hard decision based on LLR
                self.u_hat[bit_index] = 0 if self.llrs[bit_index] >= 0 else 1
            return

        half_size = block_size // 2

        # Split step - compute LLRs for the first half (u_1)
        # LLR(u1) approx sign(L1)*sign(L2)*min(|L1|,|L2|)
        llr_f = np.sign(self.llrs[bit_index : bit_index + half_size]) * \
                np.sign(self.llrs[bit_index + half_size : bit_index + block_size]) * \
                np.minimum(np.abs(self.llrs[bit_index : bit_index + half_size]), \
                           np.abs(self.llrs[bit_index + half_size : bit_index + block_size]))

        # Update LLRs for the first half sub-block
        self.llrs[bit_index : bit_index + half_size] = llr_f

        # Recursively decode the first half
        self._recursive_decode(bit_index, half_size)

        # Combine step - prepare information for the second half (u_2)
        # Need decisions from the first half (u_1) for the G operation
        u1_decisions = self.u_hat[bit_index : bit_index + half_size]
        # LLR(u2) = L2 + (1 - 2*u1) * L1
        # The L1 here refers to the LLRs of the *first* half *after* the f operation
        llr_g = self.llrs[bit_index + half_size : bit_index + block_size] + \
                (1 - 2 * u1_decisions) * self.llrs[bit_index : bit_index + half_size]

        # Update LLRs for the second half sub-block
        self.llrs[bit_index + half_size : bit_index + block_size] = llr_g

        # Recursively decode the second half
        self._recursive_decode(bit_index + half_size, half_size)

    # The _f and _g functions are implicitly used within the recursive function
    # but can be defined separately for clarity if needed.
    # def _f(self, L1, L2):
    #     return np.sign(L1) * np.sign(L2) * np.minimum(np.abs(L1), np.abs(L2))

    # def _g(self, L1, L2, u1):
    #     return L2 + (1 - 2 * u1) * L1
#############################################

#Latest SCL DECODER
class PolarCodeDecoder:
    def __init__(self, N, K, list_size, crc_poly=None):
        self.N = N
        self.K = K
        self.list_size = list_size
        self._crc_polynomial = None
        self._crc_length = 0

        if crc_poly is not None and isinstance(crc_poly, tuple) and len(crc_poly) == 2:
             self._crc_polynomial = crc_poly[0]
             self._crc_length = crc_poly[1]
             self.K_crc = self.K + self._crc_length
        else:
             self.K_crc = self.K

        # Get frozen and info sets using the 3GPP sequence method
        # Corrected method call name
        self.frozen_set, self.info_set = self._get_frozen_and_info_sets_from_3gpp_sequence()

        # Debug prints for info and frozen sets in Decoder
        logging.info(f"PolarCodeDecoder: N={self.N}, K={self.K}, CRC Length={self._crc_length}, K_crc={self.K_crc}")
        logging.info(f"Decoder Info Set (first 5, last 5): {self.info_set[:5]} ... {self.info_set[-5:]}")
        logging.info(f"Decoder Frozen Set (first 5, last 5): {self.frozen_set[:5]} ... {self.frozen_set[-5:]}")


    def _get_3gpp_reliability_sequence_128(self):
        # The pre-computed reliability sequence for N=128, AWGN (from 3GPP TS 38.212).
        # This sequence lists bit indices in increasing order of reliability
        # (least reliable first).
        return [
            0, 1, 2, 4, 8, 16, 3, 5, 9, 6, 17, 10, 18, 32, 12, 33,
            20, 24, 34, 36, 40, 7, 11, 19, 21, 13, 22, 25, 26, 28,
            48, 35, 37, 38, 41, 42, 44, 56, 14, 15, 23, 27, 29, 30,
            31, 39, 43, 45, 46, 49, 50, 52, 57, 58, 60, 63, 47, 51,
            53, 54, 59, 61, 62, 65, 66, 67, 68, 70, 72, 73, 74, 75,
            76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 77,
            79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 102, 103,
            104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
            116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127
        ]


    def _get_frozen_and_info_sets_from_3gpp_sequence(self):
        reliability_sequence = self._get_3gpp_reliability_sequence_128()

        if self.N != 128:
             raise ValueError(f"Code length N={self.N} is not supported by the hardcoded 3GPP sequence.")

        # The most reliable channels are used for information bits (including CRC)
        # The sequence is ordered from least reliable to most reliable,
        # so we take the last K_crc elements for the info set.
        if len(reliability_sequence) >= self.K_crc:
            info_channel_indices = sorted(reliability_sequence[-self.K_crc:])
        else:
             logging.error(f"Reliability sequence (length {len(reliability_sequence)}) is shorter than K_crc ({self.K_crc}). Cannot determine info set correctly.")
             info_channel_indices = sorted(reliability_sequence[-len(reliability_sequence):])

        # The remaining channels are frozen
        frozen_channel_indices = sorted(list(set(range(self.N)) - set(info_channel_indices)))

        if len(info_channel_indices) != self.K_crc:
            logging.warning(f"Mismatch: Expected {self.K_crc} info indices, but got {len(info_channel_indices)}")

        if len(frozen_channel_indices) != self.N - self.K_crc:
             logging.warning(f"Mismatch: Expected {self.N - self.K_crc} frozen indices, but got {len(frozen_channel_indices)}")

        return frozen_channel_indices, info_channel_indices


    def decode(self, received_llrs):
        # Initialize list of active path indices. Initially, only one path (index 0) is active.
        active_path_indices = [0]
        # Initialize lists to store LLRs, hard decisions, and path metrics for each path.
        # We start with one path, which is a copy of the input LLRs and an empty hard decision array.
        self.llrs = [np.copy(received_llrs)]
        self.hard_decisions = [np.zeros(self.N, dtype=int)]
        # Initialize path metrics. A common initialization is 0.0.
        self.path_metrics = [0.0]

        # Start the recursive decoding process.
        final_active_path_indices = self._recursive_decode(active_path_indices, 0, self.N)

        # Select the best path among the final active paths.
        # If CRC is used, we prioritize paths with valid CRC.
        if self._crc_polynomial is not None:
             valid_paths = []
             # logging.info(f"Checking CRC for {len(final_active_path_indices)} paths.") # Optional logging
             for path_idx in final_active_path_indices:
                 # Extract the bits corresponding to info set indices for CRC check
                 decoded_bits_at_info_indices = self.hard_decisions[path_idx][list(sorted(self.info_set))]

                 # Ensure the extracted bits have the expected length before checking CRC
                 if len(decoded_bits_at_info_indices) == self.K_crc and self._check_crc(decoded_bits_at_info_indices):
                     valid_paths.append(path_idx)
                     # logging.info(f"Path {path_idx} has a valid CRC.") # Optional logging
                 # else: # Optional logging for CRC failures
                     # logging.info(f"Path {path_idx} CRC check failed or length mismatch.")
                     # if len(decoded_bits_at_info_indices) != self.K_crc:
                     #    logging.warning(f"Path {path_idx}: Length mismatch for CRC check. Got {len(decoded_bits_at_info_indices)}, Expected {self.K_crc}")


             if valid_paths:
                 # If valid paths exist, choose the one with the minimum path metric.
                 logging.info(f"Found {len(valid_paths)} paths with valid CRC. Selecting the one with the best metric.")
                 best_path_index_in_valid = np.argmin([self.path_metrics[i] for i in valid_paths])
                 best_path_index = valid_paths[best_path_index_in_valid]
                 # logging.info(f"Selected valid path {best_path_index} with metric {self.path_metrics[best_path_index]:.4f}") # Optional logging
             else:
                 # If no valid CRC path, choose the path with the overall best metric among the final active paths.
                 # This indicates a decoding failure where no path satisfies the CRC,
                 # but we return the "most likely" result according to the path metric.
                 logging.warning("No valid CRC path found among the best paths. Choosing the path with the best metric.")
                 best_path_index_in_final = np.argmin([self.path_metrics[i] for i in final_active_path_indices])
                 best_path_index = final_active_path_indices[best_path_index_in_final]
                 # logging.warning(f"Selected path {best_path_index} with metric {self.path_metrics[best_path_index]:.4f}") # Optional logging

        else:
            # If no CRC is used, simply choose the path with the minimum path metric among the final active paths.
            logging.info(f"No CRC used. Selecting the path with the best metric among {len(final_active_path_indices)} final paths.")
            best_path_index_in_final = np.argmin([self.path_metrics[i] for i in final_active_path_indices])
            best_path_index = final_active_path_indices[best_path_index_in_final]
            # logging.info(f"Selected path {best_path_index} with best metric {self.path_metrics[best_path_index]:.4f}.") # Optional logging


        # Extract and return the decoded information bits (first K bits of the K_crc bits)
        decoded_info_bits_with_crc = self.hard_decisions[best_path_index][list(sorted(self.info_set))]

        # Ensure we return exactly K information bits
        if len(decoded_info_bits_with_crc) >= self.K:
             return decoded_info_bits_with_crc[:self.K]
        else:
             logging.error(f"Decoded bits at info indices (length {len(decoded_info_bits_with_crc)}) is shorter than K ({self.K}). Returning truncated bits.")
             return decoded_info_bits_with_crc[:min(len(decoded_info_bits_with_crc), self.K)] # Return up to K bits


    def _recursive_decode(self, active_path_indices, bit_index, block_size):
        # Base Case: Arrived at a single bit
        if block_size == 1:
            next_active_path_indices = []
            for path_idx in active_path_indices:
                llr = self.llrs[path_idx][bit_index]

                # Check if this is an information bit index (including CRC)
                is_info_bit = bit_index in self.info_set

                # If it's a frozen bit, the decision is fixed to 0 for all paths.
                if bit_index in self.frozen_set:
                     hard_decision = 0
                     # For a frozen bit, there is no alternative decision to consider for path splitting.
                     # The path metric contribution for a frozen bit is typically considered 0
                     # because the decision is not based on the channel observation.
                     self.hard_decisions[path_idx][bit_index] = hard_decision
                     next_active_path_indices.append(path_idx)
                else: # If it's an information bit (or CRC bit within the info set)
                    # For information bits, we consider two possible hard decisions (0 and 1).
                    # We need to calculate the path metric for both decisions and potentially
                    # create new paths if the list size allows.

                    # Calculate path metric for decision 0
                    path_metric_0 = self.path_metrics[path_idx] + np.log(1 + np.exp(-llr))

                    # Calculate path metric for decision 1
                    path_metric_1 = self.path_metrics[path_idx] + np.log(1 + np.exp(llr))

                    # If the number of active paths is less than the list size,
                    # we can potentially branch and create a new path for the alternative decision.
                    if len(active_path_indices) < self.list_size:
                         # Create a new path by copying the current path's state
                         new_path_idx = len(self.paths) # Index for the new path
                         self.llrs.append(np.copy(self.llrs[path_idx]))
                         self.hard_decisions.append(np.copy(self.hard_decisions[path_idx]))
                         self.path_metrics.append(0.0) # Will update below

                         # Assign one decision to the original path and the other to the new path.
                         # It's common to assign the decision with the lower path metric to the original path
                         # and the higher path metric to the new path, but the order doesn't fundamentally
                         # change the set of paths considered, just which index they get.
                         # Let's assign decision 0 to the original path and decision 1 to the new path for simplicity here.
                         hard_decision_orig = 0
                         hard_decision_new = 1

                         self.hard_decisions[path_idx][bit_index] = hard_decision_orig
                         self.path_metrics[path_idx] = path_metric_0

                         self.hard_decisions[new_path_idx][bit_index] = hard_decision_new
                         self.path_metrics[new_path_idx] = path_metric_1

                         # Add both paths to the list of next active paths
                         next_active_path_indices.extend([path_idx, new_path_idx])
                    else:
                         # If we are already at the list size limit, we must choose only one decision
                         # for the current path. We choose the decision with the smaller path metric.
                         if path_metric_0 <= path_metric_1:
                             hard_decision = 0
                             self.path_metrics[path_idx] = path_metric_0
                         else:
                             hard_decision = 1
                             self.path_metrics[path_idx] = path_metric_1

                         self.hard_decisions[path_idx][bit_index] = hard_decision
                         next_active_path_indices.append(path_idx)

            # After processing all active paths for the current bit, prune the list
            # if the number of potential next paths exceeds the list size.
            if len(next_active_path_indices) > self.list_size:
                 # Sort paths by their metric and keep only the best 'list_size' paths
                 sorted_indices = sorted(next_active_path_indices, key=lambda i: self.path_metrics[i])
                 return sorted_indices[:self.list_size]
            else:
                return next_active_path_indices


        # Recursive Step: Process a block
        else:
            half_size = block_size // 2
            # Split step - compute LLRs for the first half (u_1) for each active path
            for path_idx in active_path_indices:
                llr_f = self._f(self.llrs[path_idx][bit_index:bit_index + half_size], self.llrs[path_idx][bit_index + half_size:bit_index + block_size])
                self.llrs[path_idx][bit_index:bit_index + half_size] = llr_f

            # Recursively decode the first half for the current active paths
            active_paths_after_u1 = self._recursive_decode(active_path_indices, bit_index, half_size)

            # Combine step - prepare information for the second half (u_2) for the remaining active paths
            # Need decisions from the first half (u_1) for the G operation
            next_active_path_indices = []
            for path_idx in active_paths_after_u1:
                u1_decisions = self.hard_decisions[path_idx][bit_index:bit_index + half_size]
                llr_g = self._g(self.llrs[path_idx][bit_index:bit_index + half_size], self.llrs[path_idx][bit_index + half_size:bit_index + block_size], u1_decisions)
                self.llrs[path_idx][bit_index + half_size:bit_index + block_size] = llr_g
                next_active_path_indices.append(path_idx) # These paths remain active for the next recursive call

            # Recursively decode the second half for the remaining active paths
            active_paths_after_u2 = self._recursive_decode(next_active_path_indices, bit_index + half_size, half_size)

            # Pruning after finishing a sub-block (this is a common point for pruning in SCL)
            if len(active_paths_after_u2) > self.list_size:
                sorted_indices = sorted(active_paths_after_u2, key=lambda i: self.path_metrics[i])
                return sorted_indices[:self.list_size]
            else:
                return active_paths_after_u2


    def _f(self, L1, L2):
        # More accurate LLR combining for F operation
        # Formula: sign(L1)*sign(L2)*min(|L1|,|L2|) + log(1+exp(-|L1+L2|)) - log(1+exp(-|L1-L2|))
        # Handle potential overflow/underflow for large LLRs
        # A numerically stable approximation often used:
        # return np.sign(L1) * np.sign(L2) * np.minimum(np.abs(L1), np.abs(L2))
        # Let's use the more standard approximation that is numerically stable:
        # min(|L1|, |L2|) with the sign of L1*L2, adding a correction term.
        # This is also sometimes approximated as sign(L1*L2) * min(|L1|, |L2|) for simplicity
        # when exact LLRs aren't critical.
        # A more robust approximation:
        return np.sign(L1) * np.sign(L2) * np.minimum(np.abs(L1), np.abs(L2))

    def _g(self, L1, L2, u1):
        # LLR combining for G operation
        # Formula: LLR(u2) = L2 + (1 - 2*u1) * L1
        return L2 + (1 - 2 * u1) * L1

    def _check_crc(self, bits):
         """
         Checks the CRC for the given bits (assumed to be info bits + CRC).
         Uses the centralized compute_crc function.
         """
         if self._crc_polynomial is None or self._crc_length == 0:
             return True # No CRC to check or CRC length is 0

         # Ensure the input bits have the expected length (K + crc_length)
         if len(bits) != self.K_crc:
             # logging.warning(f"_check_crc: Input bits length ({len(bits)}) does not match expected K_crc ({self.K_crc}). CRC check skipped.") # Optional logging
             return False # Cannot check CRC if the length is wrong

         # Extract data and received CRC bits
         data_bits = bits[:self.K]
         received_crc = bits[self.K:]

         # Compute CRC for the data bits
         computed_crc = compute_crc(data_bits, self._crc_polynomial)

         # Check if computed CRC matches the received CRC
         return np.array_equal(computed_crc, received_crc)

####################################################






#####################################################







#Part 6: Plotting Functions

def plot_training_validation(train_losses, val_losses):
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, label='Training Loss')
    if val_losses:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

def plot_ber_bler_comparison(snr_range, rnn_results, scl_results_all, list_sizes):
    plt.figure(figsize=(12, 6))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN')
    for size, scl_results in scl_results_all.items():
        plt.plot(snr_range, [result['BER'] for result in scl_results], label=f'SCL, List Size {size}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN')
    for size, scl_results in scl_results_all.items():
        plt.plot(snr_range, [result['BLER'] for result in scl_results], label=f'SCL, List Size {size}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, title='Confusion Matrix'):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title(title)
    plt.show()

def plot_ber_bler_scl(snr_range, scl_results_all, list_sizes):
    plt.figure(figsize=(12, 6))

    # BER Plot for SCL
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)  # Set y-limits
    for size in list_sizes:
        if size not in scl_results_all or not scl_results_all[size]:
            continue
        plt.plot(snr_range, [result['BER'] for result in scl_results_all[size]], label=f'SCL, List Size {size}', marker='x', linestyle='--')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SCL Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot for SCL
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)  # Set y-limits
    for size in list_sizes:
        if size not in scl_results_all or not scl_results_all[size]:
            continue
        plt.plot(snr_range, [result['BLER'] for result in scl_results_all[size]], label=f'SCL, List Size {size}', marker='x', linestyle='--')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SCL Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()

    #for RNN decoder plots
def plot_ber_bler_rnn(snr_range, rnn_results):
    plt.figure(figsize=(12, 6))

    # BER Plot for RNN
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)  # Set y-limits
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN', marker='o', linestyle='-')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('RNN Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot for RNN
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)  # Set y-limits
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN', marker='o', linestyle='-')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('RNN Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()

# Part 6 continued: Performance Evaluation Function

#Latest performance compariso from ChatGPT 06/09/25 evening

def performance_comparison(
    rnn_trainer, polar_code_gen, snr_range_db, channel_type, list_sizes, num_trials, device, info_bits
):
    """
    Evaluate performance (BER and BLER) of RNN and SCL decoders over a range of SNR values.

    Args:
        rnn_trainer: Trained RNN decoder trainer object.
        polar_code_gen: Polar code generator object.
        snr_range_db: Iterable of SNR values in dB.
        channel_type: Channel model type string (e.g., 'AWGN').
        list_sizes: List of integers for SCL decoder list sizes.
        num_trials: Number of random trials per SNR.
        device: PyTorch device (cpu or cuda).
        info_bits: Number of info bits in code.

    Returns:
        rnn_perf_results: dict with 'BER_RNN' and 'BLER_RNN' lists.
        scl_perf_results: dict keyed by list size with list of dicts {'BER', 'BLER'}.
    """
    logging.info("Starting performance comparison...")
    channel_simulator = EnhancedChannelSimulator(channel_type=channel_type)
    rnn_perf_results = {'BER_RNN': [], 'BLER_RNN': []}
    scl_perf_results = {size: [] for size in list_sizes}

    rnn_model = rnn_trainer.model.to(device)
    rnn_model.eval()

    # Initialize one PolarCodeDecoder per list size outside trials loop (assuming stateless)
    scl_decoders = {
        size: PolarCodeDecoder(
            N=polar_code_gen.N,
            K=polar_code_gen.K,
            list_size=size,
            crc_poly=polar_code_gen.crc_polynomials.get(polar_code_gen.crc_type)
        ) for size in list_sizes
    }

    for snr_db in snr_range_db:
        logging.info(f"Simulating at SNR: {snr_db} dB")

        # Initialize counters
        total_bits_rnn = 0
        bit_errors_rnn = 0
        total_blocks_rnn = 0
        block_errors_rnn = 0

        total_bits_scl = {size: 0 for size in list_sizes}
        bit_errors_scl = {size: 0 for size in list_sizes}
        total_blocks_scl = {size: 0 for size in list_sizes}
        block_errors_scl = {size: 0 for size in list_sizes}

        snr_linear = 10 ** (snr_db / 10)
        noise_variance = 1 / (2 * snr_linear)  # Assumes unit signal power for BPSK

        for trial in range(num_trials):
            # Generate and encode bits
            info_bits_array = polar_code_gen.generate_info_bits()
            encoded_signal = polar_code_gen.polar_encode(info_bits_array)
            modulated_signal = bpsk_modulate(encoded_signal)
            received_signal = channel_simulator.simulate(modulated_signal, snr_db)

            # --- RNN Decoding ---
            with torch.no_grad():
                received_tensor = torch.FloatTensor(received_signal).view(1, -1).to(device)
                rnn_output_prob = rnn_model(received_tensor).cpu().numpy().squeeze()
                rnn_decoded_bits = (rnn_output_prob > 0.5).astype(int)

            # Update RNN stats
            total_bits_rnn += info_bits
            bit_errors_rnn += np.sum(rnn_decoded_bits != info_bits_array)
            total_blocks_rnn += 1
            if not np.array_equal(rnn_decoded_bits, info_bits_array):
                block_errors_rnn += 1

            # --- SCL Decoding ---
            received_llrs = (2 * received_signal) / noise_variance

            for size in list_sizes:
                scl_decoder = scl_decoders[size]
                scl_decoded_bits = scl_decoder.decode(received_llrs)

                total_bits_scl[size] += info_bits
                bit_errors_scl[size] += np.sum(scl_decoded_bits != info_bits_array)
                total_blocks_scl[size] += 1
                if not np.array_equal(scl_decoded_bits, info_bits_array):
                    block_errors_scl[size] += 1

        # Calculate RNN BER and BLER
        rnn_ber = bit_errors_rnn / total_bits_rnn if total_bits_rnn > 0 else 0
        rnn_bler = block_errors_rnn / total_blocks_rnn if total_blocks_rnn > 0 else 0
        rnn_perf_results['BER_RNN'].append(rnn_ber)
        rnn_perf_results['BLER_RNN'].append(rnn_bler)
        logging.info(f"SNR: {snr_db} dB, RNN BER: {rnn_ber:.4e}, BLER: {rnn_bler:.4e}")

        # Calculate SCL BER and BLER for each list size
        for size in list_sizes:
            scl_ber = bit_errors_scl[size] / total_bits_scl[size] if total_bits_scl[size] > 0 else 0
            scl_bler = block_errors_scl[size] / total_blocks_scl[size] if total_blocks_scl[size] > 0 else 0
            scl_perf_results[size].append({'BER': scl_ber, 'BLER': scl_bler})
            logging.info(f"SNR: {snr_db} dB, SCL List Size {size} BER: {scl_ber:.4e}, BLER: {scl_bler:.4e}")

    return rnn_perf_results, scl_perf_results



#Part 7
#main() function



 #####################################################################

 #latest main() from ChatGPT 06/09/25 evening


def main():
    """
    Main workflow for training RNN decoder and comparing with SCL decoder.
    """
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info(f"Using device: {device}")

        # Constants
        BLOCK_LENGTH = 128
        INFO_BITS = 64
        LEARNING_RATE = 1e-3
        EPOCHS = 50
        BATCH_SIZE = 32
        NUM_SAMPLES_TRAIN = 50000
        NUM_TRIALS_PERF = 1500
        SNR_RANGE_AWGN = np.linspace(0, 5, 11)
        LIST_SIZES = [1, 8, 16]

        # Initialize Polar code generator and RNN model
        polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
        rnn_model = EnhancedRNNDecoder(BLOCK_LENGTH, INFO_BITS).to(device)
        rnn_trainer = DecoderTrainer(rnn_model, LEARNING_RATE)

        logging.info(f"Polar code rate: {polar_code_gen.R}")

        # Data structures to store BER and BLER results
        sc_ber_results = []
        sc_bler_results = []
# Use a dictionary to store results for each list size
        scl_ber_results = {ls: [] for ls in LIST_SIZES}
        scl_bler_results = {ls: [] for ls in LIST_SIZES}

# Define the CRC polynomial and length (assuming CRC-7 is used)
# Make sure this matches what's used in PolarCodeGenerator and PolarCodeDecoder
        crc_polynomial = np.array([1, 0, 0, 0, 1, 0, 0, 1], dtype=int)
        crc_length = 7
        crc_poly_tuple = (crc_polynomial, crc_length) # Tuple format for PolarCodeDecoder

for snr_db in SNR_RANGE_AWGN:
    logging.info(f"Simulating at SNR = {snr_db:.2f} dB")

    # Instantiate the generator for the current SNR.
    # While reliability order for N=128 AWGN is fixed,
    # the generator still needs the SNR to be initialized.
    polar_code_gen = PolarCodeGenerator(BLOCK_LENGTH, INFO_BITS, snr_db, crc_type='CRC-7')

    # Get the frozen and info sets from the generator.
    # These sets are determined during generator initialization based on the channel model/reliability sequence.
    frozen_set = polar_code_gen.frozen_set
    info_set = polar_code_gen.info_set

    # Instantiate the SC Decoder.
    # It uses the info and frozen sets to know which bits to decode.
    sc_decoder = SCDecoder(BLOCK_LENGTH, INFO_BITS, info_set, frozen_set)

    # Instantiate the SCL Decoders for each list size.
    # They also need the info/frozen sets (derived internally using the same logic as the generator)
    # and the CRC polynomial for checking.
    scl_decoders = {ls: PolarCodeDecoder(BLOCK_LENGTH, INFO_BITS, ls, crc_poly=crc_poly_tuple) for ls in LIST_SIZES}


    total_info_bits = 0 # Total number of information bits transmitted
    total_block_errors_sc = 0 # Number of blocks where SC decoder made at least one error
    total_bit_errors_sc = 0 # Total number of bit errors by SC decoder

    total_block_errors_scl = {ls: 0 for ls in LIST_SIZES} # Block errors for each SCL list size
    total_bit_errors_scl = {ls: 0 for ls in LIST_SIZES} # Bit errors for each SCL list size

    total_simulations = 0 # Total number of blocks simulated at this SNR

    # Run simulation trials for performance evaluation
    for trial in range(NUM_TRIALS_PERF):
        # Generate random information bits
        info_bits = polar_code_gen.generate_info_bits()

        # Polar encode the information bits (includes CRC if configured)
        encoded_signal = polar_code_gen.polar_encode(info_bits)

        # BPSK modulate the encoded signal
        modulated_signal = bpsk_modulate(encoded_signal)

        # Simulate the AWGN channel
        channel_simulator = EnhancedChannelSimulator(channel_type='AWGN')
        received_signal = channel_simulator.simulate(modulated_signal, snr_db)

        # Convert received signal to LLRs (Log-Likelihood Ratios)
        # For BPSK over AWGN, LLR = (2 * received_signal) / noise_variance
        snr_linear = 10**(snr_db/10)
        # Assuming Es = 1 (symbol energy)
        noise_var = 1 / (2 * snr_linear)
        # Handle potential division by zero if snr_linear is very close to 0 (e.g., -inf dB)
        if noise_var > 0:
            received_llrs = (2 * received_signal) / noise_var
        else:
            # For very low/negative SNR, LLRs can be effectively 0 (pure noise)
            received_llrs = np.zeros_like(received_signal)


        # --- Decode with SC ---
        try:
            decoded_info_sc = sc_decoder.decode(received_llrs)
            if not np.array_equal(decoded_info_sc, info_bits):
                total_block_errors_sc += 1
                total_bit_errors_sc += np.sum(decoded_info_sc != info_bits)
        except Exception as e:
             logging.error(f"SC Decoder Error during trial {trial} at SNR {snr_db:.2f} dB: {e}")
             traceback.print_exc() # Print traceback for debugging
             total_block_errors_sc += 1 # Count as a block error if decoding fails
             total_bit_errors_sc += INFO_BITS # Assume all bits wrong in case of decoding failure

        total_info_bits += INFO_BITS # Add INFO_BITS for each trial


        # --- Decode with SCL for each list size ---
        for ls in LIST_SIZES:
            try:
                decoded_info_scl = scl_decoders[ls].decode(received_llrs)
                if not np.array_equal(decoded_info_scl, info_bits):
                    total_block_errors_scl[ls] += 1
                    total_bit_errors_scl[ls] += np.sum(decoded_info_scl != info_bits)
            except Exception as e:
                 logging.error(f"SCL Decoder (L={ls}) Error during trial {trial} at SNR {snr_db:.2f} dB: {e}")
                 traceback.print_exc() # Print traceback for debugging
                 total_block_errors_scl[ls] += 1 # Count as a block error if decoding fails
                 total_bit_errors_scl[ls] += INFO_BITS # Assume all bits wrong

        total_simulations += 1

    # Calculate BER and BLER for SC at this SNR
        sc_ber = total_bit_errors_sc / total_info_bits if total_info_bits > 0 else 0
        sc_bler = total_block_errors_sc / total_simulations if total_simulations > 0 else 0
        sc_ber_results.append(sc_ber)
        sc_bler_results.append(sc_bler)

        logging.info(f"  SC Decoder: BER = {sc_ber:.6f}, BLER = {sc_bler:.6f} (based on {total_simulations} trials)")


    # Calculate BER and BLER for SCL for each list size at this SNR
    for ls in LIST_SIZES:
        scl_ber = total_bit_errors_scl[ls] / total_info_bits if total_info_bits > 0 else 0
        scl_bler = total_block_errors_scl[ls] / total_simulations if total_simulations > 0 else 0
        scl_ber_results[ls].append(scl_ber)
        scl_bler_results[ls].append(scl_bler)

        logging.info(f"  SCL Decoder (L={ls}): BER = {scl_ber:.6f}, BLER = {scl_bler:.6f} (based on {total_simulations} trials)")


         logging.info("Performance evaluation finished.")

# --- Plotting ---
# You will need plotting code here, using the collected results:
# sc_ber_results, sc_bler_results, scl_ber_results, scl_bler_results
# Example (assuming matplotlib is imported as plt):

        plt.figure(figsize=(10, 6))
        plt.semilogy(SNR_RANGE_AWGN, sc_ber_results, marker='o', linestyle='-', label='SC Decoder')
        for ls in LIST_SIZES:
            plt.semilogy(SNR_RANGE_AWGN, scl_ber_results[ls], marker='x', linestyle='--', label=f'SCL Decoder (L={ls})')

        plt.xlabel('SNR (dB)')
        plt.ylabel('Bit Error Rate (BER)')
        plt.title('BER Performance Comparison (Polar Code)')
        plt.grid(True, which="both", linestyle='--')
        plt.legend()
        plt.show()

        plt.figure(figsize=(10, 6))
        plt.semilogy(SNR_RANGE_AWGN, sc_bler_results, marker='o', linestyle='-', label='SC Decoder')
for ls in LIST_SIZES:
    plt.semilogy(SNR_RANGE_AWGN, scl_bler_results[ls], marker='x', linestyle='--', label=f'SCL Decoder (L={ls})')

plt.xlabel('SNR (dB)')
plt.ylabel('Block Error Rate (BLER)')
plt.title('BLER Performance Comparison (Polar Code)')
plt.grid(True, which="both", linestyle='--')
plt.legend()
plt.show()

# Note: RNN decoder evaluation would require loading the trained model
# and running it on test data for each SNR, then plotting its results as well.
# This rewrite focuses on SC and SCL.


        # Dataset generation & saving
        X_raw, y_raw = prepare_polar_dataset(
            polar_code_gen, num_samples=NUM_SAMPLES_TRAIN, snr_db=5.0, channel_type='AWGN'
        )
        save_dataset_to_csv(X_raw, y_raw, 'awgn_dataset.csv')

        # Prepare tensors and split into train/val sets
        X_tensor = torch.FloatTensor(X_raw).view(-1, BLOCK_LENGTH).to(device)
        y_tensor = torch.FloatTensor(y_raw).view(-1, INFO_BITS).to(device)
        train_size = int(0.8 * X_tensor.shape[0])
        train_X, train_y = X_tensor[:train_size], y_tensor[:train_size]
        val_X, val_y = X_tensor[train_size:], y_tensor[train_size:]

        # Train RNN model
        train_losses, val_losses = rnn_trainer.train(
            train_X, train_y, X_val=val_X, y_val=val_y, epochs=EPOCHS, batch_size=BATCH_SIZE
        )

        # Evaluate performance
        rnn_perf_results, scl_perf_results = performance_comparison(
            rnn_trainer, polar_code_gen, SNR_RANGE_AWGN, 'AWGN', LIST_SIZES, NUM_TRIALS_PERF, device, INFO_BITS
        )

        # Plot training and evaluation results
        plot_training_validation(train_losses, val_losses)
        plot_ber_bler_rnn(SNR_RANGE_AWGN, rnn_perf_results)
        plot_ber_bler_comparison(SNR_RANGE_AWGN, rnn_perf_results, scl_perf_results, LIST_SIZES)
        plot_ber_bler_scl(SNR_RANGE_AWGN, scl_perf_results, LIST_SIZES)

        # Confusion matrix example for first 100 samples in training data
        y_true_example = train_y[:100].cpu().numpy()
        rnn_input_example = train_X[:100]
        with torch.no_grad():
            rnn_output_prob_example = rnn_trainer.model(rnn_input_example).cpu().numpy()
        rnn_output_example = (rnn_output_prob_example > 0.5).astype(int)
        plot_confusion_matrix(y_true_example.flatten(), rnn_output_example.flatten(), title='Confusion Matrix')

        logging.info("🎉 AWGN Channel Simulation Complete!")

    except Exception as e:
        logging.error(f"Simulation Error: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    main()

