<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/latestPOLARML0612.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Jun 12 15:30:27 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P0             45W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 89.6 gigabytes of available RAM

You are using a high-RAM runtime!


## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import logging
import pandas as pd
import traceback

import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatterMathtext
from matplotlib.ticker import LogFormatterMathtext
# Configure logging
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration parameters
BLOCK_LENGTH = 128
INFO_BITS = 64
LEARNING_RATE = 1e-3
EPOCHS = 50
BATCH_SIZE = 64
NUM_SAMPLES_TRAIN = 10000
NUM_SAMPLES=10000
NUM_TRIALS_PERF = 1000
#SNR_RANGE_AWGN = np.linspace(0, 5, 11)
SNR_RANGE_AWGN = np.linspace(0.5, 4.5, 13) # Suggested range for performance evaluation
#snr_range = [0, 5, 10]
LIST_SIZES = [1, 8, 16]
snr_db = 5.0     # <----- You can name this however you like!
#crc_poly = [1, 0, 0, 0, 1, 0, 0, 1]
# CRC-16-CCITT polynomial (coefficients of x^16 to x^0)
CRC_16_CCITT_POLY = np.array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int)
#generator = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc_poly)
#G = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc7_poly)
#DATASET_SNR_DB = 5.0 # Added configuration for dataset SNR


#part 1 and 2

def compute_crc(data, polynomial):
    """
    Computes the CRC checksum for binary data using a given polynomial.
    Args:
        data (np.array): Input data as a binary numpy array (0s and 1s).
        polynomial (np.array): CRC polynomial as a binary numpy array (coefficients from highest degree).
    Returns:
        np.array: The CRC checksum bits (length = len(polynomial) - 1).
    """
    poly_len = len(polynomial)
    if poly_len < 2:
        raise ValueError("Polynomial must have degree at least 1.")

    crc_length = poly_len - 1

    # Ensure data is numpy array of integers
    data = np.array(data, dtype=int)

    # Append zeros for the remainder
    data_with_zeros = np.concatenate((data, np.zeros(crc_length, dtype=int)))

    # Use a copy to avoid modifying the input data_with_zeros array directly
    remainder = np.copy(data_with_zeros)

    # Perform polynomial division
    # The loop runs over the data bits that are aligned with the highest bit of the polynomial
    for i in range(len(remainder) - poly_len + 1):
        if remainder[i] == 1:
            # XOR the current segment with the polynomial
            remainder[i:i + poly_len] ^= polynomial

    # The remainder after the loop is the CRC checksum
    return remainder[-crc_length:]
###################################################################################
#Latest Polarcode generator
#Latest Polarcode generator
class PolarCodeGenerator:
    """
    Polar Code Generator for N=128 (with CRC option, reliability sequence for info/frozen placement).
    """
    # Reliability order for N=128 (first is most reliable, last is least reliable)
    RELIABILITY_SEQUENCE_128 = [
        0, 1, 2, 4, 8, 16, 3, 5, 6, 9, 10, 12, 17, 18, 20, 24,
        32, 7, 11, 13, 14, 19, 21, 22, 25, 26, 28, 33, 34, 36, 38, 40,
        15, 23, 27, 29, 30, 35, 37, 39, 41, 42, 44, 48, 31, 43, 45, 46,
        49, 50, 52, 56, 47, 51, 53, 54, 57, 58, 60, 62, 63, 55, 59, 61,
        64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
        80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
        96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127
    ]
     # CRC-16-CCITT polynomial (coefficients of x^16 to x^0)
    CRC_16_CCITT_POLY = np.array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int)


    def __init__(self, N=128, K=64):
        """
        Args:
            N (int): Codeword/block length (should match length of RELIABILITY_SEQUENCE_128)
            K (int): Number of payload (info) bits (without CRC)
        """
        self.N = N
        self.K = K
        self.crc_poly = self.CRC_16_CCITT_POLY
        self.crc_length = len(self.crc_poly) - 1 # Should be 16 for CRC-16-CCITT

        # Compute total info bits (payload + CRC)
        self.KwCRC = self.K + self.crc_length

        # Ensure N is supported and KwCRC doesn't exceed N
        if N != 128:
             raise ValueError(f"PolarCodeGenerator is currently implemented for N=128. Got N={N}")
        if self.KwCRC > N:
             raise ValueError(f"K + CRC length ({self.KwCRC}) cannot be greater than N ({N})")

        # Info bits are placed at the KwCRC most reliable positions based on RELIABILITY_SEQUENCE_128
        self.info_set = sorted(self.RELIABILITY_SEQUENCE_128[:self.KwCRC])
        # Frozen bits are placed at the remaining less reliable positions
        self.frozen_set = sorted(set(range(N)) - set(self.info_set))

        logging.info(f"Generator initialized: N={self.N}, K={self.K}, CRC Length={self.crc_length}, KwCRC={self.KwCRC}")
        logging.info(f"Info set indices ({len(self.info_set)}): {self.info_set[:10]}...") # Print first few
        logging.info(f"Frozen set indices ({len(self.frozen_set)}): {self.frozen_set[:10]}...") # Print first few

    def generate_payload(self):
        """Generates K random information bits (payload)."""
        return np.random.randint(2, size=self.K)



    def encode(self, payload_bits):
        """
        Args:
            payload_bits (array-like): Length K (payload, i.e. without CRC).
        Returns:
            codeword (np.array): Encoded polar codeword (length N)
            info_bits_with_crc (np.array): Full info vector (payload + CRC if used, length KwCRC)
        """
        payload_bits = np.array(payload_bits, dtype=int)
        if len(payload_bits) != self.K:
            raise ValueError(f"Expected {self.K} payload bits, got {len(payload_bits)}")

        # Add CRC bits if needed
        if self.crc_poly is not None:
            # Use the global compute_crc or define it within the class if preferred
            crc_bits = compute_crc(payload_bits, self.crc_poly)
            info_bits_with_crc = np.concatenate([payload_bits, crc_bits])
        else:
            info_bits_with_crc = payload_bits

        if len(info_bits_with_crc) != self.KwCRC:
             raise RuntimeError(f"CRC appending resulted in {len(info_bits_with_crc)} bits, expected {self.KwCRC}")


        # Place info bits at most reliable positions, rest frozen (zero)
        u = np.zeros(self.N, dtype=int)
        # Place info_bits_with_crc into the u vector at info_set indices
        u[self.info_set] = info_bits_with_crc
        codeword = self._arikan_transform(u)
        return codeword, info_bits_with_crc

    def _arikan_transform(self, u):
        """
        Fast polar (Arıkan) transform (in-place butterfly).
        """
        N = len(u)
        x = u.copy()
        n = int(np.log2(N))
        for i in range(n):
            step = 2 ** i
            for j in range(0, N, 2*step):
                for k in range(step):
                    x[j+k] ^= x[j+k+step]
        return x

    # Removed redundant compute_crc method here, using the global one


    def generate_info_bits(self):
        # Generates K information bits (payload)
        return np.random.randint(2, size=self.K)

    # Removed _place_info_bits and _polar_encode_recursive, replaced by encode method

    # Removed append_crc here, now handled inside encode method using the global compute_crc


# --- Simulation Functions ---

def bpsk_modulate(bits):
    # Map 0 -> +1, 1 -> -1
    bits = np.array(bits, dtype=int)
    return 1 - 2*bits

def add_awgn_noise(x, snr_db, code_rate):
    """
    Adds AWGN noise. SNR here is Eb/N0, bit energy to noise ratio, in dB.
    Requires code_rate (K_info / N or KwCRC / N). Using KwCRC/N for rate.
    """
    snr_linear_eb = 10**(snr_db/10)
    # Noise variance sigma^2 = N0/2.
    # Eb/N0 = (Es/R) / N0, Es=1 for BPSK
    # N0 = 1 / (R * (Eb/N0))
    # sigma^2 = 1 / (2 * R * (Eb/N0)_linear)
    noise_var = 1 / (2 * code_rate * snr_linear_eb)
    noise = np.sqrt(noise_var) * np.random.randn(*x.shape)
    return x + noise, noise_var

def make_llr(received, noise_var):
    return 2*received/noise_var

def prepare_polar_dataset(gen, num_samples, snr_db, channel_type='AWGN'):
    """
    Prepares a dataset of noisy channel outputs (LLRs) and corresponding
    original information bits (payload + CRC).
    """
    X_data = np.zeros((num_samples, gen.N))
    # Correctly use gen.KwCRC for the size of information bits + CRC
    y_data = np.zeros((num_samples, gen.KwCRC), dtype=int)

    code_rate = gen.KwCRC / gen.N # Use the correct rate including CRC

    for i in range(num_samples):
        # 1. Generate K info bits (payload)
        payload = gen.generate_info_bits()

        # 2. Encode payload (handles CRC appending and polar transform internally)
        codeword, info_bits_with_crc = gen.encode(payload)

        # 3. BPSK modulate
        x = bpsk_modulate(codeword)

        # 4. Add noise and calculate LLRs
        if channel_type == 'AWGN':
            rx, noise_var = add_awgn_noise(x, snr_db, code_rate)
            llr = make_llr(rx, noise_var)
        # Add other channel types here if needed
        else:
            raise ValueError(f"Unsupported channel type: {channel_type}")

        # Store LLRs (received signal) and the original info_bits_with_crc
        X_data[i] = llr
        y_data[i] = info_bits_with_crc # y_data should be the bits we want to decode to

    return X_data, y_data

def save_dataset_to_csv(X, y, filename='dataset.csv'):
    # X shape: (num_samples, N)
    # y shape: (num_samples, KwCRC)
    # Data to save: (num_samples, N + KwCRC)
    data = np.hstack((X, y))
    columns = [f'received_{i}' for i in range(X.shape[1])] + [f'info_bit_{j}' for j in range(y.shape[1])]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(filename, index=False)
    logging.info(f"Dataset saved to {filename}")
####################################################################################
#part 3
#Rewrite
######################################################

def bpsk_modulate(bits):
    bits = np.array(bits, dtype=int)
    # Map 0 -> -1, 1 -> +1
    return 2 * bits - 1

class EnhancedChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        self.channel_type = channel_type

    def simulate(self, signal, snr_db):
        # signal is assumed to be BPSK modulated (+1 or -1)
        # Calculate noise variance based on SNR_db and signal power (which is 1 for BPSK)
        # SNR_linear = E_s / N0, where E_s = 1 for BPSK
        # Noise variance sigma^2 = N0/2.
        # SNR_linear = 1 / N0 = 1 / (2 * sigma^2)
        # sigma^2 = 1 / (2 * SNR_linear)
        # Need to consider code rate R = K/N if SNR is E_b/N0
        # E_b/N0 = (E_s/R) / N0 = (1/R) * (E_s/N0) = (1/R) * SNR_linear_Es
        # SNR_linear_Es = R * SNR_linear_Eb
        # sigma^2 = 1 / (2 * SNR_linear_Es) = 1 / (2 * R * SNR_linear_Eb)
        # The PolarCodeGenerator stores K and N, but R is needed here.
        # We need the code rate from the generator. Let's pass it or the generator object.

        # Assuming snr_db provided is E_b/N0
        # We need the code rate K/N used for the encoding. Let's assume the generator object is available or pass R.
        # For now, let's assume R is available from a generator instance or a global var if constant.
        # If we simulate *per block*, we need the specific R for that block (though N, K are fixed here).
        # Let's assume R = INFO_BITS / BLOCK_LENGTH is constant.
        R = INFO_BITS / BLOCK_LENGTH # This might need adjustment if K_crc is used for rate calc

        snr_linear_eb = 10 ** (snr_db / 10)
        # Noise variance for E_b/N0
        noise_variance = 1 / (2 * R * snr_linear_eb)
        noise_std = np.sqrt(noise_variance)
        noise = noise_std * np.random.randn(*signal.shape)
        return signal + noise


def bpsk_modulate(bits):
    # Map 0 -> +1, 1 -> -1
    bits = np.array(bits, dtype=int)
    return 1 - 2*bits

def add_awgn_noise(x, snr_db, code_rate):
    """
    Adds AWGN noise. SNR here is Eb/N0, bit energy to noise ratio, in dB.
    Requires code_rate (K_info / N or KwCRC / N). Using KwCRC/N for rate.
    """
    snr_linear_eb = 10**(snr_db/10)
    # Noise variance sigma^2 = N0/2.
    # Eb/N0 = (Es/R) / N0, Es=1 for BPSK
    # N0 = 1 / (R * (Eb/N0))
    # sigma^2 = 1 / (2 * R * (Eb/N0)_linear)
    noise_var = 1 / (2 * code_rate * snr_linear_eb)
    noise = np.sqrt(noise_var) * np.random.randn(*x.shape)
    return x + noise, noise_var

def make_llr(received, noise_var):
    return 2*received/noise_var



def prepare_polar_dataset(gen, num_samples, snr_db, channel_type='AWGN'):
    """
    Prepares a dataset of noisy channel outputs (LLRs) and corresponding
    original information bits (payload + CRC).
    """
    X_data = np.zeros((num_samples, gen.N))
    # Correctly use gen.KwCRC for the size of information bits + CRC
    y_data = np.zeros((num_samples, gen.KwCRC), dtype=int)

    code_rate = gen.KwCRC / gen.N # Use the correct rate including CRC

    for i in range(num_samples):
        # 1. Generate K info bits (payload)
        payload = gen.generate_info_bits()

        # 2. Encode payload (handles CRC appending and polar transform internally)
        codeword, info_bits_with_crc = gen.encode(payload)

        # 3. BPSK modulate
        x = bpsk_modulate(codeword)

        # 4. Add noise and calculate LLRs
        if channel_type == 'AWGN':
            rx, noise_var = add_awgn_noise(x, snr_db, code_rate)
            llr = make_llr(rx, noise_var)
        # Add other channel types here if needed
        else:
            raise ValueError(f"Unsupported channel type: {channel_type}")

        # Store LLRs (received signal) and the original info_bits_with_crc
        X_data[i] = llr
        y_data[i] = info_bits_with_crc # y_data should be the bits we want to decode to

        if num_samples > 0:  # Only bother if dataset not empty
           print(f"DEBUG SNR={snr_db}: noise_var={noise_var:.4f}, example received[0]={rx[0]:.4f}, LLR[0]={llr[0]:.4f}")

    return X_data, y_data


def save_dataset_to_csv(X, y, filename='dataset.csv'):
    # X shape: (num_samples, N)
    # y shape: (num_samples, K_info)
    # Data to save: (num_samples, N + K_info)
    data = np.hstack((X, y))
    columns = [f'received_{i}' for i in range(X.shape[1])] + [f'info_bit_{j}' for j in range(y.shape[1])]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(filename, index=False)
    logging.info(f"Dataset saved to {filename}")

############################################################################################
#latest part 4
 #----- Model -----
#latest part 4
 #----- Model -----
class EnhancedRNNDecoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Dropout is applied after each LSTM (except last)
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Input: [batch, code_length], reshape to [batch, seq=1, code_length]
        x_reshaped = x.unsqueeze(1)  # [B, 1, N]
        # LSTM expects sequence, but ours is just 1 timestep with full codeword
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
        out, _ = self.rnn(x_reshaped, (h0, c0))   # [B, 1, H]
        out = self.fc(out[:, -1, :])              # [B, output_size]
        return self.sigmoid(out)

# ----- Trainer -----
#########################################################################################################
#add loop
crc_poly = [1,0,0,0,1,0,0,1]
BLOCK_LENGTH = 128
INFO_BITS = 64
#polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc_poly)
#polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=CRC_16_CCITT_POLY)
polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS) # Corrected line

print("Sanity check: LLR magnitudes at different SNRs")
code_rate = polar_code_gen.KwCRC / polar_code_gen.N
payload = polar_code_gen.generate_info_bits()
cw, _ = polar_code_gen.encode(payload)
x = bpsk_modulate(cw)
for snr in [0, 3, 6, 10]:
    rx, noise_var = add_awgn_noise(x, snr_db=snr, code_rate=code_rate)
    llr = make_llr(rx, noise_var)

###############################################################################################
print("KwCRC (encode):", polar_code_gen.KwCRC)
print("Encoder info set:", polar_code_gen.info_set)
print("Encoder frozen set:", polar_code_gen.frozen_set)
try:
    decoder_info_set = polar_code_gen.info_set    # Pass this into your decoder
    decoder_frozen_set = polar_code_gen.frozen_set # Pass this into your decoder (see below)
    print("Decoder info set:", decoder_info_set)
    print("Decoder frozen set:", decoder_frozen_set)
except:
    print("Decoder info_set attribute missing")

print("CRC poly (encoder):", polar_code_gen.crc_poly)
print("CRC poly (decoder):", crc_poly)
#########################################################################################################
#latest Def trainer decode on 06/12/2025
# --- Trainer ---
class DecoderTrainer:
    def __init__(self, model, learning_rate, early_stop_patience=10):
        super().__init__() # No need to call super for this class structure
        self.model = model
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.early_stop_patience = early_stop_patience
        self.device = next(model.parameters()).device # Get device from model

    def train(self, train_X, train_y, val_X=None, val_y=None, epochs=50, batch_size=128, snr_min=1, snr_max=7, generator=None):
        """
        Trains the RNN decoder with dynamic noisy batch generation.

        Args:
            train_X (torch.Tensor): Training data LLRs (Initial, can be re-noised).
            train_y (torch.Tensor): Training data labels (info_bits_with_crc).
            val_X (torch.Tensor, optional): Validation data LLRs.
            val_y (torch.Tensor, optional): Validation data labels.
            epochs (int): Number of training epochs.
            batch_size (int): Batch size for training.
            snr_min (float): Minimum SNR (dB) for dynamic noise generation.
            snr_max (float): Maximum SNR (dB) for dynamic noise generation.
            generator (PolarCodeGenerator): Polar code generator instance for re-encoding.
        """
        train_losses, val_losses = [], []
        best_val_loss = float('inf')
        patience_counter = 0

        # Ensure labels are float and on the correct device
        train_y = train_y.float().to(self.device)
        if val_y is not None:
            val_y = val_y.float().to(self.device)

        # Ensure input data is on the correct device
        train_X = train_X.to(self.device)
        if val_X is not None:
             val_X = val_X.to(self.device)


        # --- Debugging LLRs (Optional - run before training) ---
        # Need generator for rate and verification
        if generator is not None and hasattr(self, 'verify_llr_calculation'):
             code_rate = generator.KwCRC / generator.N
             logging.info("--- LLR Verification Before Training ---")
             # Verify at min, max, and a point in between
             # Pass generator or its attributes if needed by verify_llr_calculation
             self.verify_llr_calculation(snr_min, code_rate, generator=generator)
             self.verify_llr_calculation(snr_max, code_rate, generator=generator)
             self.verify_llr_calculation((snr_min + snr_max) / 2.0, code_rate, generator=generator)
             logging.info("--- End LLR Verification ---")
        elif generator is None:
             logging.warning("Generator not provided. Cannot perform full LLR verification.")
        else:
             logging.warning("verify_llr_calculation method not found in DecoderTrainer.")


        for epoch in range(epochs):
            epoch_loss = 0.0

            # Shuffle training set
            indices = torch.randperm(train_X.shape[0])
            # IMPORTANT: We need the original *info_bits_with_crc* (y_batch) to re-generate noisy data.
            # So, shuffle X and y together.
            # train_X_shuffled = train_X[indices] # Not strictly needed as we regenerate X
            train_y_shuffled = train_y[indices]


            self.model.train()
            for i in range(0, len(train_X), batch_size):
                # Get batch of original info_bits_with_crc (y_batch)
                y_batch = train_y_shuffled[i:i+batch_size]
                batch_size_actual = y_batch.size(0) # Handle last batch potentially smaller

                # --- Dynamic SNR and Re-generation for each batch ---
                if generator is not None:
                    batch_snr = np.random.uniform(snr_min, snr_max)
                    code_rate = generator.KwCRC / generator.N

                    # Re-generate noisy LLRs for this batch
                    info_bits_with_crc_batch_np = y_batch.cpu().int().numpy() # Convert to numpy int bits

                    X_noisy_batch_np = np.zeros((batch_size_actual, generator.N), dtype=np.float32)

                    # This loop re-encodes and adds noise for each sample in the batch
                    # This is the correct way given the current structure.
                    for k in range(batch_size_actual):
                         # Get the info_bits_with_crc for this sample
                         info_bits_with_crc_k = info_bits_with_crc_batch_np[k]

                         # Place info_bits_with_crc into the u vector at info_set indices
                         u_vec = np.zeros(generator.N, dtype=int)
                         # Ensure info_bits_with_crc_k has the expected length
                         if len(info_bits_with_crc_k) != generator.KwCRC:
                             logging.error(f"Mismatch in info_bits_with_crc length: Expected {generator.KwCRC}, got {len(info_bits_with_crc_k)}. Skipping sample.")
                             continue # Skip this sample if data is malformed

                         try:
                             u_vec[generator.info_set] = info_bits_with_crc_k # Place info+crc bits
                             # Encode u_vec using the generator's internal transform (_arikan_transform expects 'u' vector)
                             codeword = generator._arikan_transform(u_vec) # Use internal method if public encode doesn't expose u->x
                             # BPSK modulate
                             x = bpsk_modulate(codeword) # Use global bpsk_modulate
                             # Add noise and calculate LLRs
                             rx, noise_var = add_awgn_noise(x, batch_snr, code_rate) # Use global noise functions
                             llr = make_llr(rx, noise_var) # Use global make_llr
                             X_noisy_batch_np[k] = llr
                         except Exception as e:
                             logging.error(f"Error during re-generation for sample {k}: {e}")
                             traceback.print_exc()
                             # Handle errors: e.g., log, set row to zeros, or break
                             # Setting to zero might introduce bias, but avoids crash.
                             X_noisy_batch_np[k] = np.zeros(generator.N, dtype=np.float32) # Use zeros on error


                    # Convert the re-generated batch back to a PyTorch tensor and move to device
                    X_noisy = torch.tensor(X_noisy_batch_np, dtype=torch.float32).to(self.device)

                    # --- Optional: Debugging Re-generated Batch LLRs ---
                    # if i == 0 and epoch % 10 == 0: # Debug first batch every few epochs
                    #      logging.info(f"Epoch {epoch+1}, Batch {i//batch_size + 1}: Re-generated Batch SNR: {batch_snr:.2f}")
                    #      # Add verification calls here if needed, requires getting original bits from y_batch
                    #      pass


                else:
                    # If no generator, use the original LLRs for the batch
                    logging.warning("Generator not provided. Training on static initial LLRs.")
                    # Use the batch from the shuffled initial dataset
                    X_noisy = train_X[indices[i:i+batch_size]]


                # --- Forward Pass ---
                outputs = self.model(X_noisy)
                loss = self.criterion(outputs, y_batch) # y_batch is already float

                # --- Backward Pass and Optimization ---
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item() * batch_size_actual # Accumulate weighted by actual batch size

            train_losses.append(epoch_loss / len(train_X)) # Average over total training samples

            # ----- Validation -----
            if val_X is not None and val_y is not None and generator is not None:
                self.model.eval()
                # Evaluate on validation data at a fixed SNR (e.g., median)
                val_snr = (snr_min + snr_max) / 2.0

                # Re-generate validation batch with noise at val_snr
                info_bits_with_crc_val_np = val_y.cpu().int().numpy() # Convert to numpy int bits
                X_noisy_val_batch_np = np.zeros((val_X.size(0), generator.N), dtype=np.float32)
                code_rate = generator.KwCRC / generator.N

                for k in range(val_X.size(0)):
                      info_bits_with_crc_k = info_bits_with_crc_val_np[k]
                      u_vec = np.zeros(generator.N, dtype=int)

                      if len(info_bits_with_crc_k) != generator.KwCRC:
                          logging.error(f"Validation: Mismatch in info_bits_with_crc length for sample {k}: Expected {generator.KwCRC}, got {len(info_bits_with_crc_k)}. Skipping sample.")
                          continue

                      try:
                          u_vec[generator.info_set] = info_bits_with_crc_k
                          codeword = generator._arikan_transform(u_vec) # Use internal method if public encode doesn't expose u->x
                          x = bpsk_modulate(codeword) # Use global bpsk_modulate
                          rx, noise_var = add_awgn_noise(x, val_snr, code_rate) # Use global noise functions
                          llr = make_llr(rx, noise_var) # Use global make_llr
                          X_noisy_val_batch_np[k] = llr
                      except Exception as e:
                          logging.error(f"Validation: Error during re-generation for sample {k}: {e}")
                          traceback.print_exc()
                          X_noisy_val_batch_np[k] = np.zeros(generator.N, dtype=np.float32) # Use zeros on error

                X_noisy_val = torch.tensor(X_noisy_val_batch_np, dtype=torch.float32).to(self.device)

                with torch.no_grad():
                    val_output = self.model(X_noisy_val)
                    val_loss = self.criterion(val_output, val_y).item()
                val_losses.append(val_loss)

                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_loss:.4f}")

                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    torch.save(self.model.state_dict(), "best_rnn_decoder.pt")
                    logging.info(f"Epoch {epoch+1}: Best validation loss improved, saving model.")
                else:
                    patience_counter += 1
                    if patience_counter >= self.early_stop_patience:
                        print(f"Early stopping activated after {self.early_stop_patience} epochs without improvement.")
                        break
            elif val_X is not None and val_y is not None and generator is None:
                 logging.warning("Generator not provided. Validation will use static initial LLRs.")
                 self.model.eval()
                 with torch.no_grad():
                     val_output = self.model(val_X)
                     val_loss = self.criterion(val_output, val_y).item()
                 val_losses.append(val_loss)
                 print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_loss:.4f}")
                 # Early stopping logic here would need to be based on training loss or just disabled
                 # Assuming early stopping is less useful without dynamic validation data
            else:
                # No validation set provided
                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f}")
                # If no validation, save model based on train loss or periodically
                if len(train_losses) == 1 or train_losses[-1] < train_losses[-2]:
                     torch.save(self.model.state_dict(), "last_rnn_decoder.pt") # Save latest if improving train loss


        # Load the best model weights if early stopping occurred and validation was used
        if patience_counter >= self.early_stop_patience and val_X is not None:
             try:
                 self.model.load_state_dict(torch.load("best_rnn_decoder.pt"))
                 print("Loaded best model based on validation performance.")
             except FileNotFoundError:
                 logging.warning("Best model file not found after early stopping. Using last epoch weights.")
        elif val_X is None:
            # If no validation, you might want to load the last saved model
             try:
                 self.model.load_state_dict(torch.load("last_rnn_decoder.pt"))
                 print("Loaded last saved model (no validation used for early stopping).")
             except FileNotFoundError:
                 logging.warning("Last model file not found. Using final epoch weights.")


        return train_losses, val_losses if len(val_losses)>0 else None

    # --- Moved evaluate method INSIDE the DecoderTrainer class ---
    def evaluate(self, X_test, y_test, threshold=0.5):
        """
        Evaluates the decoder performance (BER and BLER).
        Assumes X_test is LLRs (for the codeword) and y_test are original info_bits_with_crc (labels).
        """
        self.model.eval()
        # Ensure y_test is int for comparison and on the correct device
        y_test_int = y_test.int().to(self.device) # These are the true info_bits_with_crc
        X_test = X_test.to(self.device) # These are the LLRs for the test codewords

        with torch.no_grad():
            outputs = self.model(X_test) # Model predicts KwCRC bits (probabilities)

            # Threshold probabilities to get hard bit predictions for KwCRC bits
            preds_kwcrc = (outputs > threshold).int() # Predicted info_bits_with_crc

            # --- Calculate BER ---
            total_bits_kwcrc = y_test_int.numel()
            bit_errors_kwcrc = torch.sum(preds_kwcrc != y_test_int).item()
            ber_kwcrc = bit_errors_kwcrc / total_bits_kwcrc if total_bits_kwcrc > 0 else 0.0

            # --- Calculate BLER ---
            block_error_flags_kwcrc = torch.any(preds_kwcrc != y_test_int, dim=1)
            block_errors_kwcrc = torch.sum(block_error_flags_kwcrc).item()
            total_blocks = X_test.size(0)
            bler_kwcrc = block_errors_kwcrc / total_blocks if total_blocks > 0 else 0.0

        # Return ONLY two values (BER, BLER) for the KwCRC bits
        return ber_kwcrc, bler_kwcrc

    # --- Moved verify_llr_calculation method INSIDE the DecoderTrainer class ---
    def verify_llr_calculation(self, snr_db, code_rate, num_samples=10000, generator=None):
        """Generates random bits, encodes, adds noise, calculates LLRs,
           and compares with expected LLR behavior. Requires generator."""
        if generator is None:
            logging.warning("Cannot verify LLR calculation without a generator.")
            return

        logging.info(f"Verifying LLR calculation for SNR={snr_db} dB, Rate={code_rate}")

        # Generate random bits (equivalent of info_bits_with_crc for this test)
        # We need N bits to simulate the codeword for LLR verification purposes
        # This is a simplified test - real LLRs are on channel output of codeword
        # A better verification would encode random payload bits using the generator
        # and then check LLRs of the *encoded* codeword.
        # For this simplified check, let's generate N random bits and treat them as BPSK symbols
        random_bits_for_bpsk = np.random.randint(2, size=(num_samples, generator.N))

        # Simulate channel using global functions
        # BPSK map 0->+1, 1->-1 for AWGN LLR formula derivation (as per bpsk_modulate function)
        bpsk_signal = bpsk_modulate(random_bits_for_bpsk) # Use the global bpsk_modulate
        rx_signal, noise_var = add_awgn_noise(bpsk_signal, snr_db, code_rate) # Use global add_awgn_noise
        calculated_llr = make_llr(rx_signal, noise_var) # Use global make_llr

        # Expected LLR for transmitted bit 0 (+1 in BPSK) should be positive
        # Expected LLR for transmitted bit 1 (-1 in BPSK) should be negative
        # The sign of the LLR should match the sign of the *transmitted* BPSK symbol (1 for 0, -1 for 1)
        expected_llr_signs = np.sign(1 - 2 * random_bits_for_bpsk)
        actual_llr_signs = np.sign(calculated_llr)

        # Check proportion of correct signs
        correct_sign_proportion = np.mean(expected_llr_signs == actual_llr_signs)
        logging.info(f"  Proportion of LLRs with correct sign: {correct_sign_proportion:.4f}")

        # Check mean absolute LLR magnitude (rough check)
        mean_abs_llr = np.mean(np.abs(calculated_llr))
        logging.info(f"  Mean absolute LLR magnitude: {mean_abs_llr:.4f}")

        # Example: Mean LLR for transmitted 0s vs 1s
        mean_llr_zeros = np.mean(calculated_llr[random_bits_for_bpsk == 0]) if np.any(random_bits_for_bpsk == 0) else np.nan
        mean_llr_ones = np.mean(calculated_llr[random_bits_for_bpsk == 1]) if np.any(random_bits_for_bpsk == 1) else np.nan
        logging.info(f"  Mean LLR for transmitted 0s: {mean_llr_zeros:.4f}")
        logging.info(f"  Mean LLR for transmitted 1s: {mean_llr_ones:.4f}")

        if not (np.isnan(mean_llr_zeros) or mean_llr_zeros > 0):
             logging.warning(f"  Warning: Mean LLR for transmitted 0s is not positive ({mean_llr_zeros:.4f}).")
        if not (np.isnan(mean_llr_ones) or mean_llr_ones < 0):
             logging.warning(f"  Warning: Mean LLR for transmitted 1s is not negative ({mean_llr_ones:.4f}).")

# The rest of the code (performance_comparison, plotting, main) remains the same.
# Ensure the global functions bpsk_modulate, add_awgn_noise, make_llr are defined before DecoderTrainer.
# Also ensure PolarCodeGenerator is defined and an instance polar_code_gen exists before DecoderTrainer.

#######################################################################################################








 ############################################################################################
# --- Example Usage (Assuming other parts of your notebook are run) ---

# Assuming you have a PolarCodeGenerator instance, e.g.:
# crc7_poly = [1, 0, 0, 0, 1, 0, 0, 1] # Example CRC-7
# generator = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc7_poly)

# Assuming you have generated initial training and validation data:
# # Define a fixed SNR for initial dataset generation
# DATASET_SNR_DB = 5.0 # You can choose an appropriate SNR
# num_train_samples = 10000
# num_val_samples = 2000 # Optional validation set

# train_X_np, train_y_np = prepare_polar_dataset(generator, num_train_samples, DATASET_SNR_DB)
# train_X = torch.FloatTensor(train_X_np)
# train_y = torch.FloatTensor(train_y_np) # Keep labels as float for training

# If using validation:
# val_X_np, val_y_np = prepare_polar_dataset(generator, num_val_samples, DATASET_SNR_DB)
# val_X = torch.FloatTensor(val_X_np)
# val_y = torch.FloatTensor(val_y_np)

# Assuming you have initialized the RNN model:
# input_size = generator.N # LLRs length
# output_size = generator.KwCRC # Info bits + CRC length
# rnn_decoder = EnhancedRNNDecoder(input_size, output_size).to(device)

# Initialize the Trainer:
# learning_rate = 1e-3
# trainer = DecoderTrainer(rnn_decoder, learning_rate=learning_rate, early_stop_patience=10)

# Start Training:
# snr_train_min = 1 # Min SNR for dynamic training noise
# snr_train_max = 7 # Max SNR for dynamic training noise

# train_losses, val_losses = trainer.train(
#     train_X, train_y,
#     val_X=val_X, val_y=val_y, # Pass validation data if available
#     epochs=EPOCHS,
#     batch_size=BATCH_SIZE,
#     snr_min=snr_train_min,
#     snr_max=snr_train_max,
#     generator=generator # Pass the


# Part 5
############################################################################
class PolarCodeDecoder:
    """
    Polar Code Decoder supporting both SC and SCL (list) decoding with optional CRC checking.
    """

    # >>>>>> THIS LINE MUST BE: <<<<<<
    def __init__(self, N, K, list_size=1, crc_poly=None, info_set=None, frozen_set=None):
    # >>>>>> NOT: def __init__(self, N, K_decoder, ...) <<<<<<

        self.N = N
        self.K = K # Make sure it's assigned to self.K
        self.list_size = list_size
        self.crc_poly = crc_poly

        # Use the sets passed in if available
        if info_set is not None and frozen_set is not None:
            self.info_set = list(info_set)
            self.frozen_set = list(frozen_set)
        else:
             # Fallback - still likely incorrect for reliability sequence, but use self.K
             logging.warning("PolarCodeDecoder initialized without explicit info/frozen sets. Using default sequential assignment, which may be incorrect for reliability sequence.")
             self.info_set = sorted(range(self.K)) # Use self.K here
             self.frozen_set = sorted(set(range(self.N)) - set(self.info_set))

        # CRC properties
        if self.crc_poly is not None:
            self.crc_length = len(self.crc_poly) - 1
        else:
            self.crc_length = 0

        # Debug print using self.K
        logging.info(f"Decoder Initialized: N={self.N}, K (KwCRC)={self.K}, List Size={self.list_size}, CRC Length={self.crc_length}")
        # ... rest of __init__



        # Add debug prints here to confirm the sets and K value
      #  logging.info(f"Decoder Initialized: N={self.N}, K (KwCRC)={self.K}, List Size={self.list_size}, CRC Length={self.crc_length}")
        logging.info(f"Decoder Info set indices ({len(self.info_set)}): {self.info_set[:min(len(self.info_set), 10)]}...") # Use min to avoid error on small sets
        logging.info(f"Decoder Frozen set indices ({len(self.frozen_set)}): {self.frozen_set[:min(len(self.frozen_set), 10)]}...") # Use min
        if self.crc_poly is not None:
             logging.info(f"Decoder CRC poly: {self.crc_poly}")

    def encode_with_crc(self, payload_bits):
        """
        Attach CRC to the payload_bits if CRC polynomial is specified.
        Args:
            payload_bits (array-like): Length = K - crc_length  # This K refers to INFO_BITS, not KwCRC
        Returns:
            code_bits (np.array): Length = K (payload + CRC) # This K refers to KwCRC
        """
        # NOTE: This method seems like it belongs in the PolarCodeGenerator, not the Decoder.
        # The decoder should not be encoding. It should only be decoding received LLRs.
        # This method is likely unused or misplaced. If it's used, review its purpose.
        logging.warning("PolarCodeDecoder.encode_with_crc called. This method seems misplaced in a decoder class.")
        payload_bits = np.array(payload_bits, dtype=int)
        # The expected payload length for this method is ambiguous.
        # Should it be INFO_BITS (K) or KwCRC? Assuming it's the payload bits *before* CRC (length INFO_BITS).
        # If so, it should append crc_length bits to get K_decoder (KwCRC) total bits.
        # This method's logic needs review based on its intended use.
        # For now, assuming it takes INFO_BITS and appends CRC.

        # Check input length consistency with INFO_BITS
        # You would need INFO_BITS defined or passed to the decoder if this method is used.
        # Since it's a decoder class, it likely shouldn't have this method.

        # Reverting to the original logic but noting the potential issue
        payload_bits = np.array(payload_bits, dtype=int)
        if self.crc_poly is not None:
            # Need to know the original K (INFO_BITS) here to check payload_bits length
            # Assuming the input payload_bits is of length INFO_BITS
            # This makes this method depend on INFO_BITS, which isn't a class attribute here.
            # This method is almost certainly misplaced or buggy as currently written in the decoder.
            # For the purpose of fixing the TypeError, we leave it as is, but note it's suspicious.
            crc_bits = self.compute_crc(payload_bits, self.crc_poly)
            code_bits = np.concatenate([payload_bits, crc_bits])
            return code_bits
        else:
            return payload_bits # If no CRC, it returns the input payload_bits as is.


    def decode(self, llr):
        """
        Decodes the input LLRs using SC or SCL decoder.
        Args:
            llr (np.array): Array of length N (codeword LLRs)
        Returns:
            info_bits_with_crc (np.array): Decoded information bits (including CRC bits if present)
        """
        if self.list_size == 1:
            u_hat = self._sc_decode(llr)
        else:
            u_hat = self._scl_decode(llr) # SCL decoder should also return the full N-length u_hat

        # Extract info_bits_with_crc from the decoded u_hat vector using the info_set indices
        # The length of decoded_info_bits_with_crc will be len(self.info_set) == self.K (KwCRC)
        decoded_info_bits_with_crc = u_hat[self.info_set]

        # CRC check is performed on the extracted info_bits_with_crc *after* selecting the best path (for SCL)
        # or on the single SC path result. The crc_check method is called within this decode method or by the caller.
        # The warning about "No CRC-valid candidate" is printed *inside* _scl_decode
        # This main decode method should probably not print the warning again.

        # The crc_check result is used by the caller (e.g., SCL logic in _scl_decode)
        # This main decode method just returns the decoded bits from the best path (or only path).

        return decoded_info_bits_with_crc # Return the extracted bits of length KwCRC

    # Placeholder SC decode (Incorrect implementation)
    def _sc_decode(self, llr):
        """Basic SC decoding (Placeholder - likely incorrect). Returns the full N-length u_hat vector."""
        logging.error("Placeholder _sc_decode used. Implement a correct SC decoder.")
        u_hat = np.zeros(self.N, dtype=int)
        # This original simple thresholding is NOT a correct SC decoder
        for i in range(self.N):
             if i in self.frozen_set:
                 u_hat[i] = 0
             else:
                 # This should use calculated LLRs after propagation, not raw input LLRs
                 u_hat[i] = 0 if llr[i] >= 0 else 1 # Incorrect thresholding of raw LLR

        return u_hat # Return N-length u_hat


    # Placeholder SCL decode (Incorrect implementation)
    def _scl_decode(self, llr):
        """
        Placeholder for a correct SCL decoder implementation.
        The provided implementation is simplified and likely incorrect.
        Implementing a correct SCL decoder is a non-trivial task.
        Returns the full N-length u_hat vector of the best path.
        """
        logging.error("Placeholder _scl_decode called. A proper SCL implementation is required for correct performance.")

        # Reverting to the structure of the user's original _scl_decode but confirming it's likely incorrect
        paths = [([], 0.0)]  # Each entry: (path_bits, path_metric)
        for i in range(self.N):
            new_paths = []
            for bits, metric in paths:
                if i in self.frozen_set:
                    # Only extend with 0
                    new_bits = bits + [0]
                    # This metric update is incorrect for polar decoding LLRs
                    new_metric = metric + self._metric(llr[i], 0) # Uses raw LLR[i] - wrong
                    new_paths.append((new_bits, new_metric))
                else:
                    # Extend with both 0 and 1
                    for bit in [0, 1]:
                        ext_bits = bits + [bit]
                        # This metric update is incorrect for polar decoding LLRs
                        ext_metric = metric + self._metric(llr[i], bit) # Uses raw LLR[i] - wrong
                        new_paths.append((ext_bits, ext_metric))
            # Prune paths: keep only best list_size based on *this incorrect metric*
            new_paths.sort(key=lambda x: x[1])
            paths = new_paths[:self.list_size]

        # At the end: Check CRC if needed on the info+CRC bits for each path in the list
        # The info_indices are self.info_set
        info_indices = list(self.info_set) # These are indices within the N-length u_hat vector

        best_u_hat = None
        best_metric_crc_valid = float('inf')
        best_u_hat_overall = None # To store the best metric path if no CRC valid path found

        for bits, metric in paths:
            # Ensure the path length matches N
            if len(bits) == self.N:
                 # Extract info bits from the full N-length path
                potential_info_bits_with_crc = np.array([bits[i] for i in info_indices], dtype=int)

                # Keep track of the overall best metric path (needed if no CRC path found)
                if best_u_hat_overall is None or metric < best_u_hat_overall[1]:
                     best_u_hat_overall = (np.array(bits, dtype=int), metric) # Store the full u_hat and metric

                # Check CRC on the extracted info+CRC bits
                if self.crc_poly is not None and self.crc_check(potential_info_bits_with_crc):
                     # Found a CRC-valid path, keep the one with the best metric among CRC-valid paths
                     if best_u_hat is None or metric < best_metric_crc_valid:
                         best_u_hat = np.array(bits, dtype=int) # Store the full N-length u_hat
                         best_metric_crc_valid = metric

            else:
                logging.error(f"SCL Decoding: Path length mismatch during final processing: Expected {self.N}, got {len(bits)}")


        # Selection Logic:
        if best_u_hat is not None:
            # Return the u_hat from the best CRC-valid path
            return best_u_hat
        elif self.crc_poly is not None:
            # No CRC-valid path found among candidates, log a warning and fall back to best metric overall
            logging.warning("No CRC-valid candidate found in list; using lowest metric path (CRC failed).")
            if best_u_hat_overall is not None:
                 return best_u_hat_overall[0] # Return the u_hat from the best overall metric path
            else:
                 logging.error("SCL Decoder finished with no paths and no CRC-valid path.")
                 return np.zeros(self.N, dtype=int) # Return zeros if no paths were even generated

        else:
            # No CRC used, return the u_hat from the best metric path
            if best_u_hat_overall is not None:
                 return best_u_hat_overall[0]
            else:
                 logging.error("SCL Decoder finished with no paths (no CRC used).")
                 return np.zeros(self.N, dtype=int)


    @staticmethod
    def _metric(llr, bit):
        """Path metric increment for bit decision 0/1 given LLR."""
        # log(1+exp(-llr)) for bit=0, log(1+exp(llr)) for bit=1
        # This calculation is standard for LLRs but applying it directly to raw
        # channel LLRs without polar decoding updates is the issue in _scl_decode.
        return np.log1p(np.exp(-llr)) if bit == 0 else np.log1p(np.exp(llr))

    def crc_check(self, info_bits_with_crc):
        """
        Checks CRC on info_bits_with_crc (expects last crc_length bits are CRC).
        Returns True if CRC matches, False otherwise.
        """
        if self.crc_poly is None or self.crc_length == 0:
            return True
        poly = np.array(self.crc_poly, dtype=int)
        data = np.array(info_bits_with_crc, dtype=int) # Input is info+CRC bits
        # Check if data length is at least crc_length
        if len(data) < self.crc_length:
            logging.error(f"CRC check failed: input data length ({len(data)}) is less than CRC length ({self.crc_length}).")
            return False # Cannot perform CRC check
        payload = data[:-self.crc_length]
        received_crc = data[-self.crc_length:]
        calc_crc = self.compute_crc(payload, poly)
        return np.array_equal(calc_crc, received_crc)

    @staticmethod
    def compute_crc(data, polynomial):
        """
        Computes CRC bits (classic binary modulo-2 division).
        Args:
            data (np.array): Info bits (payload), shape (payload length,)
            polynomial (array-like): Polynomial, e.g., [1,0,1,1]
        Returns:
            np.array: CRC bits, length len(polynomial)-1
        """
        # This static method is used by both the generator (in encode) and the decoder (in crc_check)
        # Ensure it's consistent with the global compute_crc if that's also used.
        data = np.array(data, dtype=int)
        polynomial = np.array(polynomial, dtype=int)
        crc_length = len(polynomial) - 1
        if crc_length == 0:
             return np.array([], dtype=int) # Handle case of polynomial degree 0 (shouldn't happen for CRC)

        # Need to handle potential ValueError if data is empty and crc_length > 0
        if len(data) == 0 and crc_length > 0:
             logging.warning("compute_crc called with empty data but non-zero CRC length. Returning zeros.")
             return np.zeros(crc_length, dtype=int)


        data_padded = np.concatenate([data, np.zeros(crc_length, dtype=int)])
        # The division loop should iterate over the data bits *before* the zero padding.
        # The original loop "range(len(data_padded) - len(polynomial) + 1)" is correct for polynomial division.
        # Let's stick to the original loop logic.
        temp_data = np.copy(data_padded) # Operate on a copy

        for i in range(len(data)): # Iterate over the actual data bits
            if temp_data[i] == 1:
                temp_data[i:i+len(polynomial)] ^= polynomial

        return temp_data[-crc_length:]




############################################################################



###############################################################################
# latest Performance comparison
def performance_comparison(
        rnn_trainer,
        polar_code_gen,
        snr_range,
        channel_type,
        list_sizes,
        num_trials,
    ):
    device = next(rnn_trainer.model.parameters()).device
    num_info_bits = polar_code_gen.KwCRC  # Must use total info+CRC for error computation!
    rnn_results = {'BER_RNN': [], 'BLER_RNN': []}
    sc_results = {'BER_SC': [], 'BLER_SC': []}
    scl_results = {L: [] for L in list_sizes}

    for snr_db in snr_range:
        # === Generate new data for THIS SNR ===
        X_eval, y_eval = prepare_polar_dataset(
            polar_code_gen, num_samples=num_trials, snr_db=snr_db, channel_type=channel_type
        )
        X_tensor = torch.FloatTensor(X_eval).to(device)
        y_tensor = torch.FloatTensor(y_eval).to(device)

        # --- RNN Decoder ---
        ber_rnn, bler_rnn = rnn_trainer.evaluate(X_tensor, y_tensor)

        rnn_results['BER_RNN'].append(ber_rnn)
        rnn_results['BLER_RNN'].append(bler_rnn)

        # --- SC Decoder (list_size=1) ---
        sc_decoder = PolarCodeDecoder(
        N=BLOCK_LENGTH,
        K=num_info_bits,  # Should be .KwCRC
        list_size=1,
        crc_poly=crc_poly,
        info_set=polar_code_gen.info_set,
        frozen_set=polar_code_gen.frozen_set
        )
#And for SCL:
       # sc_decoder = PolarCodeDecoder(polar_code_gen.N, num_info_bits, list_size=1, crc_poly=polar_code_gen.crc_poly)
        bit_errors, blk_errors = 0, 0
        for i in range(num_trials):
            llr = X_eval[i]          # <--- This LLR was generated at THIS SNR
            true_bits = y_eval[i]
            decoded_bits = sc_decoder.decode(llr)
            bit_errors += np.sum(decoded_bits != true_bits)
            blk_errors += int(np.any(decoded_bits != true_bits))
        sc_results['BER_SC'].append(bit_errors / (num_trials * num_info_bits))
        sc_results['BLER_SC'].append(blk_errors / num_trials)

        # --- SCL Decoders ---
        for L in list_sizes:
            scl_decoder = PolarCodeDecoder(
            N=BLOCK_LENGTH,
            K=num_info_bits,
            list_size=L,
            crc_poly=crc_poly,
            info_set=polar_code_gen.info_set,
            frozen_set=polar_code_gen.frozen_set
            )
          #  scl_decoder = PolarCodeDecoder(polar_code_gen.N, num_info_bits, list_size=L, crc_poly=polar_code_gen.crc_poly)
            bit_errors, blk_errors = 0, 0
            for i in range(num_trials):
                llr = X_eval[i]
                true_bits = y_eval[i]
                decoded_bits = scl_decoder.decode(llr)
                bit_errors += np.sum(decoded_bits != true_bits)
                blk_errors += int(np.any(decoded_bits != true_bits))
            scl_results[L].append({'BER': bit_errors / (num_trials * num_info_bits), 'BLER': blk_errors / num_trials})

    return rnn_results, scl_results, sc_results
########################################################################




############################################################################
#part 6 Plotting fuctions

def plot_ber_bler_comparison(snr_range, rnn_results, scl_results, sc_results, list_sizes):
    plt.figure(figsize=(18, 6))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BLER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BLER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()





# for RNN and SCL


def plot_rnn_ber_bler(snr_range, rnn_results):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], marker='o', label='RNN BER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('RNN Bit Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], marker='o', label='RNN BLER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('RNN Block Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()
#2. SCL: Specialized BER/BLER Plot (for a single list size at a time)


def plot_sc_ber_bler(snr_range, sc_results):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, sc_results['BER_SC'], marker='s', label='SC BER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SC Bit Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, sc_results['BLER_SC'], marker='s', label='SC BLER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SC Block Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()

#SCL Plotting Function (ALL list sizes in ONE figure)

def plot_scl_ber_bler(snr_range, scl_results, list_sizes):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    for L in list_sizes:
        ber_list = [res['BER'] for res in scl_results[L]]
        plt.plot(snr_range, ber_list, marker='^', label=f'SCL L={L}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SCL Bit Error Rate (All List Sizes)')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    for L in list_sizes:
        bler_list = [res['BLER'] for res in scl_results[L]]
        plt.plot(snr_range, bler_list, marker='^', label=f'SCL L={L}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SCL Block Error Rate (All List Sizes)')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()


######################################################################
def plot_training_validation(train_losses, val_losses):
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, label='Training Loss')
    if val_losses:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()


def plot_confusion_matrix(y_true, y_pred, title='Confusion Matrix'):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title(title)
    plt.show()

def plot_ber_bler_comparison(snr_range, rnn_results, scl_results, sc_results, list_sizes):
    plt.figure(figsize=(18, 6))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BLER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BLER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()
#Part 7: Main Function
###############################################################################

#latest main()
#latest main() on 06/12/2025
def main():

    try:
        # --- Config ---
        BLOCK_LENGTH = 128
        INFO_BITS = 64
        CRC_16_CCITT_POLY = np.array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int)
        LEARNING_RATE = 1e-3
        EPOCHS = 50
        BATCH_SIZE = 64
        NUM_SAMPLES=10000
        NUM_SAMPLES_TRAIN = 10000
        NUM_TRIALS_PERF = 1000
        SNR_RANGE_AWGN= np.linspace(0.5, 4.5, 13) # Suggested range for performance evaluation
        LIST_SIZES = [1, 8, 16]
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # --- Codec Generator ---
        # Use the corrected instantiation without the crc_poly keyword argument
        polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
        KwCRC = polar_code_gen.KwCRC
        info_set = polar_code_gen.info_set
        frozen_set = polar_code_gen.frozen_set

        print("KwCRC =", KwCRC)
        print("info_set[:10] =", info_set[:10])
        print("frozen_set[:10] =", frozen_set[:10])
        # The crc_poly variable needs to be defined locally if you still want to print it or use it later
        # For instance, if you want to pass it to the PolarCodeDecoder constructor, it needs to exist.
        # Let's define the local crc_poly variable here, using the same value as the internal one for consistency.
        # It's crucial that this matches polar_code_gen.crc_poly if used for decoding.
        crc_poly_local = np.array([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], dtype=int)
        print("CRC poly:", crc_poly_local)


        # --- Debugging Encoding and Decoding for One Block (Place inside try) ---
        print("\n--- Debugging Encoding and Decoding for One Block ---")
        try: # Added an internal try block for the debug section itself
             # 1. Generate payload
            payload_debug = polar_code_gen.generate_info_bits()
            print(f"Original payload ({len(payload_debug)} bits): {payload_debug[:10]}...")

            # 2. Encode (includes CRC and polar transform)
            # Temporarily modify encode to capture u_vec
            original_encode = PolarCodeGenerator.encode
            # Define the debug_encode function within this scope or ensure it's accessible
            def debug_encode(self_gen, payload_bits): # Renamed self to self_gen to avoid conflict if debug_encode is inside a method
                 payload_bits = np.array(payload_bits, dtype=int)
                 if len(payload_bits) != self_gen.K:
                     raise ValueError(f"Expected {self_gen.K} payload bits, got {len(payload_bits)}")

                 # Add CRC bits if needed
                 if self_gen.crc_poly is not None:
                     # Use the global compute_crc here
                     crc_bits = compute_crc(payload_bits, self_gen.crc_poly)
                     info_bits_with_crc = np.concatenate([payload_bits, crc_bits])
                 else:
                     info_bits_with_crc = payload_bits

                 if len(info_bits_with_crc) != self_gen.KwCRC:
                      raise RuntimeError(f"CRC appending resulted in {len(info_bits_with_crc)} bits, expected {self_gen.KwCRC}")

                 # Place info bits at most reliable positions, rest frozen (zero)
                 u = np.zeros(self_gen.N, dtype=int)
                 # Place info_bits_with_crc into the u vector at info_set indices
                 u[self_gen.info_set] = info_bits_with_crc
                 # --- Capture u_vec here ---
                 global u_vec_before_encode
                 u_vec_before_encode = u.copy()
                 # ------------------------
                 codeword = self_gen._arikan_transform(u)
                 return codeword, info_bits_with_crc

            PolarCodeGenerator.encode = debug_encode # Replace encode with debug version

            # Call encode to capture u_vec
            codeword_debug, info_bits_with_crc_debug = polar_code_gen.encode(payload_debug)

            # Restore original encode method
            PolarCodeGenerator.encode = original_encode

            print(f"Info bits with CRC ({len(info_bits_with_crc_debug)} bits): {info_bits_with_crc_debug[:10]}...")
            print(f"Encoded codeword ({len(codeword_debug)} bits): {codeword_debug[:10]}...")
            print(f"u vector before encoding ({len(u_vec_before_encode)} bits): {u_vec_before_encode[:10]}...")

            # Verify CRC calculation on the generated info_bits_with_crc
            # Need to handle case where crc_length is 0
            if polar_code_gen.crc_length > 0:
                payload_part = info_bits_with_crc_debug[:-polar_code_gen.crc_length]
                crc_part = info_bits_with_crc_debug[-polar_code_gen.crc_length:]
                recalculated_crc = compute_crc(payload_part, polar_code_gen.crc_poly) # Use the global compute_crc
                print(f"Original CRC part: {crc_part}")
                print(f"Recalculated CRC part: {recalculated_crc}")
                print(f"CRC calculation match: {np.array_equal(crc_part, recalculated_crc)}")
            else:
                 print("No CRC used, skipping CRC verification.")


            # Verify Polar Transform (simple check: run inverse transform)
            decoded_u_vec = polar_code_gen._arikan_transform(codeword_debug)
            print(f"Arikan transform (u vector) after decoding codeword: {decoded_u_vec[:10]}...")
            # Need to compare decoded_u_vec with u_vec_before_encode
            print(f"Arikan transform inverse check: {np.array_equal(u_vec_before_encode, decoded_u_vec)}")


            # 3. Simulate channel at a low SNR (where errors are expected)
            snr_debug = 1.0 # Low SNR for debugging errors
            code_rate_debug = polar_code_gen.KwCRC / polar_code_gen.N
            x_debug = bpsk_modulate(codeword_debug)
            rx_debug, noise_var_debug = add_awgn_noise(x_debug, snr_debug, code_rate_debug)
            llr_debug = make_llr(rx_debug, noise_var_debug)
            print(f"Debugging at SNR={snr_debug} dB. Noise Variance: {noise_var_debug:.4f}")
            print(f"Sample LLRs: {llr_debug[:10]}...")

            # 4. Attempt decoding with SC and SCL (List 8)
            print("\n--- Decoding with SC ---")
            sc_decoder_debug = PolarCodeDecoder(
                 N=BLOCK_LENGTH, K_decoder=KwCRC, list_size=1, crc_poly=crc_poly_local, # Use KwCRC as K_decoder, use local crc_poly
                 info_set=info_set, frozen_set=frozen_set # Use generator's sets
            )
            # Ensure decode returns the full N-length u_hat for consistent comparison
            # OR ensure decode returns KwCRC length and compare accordingly.
            # Based on PolarCodeDecoder.decode, it returns KwCRC length bits.
            decoded_bits_sc = sc_decoder_debug.decode(llr_debug)
            print(f"SC Decoded info+CRC bits ({len(decoded_bits_sc)}): {decoded_bits_sc[:10]}...")
            print(f"SC Decoding match with true info+CRC: {np.array_equal(decoded_bits_sc, info_bits_with_crc_debug)}")
            print(f"SC CRC check passed: {sc_decoder_debug.crc_check(decoded_bits_sc)}") # crc_check operates on info+CRC bits


            print("\n--- Decoding with SCL (List 8) ---")
            scl_decoder_debug = PolarCodeDecoder(
                 N=BLOCK_LENGTH, K_decoder=KwCRC, list_size=8, crc_poly=crc_poly_local, # Use KwCRC as K_decoder, use local crc_poly
                 info_set=info_set, frozen_set=frozen_set # Use generator's sets
            )
            # Ensure decode returns KwCRC length bits.
            decoded_bits_scl = scl_decoder_debug.decode(llr_debug)
            print(f"SCL Decoded info+CRC bits ({len(decoded_bits_scl)}): {decoded_bits_scl[:10]}...")
            print(f"SCL Decoding match with true info+CRC: {np.array_equal(decoded_bits_scl, info_bits_with_crc_debug)}")
            print(f"SCL CRC check passed (on returned bits): {scl_decoder_debug.crc_check(decoded_bits_scl)}") # crc_check operates on info+CRC bits

        except Exception as e:
             print(f"Debugging Block Error: {e}")
             traceback.print_exc()
        print("\n--- End Debugging Encoding and Decoding ---")


        # --- Prepare dataset for RNN training ---
        X_raw, y_raw = prepare_polar_dataset(polar_code_gen, NUM_SAMPLES_TRAIN, snr_db=2.5)
        X_tensor = torch.FloatTensor(X_raw).to(device)
        y_tensor = torch.FloatTensor(y_raw).to(device) # y_tensor holds info_bits_with_crc
        train_size = int(0.8 * X_tensor.shape[0])
        train_X = X_tensor[:train_size]
        train_y = y_tensor[:train_size]
        val_X = X_tensor[train_size:]
        val_y = y_tensor[train_size:]

        # --- RNN Model/Trainer ---
        # Output size should be KwCRC
        rnn_model = EnhancedRNNDecoder(BLOCK_LENGTH, KwCRC).to(device)
        rnn_trainer = DecoderTrainer(rnn_model, LEARNING_RATE)

        # --- Train the RNN ---
        # Pass generator to trainer for dynamic re-noising
        train_losses, val_losses = rnn_trainer.train(
            train_X, train_y,
            val_X=val_X, val_y=val_y,
            epochs=EPOCHS, batch_size=BATCH_SIZE,
            snr_min=1.0, snr_max=7.0, # Dynamic SNR range
            generator=polar_code_gen # Pass the generator
        )

        # --- SNR BER/BLER Sweep and Plotting ---
        rnn_results = {'BER_RNN': [], 'BLER_RNN': []}
        sc_results = {'BER_SC': [], 'BLER_SC': []}
        scl_results = {L: [] for L in LIST_SIZES}

        for snr_db in SNR_RANGE_AWGN:
            print(f"Evaluating at SNR: {snr_db} dB") # Debug print
            # === Generate new data for THIS SNR ===
            X_eval, y_eval = prepare_polar_dataset(
                polar_code_gen, num_samples=NUM_TRIALS_PERF, snr_db=snr_db, channel_type='AWGN'
            )
            X_tensor = torch.FloatTensor(X_eval).to(device)
            y_tensor = torch.FloatTensor(y_eval).to(device) # True info+CRC bits

            # --- RNN Eval ---
            # rnn_trainer.evaluate expects LLRs (X_tensor) and true info+CRC bits (y_tensor)
            ber_rnn, bler_rnn = rnn_trainer.evaluate(X_tensor, y_tensor)
            rnn_results['BER_RNN'].append(ber_rnn)
            rnn_results['BLER_RNN'].append(bler_rnn)

            # --- SC Eval ---
            # Pass the local crc_poly_local and generator's sets to the decoder constructors
            sc_decoder = PolarCodeDecoder(
                N=BLOCK_LENGTH, K_decoder=KwCRC, list_size=1, crc_poly=crc_poly_local,
                info_set=info_set, frozen_set=frozen_set
            )
            bit_errors, blk_errors = 0, 0
            for i in range(NUM_TRIALS_PERF):
                llr = X_eval[i]
                true_bits = y_eval[i].int().cpu().numpy() # Get true info+CRC bits as numpy int
                decoded_bits = sc_decoder.decode(llr) # Should return info+CRC bits (length KwCRC)

                # Check lengths match for comparison
                if len(decoded_bits) != len(true_bits):
                     logging.error(f"SC Decoder output length mismatch for sample {i}: Expected {len(true_bits)}, got {len(decoded_bits)}")
                     # Skip this sample or handle error appropriately
                     continue

                bit_errors += np.sum(decoded_bits != true_bits)
                blk_errors += int(np.any(decoded_bits != true_bits))

            # Ensure division by total possible bits (num_trials * KwCRC)
            total_bits_sc = NUM_TRIALS_PERF * KwCRC
            sc_results['BER_SC'].append(bit_errors / total_bits_sc if total_bits_sc > 0 else 0.0)
            sc_results['BLER_SC'].append(blk_errors / NUM_TRIALS_PERF if NUM_TRIALS_PERF > 0 else 0.0)


            # --- SCL Eval ---
            for L in LIST_SIZES:
                scl_decoder = PolarCodeDecoder(
                    N=BLOCK_LENGTH,
                    K=KwCRC, # >>>>>> Use K=KwCRC here <<<<<<
                    list_size=L,
                    crc_poly=crc_poly_local,
                    info_set=info_set,
                    frozen_set=frozen_set
                )
                bit_errors, blk_errors = 0, 0
                for i in range(NUM_TRIALS_PERF):
                    llr = X_eval[i]
                    true_bits = y_eval[i].int().cpu().numpy() # Get true info+CRC bits as numpy int
                    decoded_bits = scl_decoder.decode(llr) # Should return info+CRC bits (length KwCRC)

                    # Check lengths match for comparison
                    if len(decoded_bits) != len(true_bits):
                        logging.error(f"SCL Decoder output length mismatch for sample {i} (List {L}): Expected {len(true_bits)}, got {len(decoded_bits)}")
                        # Skip this sample or handle error appropriately
                        continue


                    bit_errors += np.sum(decoded_bits != true_bits)
                    blk_errors += int(np.any(decoded_bits != true_bits))

                # Ensure division by total possible bits (num_trials * KwCRC)
                total_bits_scl = NUM_TRIALS_PERF * KwCRC
                scl_results[L].append({'BER': bit_errors / total_bits_scl if total_bits_scl > 0 else 0.0, 'BLER': blk_errors / NUM_TRIALS_PERF if NUM_TRIALS_PERF > 0 else 0.0})

        # --- Plots ---
        plot_training_validation(train_losses, val_losses)
        plot_rnn_ber_bler(SNR_RANGE_AWGN, rnn_results)
        plot_sc_ber_bler(SNR_RANGE_AWGN, sc_results)
        plot_scl_ber_bler(SNR_RANGE_AWGN, scl_results, LIST_SIZES)
        plot_ber_bler_comparison(SNR_RANGE_AWGN, rnn_results, scl_results, sc_results, LIST_SIZES)


        # Confusion Matrix Plot (example on validation set)
        # Need to ensure val_X and val_y are available from the training setup
        val_X_cpu = val_X.cpu()
        val_y_cpu = val_y.cpu() # True info+CRC bits on CPU

        with torch.no_grad():
            # Ensure model is on CPU for numpy conversion if needed, or handle device within plot_confusion_matrix
            rnn_model_cpu = rnn_model.cpu()
            rnn_output_prob_example = rnn_model_cpu(val_X_cpu).numpy()
            rnn_output_example = (rnn_output_prob_example > 0.5).astype(int)
            y_true_example = val_y_cpu.int().numpy() # y_true needs to be int for confusion_matrix

        # Ensure all necessary plotting functions are defined
        plot_confusion_matrix(
            y_true_example.flatten(),
            rnn_output_example.flatten(),
            title="RNN Confusion Matrix"
        )

    except Exception as e:
        print(f"Simulation Error: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    main()


##############################################################################



















Sanity check: LLR magnitudes at different SNRs
KwCRC (encode): 80
Encoder info set: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
Encoder frozen set: [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]
Decoder info set: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
Decoder frozen set: [80, 81, 82, 83

Traceback (most recent call last):
  File "<ipython-input-8-1814983089>", line 1481, in main
    sc_decoder_debug = PolarCodeDecoder(
                       ^^^^^^^^^^^^^^^^^
TypeError: PolarCodeDecoder.__init__() got an unexpected keyword argument 'K_decoder'


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.1978, LLR[0]=5.3249
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.4826, LLR[0]=6.5912
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.3659, LLR[0]=6.0723
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.0711, LLR[0]=4.7619
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.4373, LLR[0]=6.3900
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=-0.4391, LLR[0]=-1.9522
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.1137, LLR[0]=4.9514
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=0.4218, LLR[0]=1.8751
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=2.0672, LLR[0]=9.1901
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=0.8607, LLR[0]=3.8266
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.3282, LLR[0]=5.9047
DEBUG SNR=2.5: noise_var=0.4499, example received[0]=1.6990, LLR[0]=7.5533
DEBUG SNR=2.5: noise_var=0.4499, 

Traceback (most recent call last):
  File "<ipython-input-8-1814983089>", line 1558, in main
    sc_decoder = PolarCodeDecoder(
                 ^^^^^^^^^^^^^^^^^
TypeError: PolarCodeDecoder.__init__() got an unexpected keyword argument 'K_decoder'
