<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/Making_the_Most_of_your_Colab_Subscription.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Jun 12 07:33:04 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   35C    P0             52W /  400W |     631MiB /  40960MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [3]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 89.6 gigabytes of available RAM

You are using a high-RAM runtime!


## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from google.colab import files
uploaded = files.upload()

KeyboardInterrupt: 

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import logging
import pandas as pd
import traceback

import matplotlib.pyplot as plt
from matplotlib.ticker import LogFormatterMathtext
from matplotlib.ticker import LogFormatterMathtext
# Configure logging
logging.basicConfig(level=logging.INFO)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Configuration parameters
BLOCK_LENGTH = 128
INFO_BITS = 64
LEARNING_RATE = 1e-3
EPOCHS = 50
BATCH_SIZE = 64
NUM_SAMPLES_TRAIN = 10000
NUM_TRIALS_PERF = 1000
SNR_RANGE_AWGN = np.linspace(0, 5,10)
#snr_range = [0, 5, 10]
LIST_SIZES = [1, 8, 16]
snr_db = 5.0     # <----- You can name this however you like!
crc_poly = [1, 0, 0, 0, 1, 0, 0, 1]
#generator = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc_poly)
#G = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc7_poly)
#DATASET_SNR_DB = 5.0 # Added configuration for dataset SNR


#part 2

def compute_crc(data, polynomial):
    poly_len = len(polynomial)
    crc_length = poly_len - 1
    data_with_zeros = np.concatenate((data, np.zeros(crc_length, dtype=int)))
    remainder = np.copy(data_with_zeros)
    for i in range(len(remainder) - poly_len + 1):
        if remainder[i] == 1:
            remainder[i:i + poly_len] ^= polynomial
    return remainder[-crc_length:]
###################################################################################
#Latest Polarcode generator
#Latest Polarcode generator
class PolarCodeGenerator:
    """
    Polar Code Generator for N=128 (with CRC option, reliability sequence for info/frozen placement).
    """
    # Reliability order for N=128 (first is most reliable, last is least reliable)
    RELIABILITY_SEQUENCE_128 = [
        0, 1, 2, 4, 8, 16, 3, 5, 6, 9, 10, 12, 17, 18, 20, 24,
        32, 7, 11, 13, 14, 19, 21, 22, 25, 26, 28, 33, 34, 36, 38, 40,
        15, 23, 27, 29, 30, 35, 37, 39, 41, 42, 44, 48, 31, 43, 45, 46,
        49, 50, 52, 56, 47, 51, 53, 54, 57, 58, 60, 62, 63, 55, 59, 61,
        64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
        80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
        96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127
    ]

    def __init__(self, N=128, K=64, crc_poly=None):
        """
        Args:
            N (int): Codeword/block length (should match length of RELIABILITY_SEQUENCE_128)
            K (int): Number of payload (info) bits (without CRC)
            crc_poly (list, optional): CRC polynomial as a binary list, e.g. [1,0,0,0,1,0,0,1] for CRC-7
        """
        self.N = N
        self.K = K
        self.crc_poly = np.array(crc_poly, dtype=int) if crc_poly is not None else None

        # Compute CRC length, total info bits, and info/frozen positions
        self.crc_length = 0 if self.crc_poly is None else (len(self.crc_poly) - 1)
        self.KwCRC = K + self.crc_length  # Total info bits (with CRC if used)

        # Ensure RELIABILITY_SEQUENCE_128 is large enough for N
        if N > len(self.RELIABILITY_SEQUENCE_128):
             raise ValueError(f"Reliability sequence only defined for N=128. Got N={N}")
        if self.KwCRC > N:
             raise ValueError(f"K + CRC length ({self.KwCRC}) cannot be greater than N ({N})")

        # Info bits are placed at the self.KwCRC most reliable positions
        self.info_set = sorted(self.RELIABILITY_SEQUENCE_128[:self.KwCRC])
        self.frozen_set = sorted(set(range(N)) - set(self.info_set))

        logging.info(f"Generator initialized: N={self.N}, K={self.K}, CRC Length={self.crc_length}, KwCRC={self.KwCRC}")
        # logging.info(f"Info set indices: {self.info_set}")
        # logging.info(f"Frozen set indices: {self.frozen_set}")


    def encode(self, payload_bits):
        """
        Args:
            payload_bits (array-like): Length K (payload, i.e. without CRC).
        Returns:
            codeword (np.array): Encoded polar codeword (length N)
            info_bits_with_crc (np.array): Full info vector (payload + CRC if used, length KwCRC)
        """
        payload_bits = np.array(payload_bits, dtype=int)
        if len(payload_bits) != self.K:
            raise ValueError(f"Expected {self.K} payload bits, got {len(payload_bits)}")

        # Add CRC bits if needed
        if self.crc_poly is not None:
            # Use the global compute_crc or define it within the class if preferred
            crc_bits = compute_crc(payload_bits, self.crc_poly)
            info_bits_with_crc = np.concatenate([payload_bits, crc_bits])
        else:
            info_bits_with_crc = payload_bits

        if len(info_bits_with_crc) != self.KwCRC:
             raise RuntimeError(f"CRC appending resulted in {len(info_bits_with_crc)} bits, expected {self.KwCRC}")


        # Place info bits at most reliable positions, rest frozen (zero)
        u = np.zeros(self.N, dtype=int)
        # Place info_bits_with_crc into the u vector at info_set indices
        u[self.info_set] = info_bits_with_crc
        codeword = self._arikan_transform(u)
        return codeword, info_bits_with_crc

    def _arikan_transform(self, u):
        """
        Fast polar (Arıkan) transform (in-place butterfly).
        """
        N = len(u)
        x = u.copy()
        n = int(np.log2(N))
        for i in range(n):
            step = 2 ** i
            for j in range(0, N, 2*step):
                for k in range(step):
                    x[j+k] ^= x[j+k+step]
        return x

    # Removed redundant compute_crc method here, using the global one


    def generate_info_bits(self):
        # Generates K information bits (payload)
        return np.random.randint(2, size=self.K)

    # Removed _place_info_bits and _polar_encode_recursive, replaced by encode method

    # Removed append_crc here, now handled inside encode method using the global compute_crc


# --- Simulation Functions ---

def bpsk_modulate(bits):
    # Map 0 -> +1, 1 -> -1
    bits = np.array(bits, dtype=int)
    return 1 - 2*bits

def add_awgn_noise(x, snr_db, code_rate):
    """
    Adds AWGN noise. SNR here is Eb/N0, bit energy to noise ratio, in dB.
    Requires code_rate (K_info / N or KwCRC / N). Using KwCRC/N for rate.
    """
    snr_linear_eb = 10**(snr_db/10)
    # Noise variance sigma^2 = N0/2.
    # Eb/N0 = (Es/R) / N0, Es=1 for BPSK
    # N0 = 1 / (R * (Eb/N0))
    # sigma^2 = 1 / (2 * R * (Eb/N0)_linear)
    noise_var = 1 / (2 * code_rate * snr_linear_eb)
    noise = np.sqrt(noise_var) * np.random.randn(*x.shape)
    return x + noise, noise_var

def make_llr(received, noise_var):
    return 2*received/noise_var

def prepare_polar_dataset(gen, num_samples, snr_db, channel_type='AWGN'):
    """
    Prepares a dataset of noisy channel outputs (LLRs) and corresponding
    original information bits (payload + CRC).
    """
    X_data = np.zeros((num_samples, gen.N))
    # Correctly use gen.KwCRC for the size of information bits + CRC
    y_data = np.zeros((num_samples, gen.KwCRC), dtype=int)

    code_rate = gen.KwCRC / gen.N # Use the correct rate including CRC

    for i in range(num_samples):
        # 1. Generate K info bits (payload)
        payload = gen.generate_info_bits()

        # 2. Encode payload (handles CRC appending and polar transform internally)
        codeword, info_bits_with_crc = gen.encode(payload)

        # 3. BPSK modulate
        x = bpsk_modulate(codeword)

        # 4. Add noise and calculate LLRs
        if channel_type == 'AWGN':
            rx, noise_var = add_awgn_noise(x, snr_db, code_rate)
            llr = make_llr(rx, noise_var)
        # Add other channel types here if needed
        else:
            raise ValueError(f"Unsupported channel type: {channel_type}")

        # Store LLRs (received signal) and the original info_bits_with_crc
        X_data[i] = llr
        y_data[i] = info_bits_with_crc # y_data should be the bits we want to decode to

    return X_data, y_data

def save_dataset_to_csv(X, y, filename='dataset.csv'):
    # X shape: (num_samples, N)
    # y shape: (num_samples, KwCRC)
    # Data to save: (num_samples, N + KwCRC)
    data = np.hstack((X, y))
    columns = [f'received_{i}' for i in range(X.shape[1])] + [f'info_bit_{j}' for j in range(y.shape[1])]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(filename, index=False)
    logging.info(f"Dataset saved to {filename}")
####################################################################################
#part 3
#Rewrite
######################################################

def bpsk_modulate(bits):
    bits = np.array(bits, dtype=int)
    # Map 0 -> -1, 1 -> +1
    return 2 * bits - 1

class EnhancedChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        self.channel_type = channel_type

    def simulate(self, signal, snr_db):
        # signal is assumed to be BPSK modulated (+1 or -1)
        # Calculate noise variance based on SNR_db and signal power (which is 1 for BPSK)
        # SNR_linear = E_s / N0, where E_s = 1 for BPSK
        # Noise variance sigma^2 = N0/2.
        # SNR_linear = 1 / N0 = 1 / (2 * sigma^2)
        # sigma^2 = 1 / (2 * SNR_linear)
        # Need to consider code rate R = K/N if SNR is E_b/N0
        # E_b/N0 = (E_s/R) / N0 = (1/R) * (E_s/N0) = (1/R) * SNR_linear_Es
        # SNR_linear_Es = R * SNR_linear_Eb
        # sigma^2 = 1 / (2 * SNR_linear_Es) = 1 / (2 * R * SNR_linear_Eb)
        # The PolarCodeGenerator stores K and N, but R is needed here.
        # We need the code rate from the generator. Let's pass it or the generator object.

        # Assuming snr_db provided is E_b/N0
        # We need the code rate K/N used for the encoding. Let's assume the generator object is available or pass R.
        # For now, let's assume R is available from a generator instance or a global var if constant.
        # If we simulate *per block*, we need the specific R for that block (though N, K are fixed here).
        # Let's assume R = INFO_BITS / BLOCK_LENGTH is constant.
        R = INFO_BITS / BLOCK_LENGTH # This might need adjustment if K_crc is used for rate calc

        snr_linear_eb = 10 ** (snr_db / 10)
        # Noise variance for E_b/N0
        noise_variance = 1 / (2 * R * snr_linear_eb)
        noise_std = np.sqrt(noise_variance)
        noise = noise_std * np.random.randn(*signal.shape)
        return signal + noise


def bpsk_modulate(bits):
    # Map 0 -> +1, 1 -> -1
    bits = np.array(bits, dtype=int)
    return 1 - 2*bits

def add_awgn_noise(x, snr_db, code_rate):
    """
    Adds AWGN noise. SNR here is Eb/N0, bit energy to noise ratio, in dB.
    Requires code_rate (K_info / N or KwCRC / N). Using KwCRC/N for rate.
    """
    snr_linear_eb = 10**(snr_db/10)
    # Noise variance sigma^2 = N0/2.
    # Eb/N0 = (Es/R) / N0, Es=1 for BPSK
    # N0 = 1 / (R * (Eb/N0))
    # sigma^2 = 1 / (2 * R * (Eb/N0)_linear)
    noise_var = 1 / (2 * code_rate * snr_linear_eb)
    noise = np.sqrt(noise_var) * np.random.randn(*x.shape)
    return x + noise, noise_var

def make_llr(received, noise_var):
    return 2*received/noise_var



def prepare_polar_dataset(gen, num_samples, snr_db, channel_type='AWGN'):
    """
    Prepares a dataset of noisy channel outputs (LLRs) and corresponding
    original information bits (payload + CRC).
    """
    X_data = np.zeros((num_samples, gen.N))
    # Correctly use gen.KwCRC for the size of information bits + CRC
    y_data = np.zeros((num_samples, gen.KwCRC), dtype=int)

    code_rate = gen.KwCRC / gen.N # Use the correct rate including CRC

    for i in range(num_samples):
        # 1. Generate K info bits (payload)
        payload = gen.generate_info_bits()

        # 2. Encode payload (handles CRC appending and polar transform internally)
        codeword, info_bits_with_crc = gen.encode(payload)

        # 3. BPSK modulate
        x = bpsk_modulate(codeword)

        # 4. Add noise and calculate LLRs
        if channel_type == 'AWGN':
            rx, noise_var = add_awgn_noise(x, snr_db, code_rate)
            llr = make_llr(rx, noise_var)
        # Add other channel types here if needed
        else:
            raise ValueError(f"Unsupported channel type: {channel_type}")

        # Store LLRs (received signal) and the original info_bits_with_crc
        X_data[i] = llr
        y_data[i] = info_bits_with_crc # y_data should be the bits we want to decode to

        if num_samples > 0:  # Only bother if dataset not empty
           print(f"DEBUG SNR={snr_db}: noise_var={noise_var:.4f}, example received[0]={rx[0]:.4f}, LLR[0]={llr[0]:.4f}")

    return X_data, y_data


def save_dataset_to_csv(X, y, filename='dataset.csv'):
    # X shape: (num_samples, N)
    # y shape: (num_samples, K_info)
    # Data to save: (num_samples, N + K_info)
    data = np.hstack((X, y))
    columns = [f'received_{i}' for i in range(X.shape[1])] + [f'info_bit_{j}' for j in range(y.shape[1])]
    df = pd.DataFrame(data, columns=columns)
    df.to_csv(filename, index=False)
    logging.info(f"Dataset saved to {filename}")

############################################################################################
#latest part 4
 #----- Model -----
#latest part 4
 #----- Model -----
class EnhancedRNNDecoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Dropout is applied after each LSTM (except last)
        self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Input: [batch, code_length], reshape to [batch, seq=1, code_length]
        x_reshaped = x.unsqueeze(1)  # [B, 1, N]
        # LSTM expects sequence, but ours is just 1 timestep with full codeword
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
        out, _ = self.rnn(x_reshaped, (h0, c0))   # [B, 1, H]
        out = self.fc(out[:, -1, :])              # [B, output_size]
        return self.sigmoid(out)

# ----- Trainer -----
#########################################################################################################
#add loop
crc_poly = [1,0,0,0,1,0,0,1]
BLOCK_LENGTH = 128
INFO_BITS = 64
polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc_poly)

print("Sanity check: LLR magnitudes at different SNRs")
code_rate = polar_code_gen.KwCRC / polar_code_gen.N
payload = polar_code_gen.generate_info_bits()
cw, _ = polar_code_gen.encode(payload)
x = bpsk_modulate(cw)
for snr in [0, 3, 6, 10]:
    rx, noise_var = add_awgn_noise(x, snr_db=snr, code_rate=code_rate)
    llr = make_llr(rx, noise_var)
#########################################################################################################3

#latest on 06/11/25
# --- Trainer ---
class DecoderTrainer:
    def __init__(self, model, learning_rate, early_stop_patience=10):
        super().__init__() # No need to call super for this class structure
        self.model = model
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        self.early_stop_patience = early_stop_patience
        self.device = next(model.parameters()).device # Get device from model

    def train(self, train_X, train_y, val_X=None, val_y=None, epochs=50, batch_size=128, snr_min=1, snr_max=7, generator=None):
        """
        Trains the RNN decoder with dynamic noisy batch generation.

        Args:
            train_X (torch.Tensor): Training data LLRs (Initial, can be re-noised).
            train_y (torch.Tensor): Training data labels (info_bits_with_crc).
            val_X (torch.Tensor, optional): Validation data LLRs.
            val_y (torch.Tensor, optional): Validation data labels.
            epochs (int): Number of training epochs.
            batch_size (int): Batch size for training.
            snr_min (float): Minimum SNR (dB) for dynamic noise generation.
            snr_max (float): Maximum SNR (dB) for dynamic noise generation.
            generator (PolarCodeGenerator): Polar code generator instance for re-encoding.
        """
        train_losses, val_losses = [], []
        best_val_loss = float('inf')
        patience_counter = 0

        # Ensure labels are float and on the correct device
        train_y = train_y.float().to(self.device)
        if val_y is not None:
            val_y = val_y.float().to(self.device)

        # Ensure input data is on the correct device
        train_X = train_X.to(self.device)
        if val_X is not None:
             val_X = val_X.to(self.device)


        # --- Debugging LLRs (Optional - run before training) ---
        if generator is not None: # Need generator for rate
             code_rate = generator.KwCRC / generator.N
             logging.info("--- LLR Verification Before Training ---")
             # Verify at min, max, and a point in between
             verify_llr_calculation(snr_min, code_rate)
             verify_llr_calculation(snr_max, code_rate)
             verify_llr_calculation((snr_min + snr_max) / 2.0, code_rate)
             logging.info("--- End LLR Verification ---")
        else:
             logging.warning("Cannot perform full LLR verification without a generator.")


        for epoch in range(epochs):
            epoch_loss = 0.0

            # Shuffle training set
            indices = torch.randperm(train_X.shape[0])
            # IMPORTANT: We need the original *info_bits_with_crc* (y_batch) to re-generate noisy data.
            # So, shuffle X and y together.
            train_X_shuffled = train_X[indices]
            train_y_shuffled = train_y[indices]


            self.model.train()
            for i in range(0, len(train_X), batch_size):
                # Get batch of original info_bits_with_crc (y_batch)
                y_batch = train_y_shuffled[i:i+batch_size]
                batch_size_actual = y_batch.size(0) # Handle last batch potentially smaller

                # --- Dynamic SNR and Re-generation for each batch ---
                if generator is not None:
                    batch_snr = np.random.uniform(snr_min, snr_max)
                    code_rate = generator.KwCRC / generator.N

                    # Re-generate noisy LLRs for this batch
                    # We need the info_bits_with_crc (which is y_batch) to create the 'u' vector
                    # and then encode/modulate/noise.
                    info_bits_with_crc_batch_np = y_batch.cpu().int().numpy() # Convert to numpy int bits

                    X_noisy_batch_np = np.zeros((batch_size_actual, generator.N), dtype=np.float32)

                    # This loop re-encodes and adds noise for each sample in the batch
                    # This is the correct but potentially slow way.
                    # Vectorization might be possible if generator.encode can handle batches.
                    # Assuming generator.encode works on single samples:
                    for k in range(batch_size_actual):
                         # Get the info_bits_with_crc for this sample
                         info_bits_with_crc_k = info_bits_with_crc_batch_np[k]

                         # Place info_bits_with_crc into the u vector at info_set indices
                         u_vec = np.zeros(generator.N, dtype=int)
                         # Ensure info_bits_with_crc_k has the expected length
                         if len(info_bits_with_crc_k) != generator.KwCRC:
                             logging.error(f"Mismatch in info_bits_with_crc length: Expected {generator.KwCRC}, got {len(info_bits_with_crc_k)}")
                             # You might need to handle this error or fix data generation
                             continue # Skip this sample

                         try:
                             u_vec[generator.info_set] = info_bits_with_crc_k # Place info+crc bits
                             # Encode u_vec using the generator's internal transform
                             codeword = generator._arikan_transform(u_vec)
                             # BPSK modulate
                             x = bpsk_modulate(codeword)
                             # Add noise and calculate LLRs
                             rx, noise_var = add_awgn_noise(x, batch_snr, code_rate)
                             llr = make_llr(rx, noise_var)
                             X_noisy_batch_np[k] = llr
                         except Exception as e:
                             logging.error(f"Error during re-generation for sample {k}: {e}")
                             traceback.print_exc()
                             # Decide how to handle errors: skip batch, use original, etc.
                             # For now, we'll log and potentially have a zero row if error occurred before assignment


                    # Convert the re-generated batch back to a PyTorch tensor and move to device
                    X_noisy = torch.tensor(X_noisy_batch_np, dtype=torch.float32).to(self.device)

                    # --- Optional: Debugging Re-generated Batch LLRs ---
                    # if i == 0 and epoch % 10 == 0: # Debug first batch every few epochs
                    #      logging.info(f"Epoch {epoch+1}, Batch {i//batch_size + 1}: Re-generated Batch SNR: {batch_snr:.2f}")
                    #      # You can add verification calls here for the re-generated batch LLRs
                    #      # This requires knowing the original transmitted bits, which are encoded from y_batch
                    #      # For debugging, you could manually trace a sample's encoding to get its codeword.
                    #      pass


                else:
                    # If no generator, use the original LLRs for the batch
                    logging.warning("Generator not provided. Training on static initial LLRs.")
                    X_noisy = train_X_shuffled[i:i+batch_size] # Use original LLR batch

                # --- Forward Pass ---
                outputs = self.model(X_noisy)
                loss = self.criterion(outputs, y_batch) # y_batch is already float

                # --- Backward Pass and Optimization ---
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                epoch_loss += loss.item() * batch_size_actual # Accumulate weighted by actual batch size

            train_losses.append(epoch_loss / len(train_X)) # Average over total training samples

            # ----- Validation -----
            if val_X is not None and val_y is not None:
                self.model.eval()
                # Evaluate on validation data at a fixed SNR (e.g., median)
                val_snr = (snr_min + snr_max) / 2.0

                # Re-generate validation batch with noise at val_snr
                if generator is not None:
                     info_bits_with_crc_val_np = val_y.cpu().int().numpy() # Convert to numpy int bits
                     X_noisy_val_batch_np = np.zeros((val_X.size(0), generator.N), dtype=np.float32)
                     code_rate = generator.KwCRC / generator.N

                     for k in range(val_X.size(0)):
                          info_bits_with_crc_k = info_bits_with_crc_val_np[k]
                          u_vec = np.zeros(generator.N, dtype=int)

                          if len(info_bits_with_crc_k) != generator.KwCRC:
                              logging.error(f"Validation: Mismatch in info_bits_with_crc length for sample {k}: Expected {generator.KwCRC}, got {len(info_bits_with_crc_k)}")
                              continue

                          try:
                              u_vec[generator.info_set] = info_bits_with_crc_k
                              codeword = generator._arikan_transform(u_vec)
                              x = bpsk_modulate(codeword)
                              rx, noise_var = add_awgn_noise(x, val_snr, code_rate)
                              llr = make_llr(rx, noise_var)
                              X_noisy_val_batch_np[k] = llr
                          except Exception as e:
                              logging.error(f"Validation: Error during re-generation for sample {k}: {e}")
                              traceback.print_exc()
                              # Handle error

                     X_noisy_val = torch.tensor(X_noisy_val_batch_np, dtype=torch.float32).to(self.device)
                else:
                     X_noisy_val = val_X # Use original validation LLRs

                with torch.no_grad():
                    val_output = self.model(X_noisy_val)
                    val_loss = self.criterion(val_output, val_y).item()
                val_losses.append(val_loss)

                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_loss:.4f}")

                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    torch.save(self.model.state_dict(), "best_rnn_decoder.pt")
                    logging.info(f"Epoch {epoch+1}: Best validation loss improved, saving model.")
                else:
                    patience_counter += 1
                    if patience_counter >= self.early_stop_patience:
                        print(f"Early stopping activated after {self.early_stop_patience} epochs without improvement.")
                        break
            else:
                # No validation set provided
                print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f}")
                # If no validation, save model based on train loss or periodically
                if len(train_losses) == 1 or train_losses[-1] < train_losses[-2]:
                     torch.save(self.model.state_dict(), "last_rnn_decoder.pt") # Save latest if improving train loss

        # Load the best model weights if early stopping occurred and validation was used
        if patience_counter >= self.early_stop_patience and val_X is not None:
             try:
                 self.model.load_state_dict(torch.load("best_rnn_decoder.pt"))
                 print("Loaded best model based on validation performance.")
             except FileNotFoundError:
                 logging.warning("Best model file not found after early stopping. Using last epoch weights.")
        elif val_X is None:
            # If no validation, you might want to load the last saved model
             try:
                 self.model.load_state_dict(torch.load("last_rnn_decoder.pt"))
                 print("Loaded last saved model (no validation used for early stopping).")
             except FileNotFoundError:
                 logging.warning("Last model file not found. Using final epoch weights.")


        return train_losses, val_losses if len(val_losses)>0 else None



############################################################################

    def evaluate(self, X_test, y_test, threshold=0.5):
        """
        Evaluates the decoder performance (BER and BLER).
        Assumes X_test is LLRs (for the codeword) and y_test are original info_bits_with_crc (labels).
        """
        self.model.eval()
        # Ensure y_test is int for comparison and on the correct device
        y_test_int = y_test.int().to(self.device) # These are the true info_bits_with_crc
        X_test = X_test.to(self.device) # These are the LLRs for the test codewords

        with torch.no_grad():
            outputs = self.model(X_test) # Model predicts KwCRC bits (probabilities)

            # Threshold probabilities to get hard bit predictions for KwCRC bits
            preds_kwcrc = (outputs > threshold).int() # Predicted info_bits_with_crc

            # --- Calculate BER ---
            # Compare predicted KwCRC bits with true KwCRC bits (y_test_int)
            total_bits_kwcrc = y_test_int.numel()
            bit_errors_kwcrc = torch.sum(preds_kwcrc != y_test_int).item()
            ber_kwcrc = bit_errors_kwcrc / total_bits_kwcrc if total_bits_kwcrc > 0 else 0.0

            # --- Calculate BLER ---
            # A block error occurs if ANY bit in the KwCRC block is incorrect
            block_error_flags_kwcrc = torch.any(preds_kwcrc != y_test_int, dim=1)
            block_errors_kwcrc = torch.sum(block_error_flags_kwcrc).item()
            total_blocks = X_test.size(0)
            bler_kwcrc = block_errors_kwcrc / total_blocks if total_blocks > 0 else 0.0

            # --- Optional: Calculate BER/BLER for Payload Bits (K) ---
            # This requires knowing which indices in KwCRC correspond to payload vs CRC.
            # Assuming payload bits are the first K bits of KwCRC:
            # This assumption might be wrong depending on how CRC was appended.
            # If CRC is appended *after* payload:
            # true_payload_bits = y_test_int[:, :self.model.output_size - generator.crc_length] # Needs generator access
            # predicted_payload_bits = preds_kwcrc[:, :self.model.output_size - generator.crc_length] # Needs generator access

            # A safer approach is to define K_payload explicitly and slice y_test_int and preds_kwcrc
            # Assuming you have K_payload available (which is your INFO_BITS):
            K_payload = INFO_BITS # Use the global or pass it

            # Ensure slicing is valid
            if K_payload > preds_kwcrc.size(1):
                 logging.error(f"K_payload ({K_payload}) is larger than predicted KwCRC size ({preds_kwcrc.size(1)}). Cannot calculate payload BER/BLER.")
                 ber_payload = np.nan
                 bler_payload = np.nan
            else:
                 true_payload_bits = y_test_int[:, :K_payload]
                 predicted_payload_bits = preds_kwcrc[:, :K_payload]

                 total_payload_bits = true_payload_bits.numel()
                 bit_errors_payload = torch.sum(predicted_payload_bits != true_payload_bits).item()
                 ber_payload = bit_errors_payload / total_payload_bits if total_payload_bits > 0 else 0.0

                 # A block error occurs if ANY bit in the K payload block is incorrect
                 block_error_flags_payload = torch.any(predicted_payload_bits != true_payload_bits, dim=1)
                 block_errors_payload = torch.sum(block_error_flags_payload).item()
                 bler_payload = block_errors_payload / total_blocks if total_blocks > 0 else 0.0


        # Decide which BER/BLER to return based on your desired metric
        # Returning BER/BLER for the full KwCRC block might be the intended behavior
        # if the RNN is trained on the full info+crc vector.
        # If you only care about payload, return ber_payload, bler_payload.

        # Let's return both for clarity during debugging
        return ber_kwcrc, bler_kwcrc, ber_payload, bler_payload

# --- How to use the updated evaluate ---
# ber_kwcrc, bler_kwcrc, ber_payload, bler_payload = trainer.evaluate(test_X, test_y)
# print(f"Test BER (KwCRC): {ber_kwcrc:.6f}, BLER (KwCRC): {bler_kwcrc:.6f}")
# print(f"Test BER (Payload K): {ber_payload:.6f}, BLER (Payload K): {bler_payload:.6f}")

# --- Update your plotting code to use the correct BER/BLER ---
# When plotting, use ber_payload and bler_payload if your performance
# metric is specifically for the original K info bits.
# Use ber_kwcrc and bler_kwcrc if your metric includes the CRC bits.




###########################################################################

    #    return ber, bler

# Removed KwCRC definition here, it's now in the enerator

 ############################################################################################
# --- Example Usage (Assuming other parts of your notebook are run) ---

# Assuming you have a PolarCodeGenerator instance, e.g.:
# crc7_poly = [1, 0, 0, 0, 1, 0, 0, 1] # Example CRC-7
# generator = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc7_poly)

# Assuming you have generated initial training and validation data:
# # Define a fixed SNR for initial dataset generation
# DATASET_SNR_DB = 5.0 # You can choose an appropriate SNR
# num_train_samples = 10000
# num_val_samples = 2000 # Optional validation set

# train_X_np, train_y_np = prepare_polar_dataset(generator, num_train_samples, DATASET_SNR_DB)
# train_X = torch.FloatTensor(train_X_np)
# train_y = torch.FloatTensor(train_y_np) # Keep labels as float for training

# If using validation:
# val_X_np, val_y_np = prepare_polar_dataset(generator, num_val_samples, DATASET_SNR_DB)
# val_X = torch.FloatTensor(val_X_np)
# val_y = torch.FloatTensor(val_y_np)

# Assuming you have initialized the RNN model:
# input_size = generator.N # LLRs length
# output_size = generator.KwCRC # Info bits + CRC length
# rnn_decoder = EnhancedRNNDecoder(input_size, output_size).to(device)

# Initialize the Trainer:
# learning_rate = 1e-3
# trainer = DecoderTrainer(rnn_decoder, learning_rate=learning_rate, early_stop_patience=10)

# Start Training:
# snr_train_min = 1 # Min SNR for dynamic training noise
# snr_train_max = 7 # Max SNR for dynamic training noise

# train_losses, val_losses = trainer.train(
#     train_X, train_y,
#     val_X=val_X, val_y=val_y, # Pass validation data if available
#     epochs=EPOCHS,
#     batch_size=BATCH_SIZE,
#     snr_min=snr_train_min,
#     snr_max=snr_train_max,
#     generator=generator # Pass the


# Part 5
############################################################################
#Rewrite
class PolarCodeDecoder:
    """
    Polar Code Decoder supporting both SC and SCL (list) decoding with optional CRC checking.
    """
    def __init__(self, N, K, list_size=1, crc_poly=None):
        """
        Args:
            N (int): Block length (codeword length)
            K (int): Number of information bits (including CRC bits if any)
            list_size (int): Number of SCL paths (set to 1 for regular SC)
            crc_poly (list or np.array, optional): Binary coefficients of CRC polynomial, e.g. [1,0,1,1]
        """
        self.N = N
        self.K = K
        self.list_size = list_size
        self.crc_poly = crc_poly
        # By default: frozen set is the last N-K bits (can be changed for specific design)
        self.frozen_set = set(range(self.K, self.N))
        self.info_set = sorted(set(range(self.N)) - self.frozen_set)
        # CRC properties
        if crc_poly is not None:
            self.crc_length = len(crc_poly) - 1
        else:
            self.crc_length = 0

    def encode_with_crc(self, payload_bits):
        """
        Attach CRC to the payload_bits if CRC polynomial is specified.
        Args:
            payload_bits (array-like): Length = K - crc_length
        Returns:
            code_bits (np.array): Length = K (payload + CRC)
        """
        payload_bits = np.array(payload_bits, dtype=int)
        if self.crc_poly is not None:
            crc_bits = self.compute_crc(payload_bits, self.crc_poly)
            code_bits = np.concatenate([payload_bits, crc_bits])
            return code_bits
        else:
            return payload_bits

    def decode(self, llr):
        """
        Decodes the input LLRs using SC or SCL decoder.
        Args:
            llr (np.array): Array of length N (codeword LLRs)
        Returns:
            info_bits (np.array): Decoded information bits (including CRC bits if present)
        """
        if self.list_size == 1:
            # SC decoder
            u_hat = self._sc_decode(llr)
            info_bits = u_hat[list(self.info_set)]
            # CRC check if present
            if self.crc_poly is not None:
                if not self.crc_check(info_bits):
                    print('CRC failed in SC!')
            return info_bits
        else:
            # SCL decoder
            info_bits = self._scl_decode(llr)
            return info_bits

    def _sc_decode(self, llr):
        """Basic SC decoding."""
        u_hat = np.zeros(self.N, dtype=int)
        for i in range(self.N):
            if i in self.frozen_set:
                u_hat[i] = 0
            else:
                u_hat[i] = 0 if llr[i] >= 0 else 1
        return u_hat

    def _scl_decode(self, llr):
        """
        Simple SCL decoding. For each unfrozen index, splits the path,
        keeps most likely list_size paths. CRC is checked on valid codewords.
        """
        paths = [([], 0.0)]  # Each entry: (path_bits, path_metric)
        for i in range(self.N):
            new_paths = []
            for bits, metric in paths:
                if i in self.frozen_set:
                    # Only extend with 0
                    new_bits = bits + [0]
                    new_metric = metric + self._metric(llr[i], 0)
                    new_paths.append((new_bits, new_metric))
                else:
                    # Extend with both 0 and 1
                    for bit in [0, 1]:
                        ext_bits = bits + [bit]
                        ext_metric = metric + self._metric(llr[i], bit)
                        new_paths.append((ext_bits, ext_metric))
            # Prune paths: keep only best list_size
            new_paths.sort(key=lambda x: x[1])
            paths = new_paths[:self.list_size]

        # At the end: check CRC if needed
        info_indices = list(self.info_set)
        if self.crc_poly is not None:
            for bits, metric in paths:
                info_bits = np.array([bits[i] for i in info_indices], dtype=int)
                if self.crc_check(info_bits):
                    return info_bits
            # No CRC-valid path found; fall back to best metric path
            print("Warning: No CRC-valid candidate; using lowest metric path.")
        # Return info bits from best path
        best_bits = paths[0][0]
        return np.array([best_bits[i] for i in info_indices], dtype=int)

    @staticmethod
    def _metric(llr, bit):
        """Path metric increment for bit decision 0/1 given LLR."""
        # log(1+exp(-llr)) for bit=0, log(1+exp(llr)) for bit=1
        return np.log1p(np.exp(-llr)) if bit == 0 else np.log1p(np.exp(llr))

    def crc_check(self, info_bits):
        """
        Checks CRC on info_bits (expects last crc_length bits are CRC).
        Returns True if CRC matches, False otherwise.
        """
        if self.crc_poly is None or self.crc_length == 0:
            return True
        poly = np.array(self.crc_poly, dtype=int)
        data = np.array(info_bits, dtype=int)
        payload = data[:-self.crc_length]
        crc_bits = data[-self.crc_length:]
        calc_crc = self.compute_crc(payload, poly)
        return np.array_equal(calc_crc, crc_bits)

    @staticmethod
    def compute_crc(data, polynomial):
        """
        Computes CRC bits (classic binary modulo-2 division).
        Args:
            data (np.array): Info bits, shape (payload length,)
            polynomial (array-like): Polynomial, e.g., [1,0,1,1]
        Returns:
            np.array: CRC bits, length len(polynomial)-1
        """
        data = np.array(data, dtype=int)
        polynomial = np.array(polynomial, dtype=int)
        crc_length = len(polynomial) - 1
        data_padded = np.concatenate([data, np.zeros(crc_length, dtype=int)])
        for i in range(len(data)):
            if data_padded[i] == 1:
                data_padded[i:i+len(polynomial)] ^= polynomial
        return data_padded[-crc_length:] if crc_length > 0 else np.array([], dtype=int)


# --- Verification Helper Function for LLR ---
def verify_llr_calculation(snr_db, code_rate, num_samples=1000):
    """Generates random bits, encodes, adds noise, calculates LLRs,
       and compares with expected LLR behavior."""
    logging.info(f"Verifying LLR calculation for SNR={snr_db} dB, Rate={code_rate}")

    # Create a simple mock generator just for rate calculation
    # In a real scenario, you'd use the actual generator instance.
    # For verification, we just need K and N for code rate.
    mock_gen = type('MockGen', (object,), {'K': int(code_rate * BLOCK_LENGTH), 'N': BLOCK_LENGTH, 'KwCRC': int(code_rate * BLOCK_LENGTH)})()

    # Generate random bits (equivalent of info_bits_with_crc for this test)
    # We need N bits to simulate the codeword for LLR verification purposes
    # This is a simplified test - real LLRs are on channel output of codeword
    random_codeword = np.random.randint(2, size=(num_samples, mock_gen.N))

    # Simulate channel
    # BPSK map 0->1, 1->-1 for AWGN LLR formula derivation
    bpsk_signal = 1 - 2 * random_codeword
    rx_signal, noise_var = add_awgn_noise(bpsk_signal, snr_db, code_rate)
    calculated_llr = make_llr(rx_signal, noise_var)

    # Expected LLR for transmitted bit 0 (+1 in BPSK) should be positive
    # Expected LLR for transmitted bit 1 (-1 in BPSK) should be negative
    correct_llr_signs = np.sign(1 - 2 * random_codeword) # Sign of +1 for 0, -1 for 1
    actual_llr_signs = np.sign(calculated_llr)

    # Check proportion of correct signs
    correct_sign_proportion = np.mean(correct_llr_signs == actual_llr_signs)
    logging.info(f"  Proportion of LLRs with correct sign: {correct_sign_proportion:.4f}")

    # Check mean absolute LLR magnitude (rough check)
    # Expected mean magnitude increases with SNR
    mean_abs_llr = np.mean(np.abs(calculated_llr))
    # Theoretical mean LLR magnitude for AWGN BPSK, sign known: E[|LLR|] = sqrt(8 * R * Eb/N0 / pi)
    # This theoretical value is for LLR *before* sign, which is related to received signal magnitude
    # E[|LLR|] ~ 2 * E[|received|]/sigma^2. Let's skip strict theoretical check and just look at trend.
    logging.info(f"  Mean absolute LLR magnitude: {mean_abs_llr:.4f}")

    # You can add more detailed checks here, e.g., looking at the distribution
    # or comparing mean LLR for bits 0 vs 1.

    # Example: Mean LLR for transmitted 0s vs 1s
    mean_llr_zeros = np.mean(calculated_llr[random_codeword == 0]) if np.any(random_codeword == 0) else np.nan
    mean_llr_ones = np.mean(calculated_llr[random_codeword == 1]) if np.any(random_codeword == 1) else np.nan
    logging.info(f"  Mean LLR for transmitted 0s: {mean_llr_zeros:.4f}")
    logging.info(f"  Mean LLR for transmitted 1s: {mean_llr_ones:.4f}")

    # A key check: the mean LLR for transmitted 0 should be positive, for 1 negative.
    # And their magnitudes should be similar.
    if not (np.isnan(mean_llr_zeros) or mean_llr_zeros > 0):
         logging.warning(f"  Warning: Mean LLR for transmitted 0s is not positive ({mean_llr_zeros:.4f}).")
    if not (np.isnan(mean_llr_ones) or mean_llr_ones < 0):
         logging.warning(f"  Warning: Mean LLR for transmitted 1s is not negative ({mean_llr_ones:.4f}).")



###############################################################################
# latest Performance comparison
def performance_comparison(
        rnn_trainer,
        polar_code_gen,
        snr_range,
        channel_type,
        list_sizes,
        num_trials,
    ):
    device = next(rnn_trainer.model.parameters()).device
    num_info_bits = polar_code_gen.KwCRC  # Must use total info+CRC for error computation!
    rnn_results = {'BER_RNN': [], 'BLER_RNN': []}
    sc_results = {'BER_SC': [], 'BLER_SC': []}
    scl_results = {L: [] for L in list_sizes}

    for snr_db in snr_range:
        # === Generate new data for THIS SNR ===
        X_eval, y_eval = prepare_polar_dataset(
            polar_code_gen, num_samples=num_trials, snr_db=snr_db, channel_type=channel_type
        )
        X_tensor = torch.FloatTensor(X_eval).to(device)
        y_tensor = torch.FloatTensor(y_eval).to(device)

        # --- RNN Decoder ---
        ber_rnn, bler_rnn = rnn_trainer.evaluate(X_tensor, y_tensor)
        rnn_results['BER_RNN'].append(ber_rnn)
        rnn_results['BLER_RNN'].append(bler_rnn)

        # --- SC Decoder (list_size=1) ---
        sc_decoder = PolarCodeDecoder(polar_code_gen.N, num_info_bits, list_size=1, crc_poly=polar_code_gen.crc_poly)
        bit_errors, blk_errors = 0, 0
        for i in range(num_trials):
            llr = X_eval[i]          # <--- This LLR was generated at THIS SNR
            true_bits = y_eval[i]
            decoded_bits = sc_decoder.decode(llr)
            bit_errors += np.sum(decoded_bits != true_bits)
            blk_errors += int(np.any(decoded_bits != true_bits))
        sc_results['BER_SC'].append(bit_errors / (num_trials * num_info_bits))
        sc_results['BLER_SC'].append(blk_errors / num_trials)

        # --- SCL Decoders ---
        for L in list_sizes:
            scl_decoder = PolarCodeDecoder(polar_code_gen.N, num_info_bits, list_size=L, crc_poly=polar_code_gen.crc_poly)
            bit_errors, blk_errors = 0, 0
            for i in range(num_trials):
                llr = X_eval[i]
                true_bits = y_eval[i]
                decoded_bits = scl_decoder.decode(llr)
                bit_errors += np.sum(decoded_bits != true_bits)
                blk_errors += int(np.any(decoded_bits != true_bits))
            scl_results[L].append({'BER': bit_errors / (num_trials * num_info_bits), 'BLER': blk_errors / num_trials})

    return rnn_results, scl_results, sc_results
########################################################################

#add on
def sc_sanity_waterfall():
    crc_poly = [1,0,0,0,1,0,0,1]
    BLOCK_LENGTH = 128
    INFO_BITS = 64
    NUM_TRIALS_PERF = 10           # << ONLY TEN!
    polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc_poly)
    num_info_bits = polar_code_gen.KwCRC
    snr_range = [0, 5, 10]         # << Only three SNR points
    print("Testing SC only (classical):")
    for snr_db in snr_range:
        X_eval, y_eval = prepare_polar_dataset(
            polar_code_gen, num_samples=NUM_TRIALS_PERF, snr_db=snr_db, channel_type='AWGN'
        )
        sc_decoder = PolarCodeDecoder(BLOCK_LENGTH, num_info_bits, list_size=1, crc_poly=crc_poly)
        bit_errors, blk_errors = 0, 0
        for i in range(NUM_TRIALS_PERF):
            llr = X_eval[i]
            true_bits = y_eval[i]
            if i == 0:
                print(f"SNR={snr_db}: Sample LLR[:10]={llr[:10]}")
            decoded_bits = sc_decoder.decode(llr)
            bit_errors += np.sum(decoded_bits != true_bits)
            blk_errors += int(np.any(decoded_bits != true_bits))
        ber = bit_errors / (NUM_TRIALS_PERF * num_info_bits)
        bler = blk_errors / NUM_TRIALS_PERF
        print(f"SNR={snr_db:2d}dB: BER={ber:.3f}, BLER={bler:.3f}")

sc_sanity_waterfall()


############################################################################
#part 6 Plotting fuctions

def plot_ber_bler_comparison(snr_range, rnn_results, scl_results, sc_results, list_sizes):
    plt.figure(figsize=(18, 6))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BLER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BLER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()





# for RNN and SCL


def plot_rnn_ber_bler(snr_range, rnn_results):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], marker='o', label='RNN BER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('RNN Bit Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], marker='o', label='RNN BLER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('RNN Block Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()
#2. SCL: Specialized BER/BLER Plot (for a single list size at a time)


def plot_sc_ber_bler(snr_range, sc_results):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, sc_results['BER_SC'], marker='s', label='SC BER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SC Bit Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, sc_results['BLER_SC'], marker='s', label='SC BLER')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SC Block Error Rate')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()

#SCL Plotting Function (ALL list sizes in ONE figure)

def plot_scl_ber_bler(snr_range, scl_results, list_sizes):
    plt.figure(figsize=(12, 5))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    for L in list_sizes:
        ber_list = [res['BER'] for res in scl_results[L]]
        plt.plot(snr_range, ber_list, marker='^', label=f'SCL L={L}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SCL Bit Error Rate (All List Sizes)')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    for L in list_sizes:
        bler_list = [res['BLER'] for res in scl_results[L]]
        plt.plot(snr_range, bler_list, marker='^', label=f'SCL L={L}')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SCL Block Error Rate (All List Sizes)')
    plt.legend()
    plt.grid(True, which="both", ls="--")
    ax = plt.gca()
    ax.yaxis.set_major_formatter(LogFormatterMathtext())

    plt.tight_layout()
    plt.show()


######################################################################
def plot_training_validation(train_losses, val_losses):
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, label='Training Loss')
    if val_losses:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()


def plot_confusion_matrix(y_true, y_pred, title='Confusion Matrix'):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    disp.plot()
    plt.title(title)
    plt.show()

def plot_ber_bler_comparison(snr_range, rnn_results, scl_results, sc_results, list_sizes):
    plt.figure(figsize=(18, 6))

    # BER Plot
    plt.subplot(1, 2, 1)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('Bit Error Rate (BER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # BLER Plot
    plt.subplot(1, 2, 2)
    plt.yscale('log')
    plt.ylim(1e-4, 1)
    plt.plot(snr_range, rnn_results['BLER_RNN'], label='RNN')
    for size in list_sizes:
        plt.plot(snr_range, [result['BLER'] for result in scl_results[size]], label=f'SCL, List Size {size}')
    plt.plot(snr_range, sc_results['BLER_SC'], label='SC')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('Block Error Rate (BLER)')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()
#Part 7: Main Function
###############################################################################
#latest main()



##############################################################################
def main():


    try:
        # Set up the device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Configuration parameters
        BLOCK_LENGTH = 128
        INFO_BITS = 64
        LEARNING_RATE = 1e-3
        EPOCHS = 50
        BATCH_SIZE = 64
        NUM_SAMPLES_TRAIN = 10000
        NUM_TRIALS_PERF = 1000
        SNR_RANGE_AWGN = np.linspace(0, 5,10)
      #  snr_range = [0, 5, 10]
        LIST_SIZES = [1, 8, 16]
        snr_db = 5.0     # <----- You can name this however you like!
       # crc7_poly = [1, 0, 0, 0, 1, 0, 0, 1]
       # G = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS, crc_poly=crc7_poly)
       # DATASET_SNR_DB = 5.0 # Added configuration for dataset SNR


        # Initialize polar code generator and RNN model/trainer
        polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
        rnn_model = EnhancedRNNDecoder(BLOCK_LENGTH, INFO_BITS).to(device)
        rnn_trainer = DecoderTrainer(rnn_model, LEARNING_RATE)

        # Generate dataset and save (if needed)
        X_raw, y_raw = prepare_polar_dataset(
            polar_code_gen, num_samples=NUM_SAMPLES_TRAIN, snr_db=5.0, channel_type='AWGN'
        )
        save_dataset_to_csv(X_raw, y_raw, 'awgn_dataset.csv')

        # Tensor conversion and splitting
        X_tensor = torch.FloatTensor(X_raw).view(-1, BLOCK_LENGTH).to(device)
        y_tensor = torch.FloatTensor(y_raw).view(-1, INFO_BITS).to(device)
        train_size = int(0.8 * X_tensor.shape[0])
        train_X = X_tensor[:train_size]
        train_y = y_tensor[:train_size]
        val_X = X_tensor[train_size:]
        val_y = y_tensor[train_size:]

        # Train RNN model
        train_losses, val_losses = rnn_trainer.train(
           train_X, train_y, val_X=val_X, val_y=val_y, epochs=EPOCHS, batch_size=BATCH_SIZE, generator=polar_code_gen
        )

        # ---- Performance Evaluation ----

        # Performance comparison: obtain RNN, SCL, and SC results
        # You must make sure your performance_comparison returns these three!
        rnn_results, scl_results, sc_results = performance_comparison(
            rnn_trainer, polar_code_gen, SNR_RANGE_AWGN, 'AWGN', LIST_SIZES, NUM_TRIALS_PERF
        )

        # ---- Plot losses ----
        plot_training_validation(train_losses, val_losses)

        # Plot RNN results
        plot_rnn_ber_bler(SNR_RANGE_AWGN, rnn_results)

# Plot SC results
        plot_sc_ber_bler(SNR_RANGE_AWGN, sc_results)

# Plot SCL results (all list sizes in one plot)
        plot_scl_ber_bler(SNR_RANGE_AWGN, scl_results, LIST_SIZES)
        # ---- (Optional) Plot all for comparison ----
        plot_ber_bler_comparison(
            SNR_RANGE_AWGN,
            rnn_results,
            scl_results,
            sc_results,
            LIST_SIZES
        )

        # ---- (Optional) Example Confusion Matrix for RNN on validation data ----
        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
        y_true_example = val_y[:100].cpu().numpy()
        rnn_input_example = val_X[:100]
        rnn_output_prob_example = rnn_trainer.model(rnn_input_example).cpu().detach().numpy()
        rnn_output_example = (rnn_output_prob_example > 0.5).astype(int)
        y_pred_example = rnn_output_example.squeeze()
        plot_confusion_matrix(
            y_true_example.flatten(),
            y_pred_example.flatten(),
            title='RNN Confusion Matrix'
        )

    except Exception as e:
        print(f"Simulation Error: {e}")
        traceback.print_exc()


if __name__ == "__main__":
    main()



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.7172, LLR[0]=-4.5361
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.5905, LLR[0]=-3.7347
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=1.6395, LLR[0]=10.3689
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-2.0900, LLR[0]=-13.2184
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-1.2640, LLR[0]=-7.9944
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.2717, LLR[0]=-1.7182
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.9458, LLR[0]=-5.9818
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.1688, LLR[0]=-1.0674
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.6547, LLR[0]=-4.1408
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.8564, LLR[0]=-5.4162
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-0.7945, LLR[0]=-5.0248
DEBUG SNR=5.0: noise_var=0.3162, example received[0]=-1.3595, LLR[0]=-8.5979
DEBUG SNR=5

Traceback (most recent call last):
  File "<ipython-input-1-3016808852>", line 1300, in main
    rnn_results, scl_results, sc_results = performance_comparison(
                                           ^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-1-3016808852>", line 976, in performance_comparison
    ber_rnn, bler_rnn = rnn_trainer.evaluate(X_tensor, y_tensor)
    ^^^^^^^^^^^^^^^^^
ValueError: too many values to unpack (expected 2)
