<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/ltstpolarML0629.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [None]:
# === Imports ===
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

# ===========================
# Configuration Parameters
# ===========================
N = 128                # Block length
K = 64                 # Number of information bits
CRC_LEN = 8            # CRC length (bits)
LIST_SIZE = [1, 4, 8, 16]  # SCL list sizes
NUM_FRAMES_PER_SNR = 1000
SNR_RANGE_DB = np.arange(0.5, 4.5, 0.5)

BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 1e-3
# Add these if used by your RNN decoder (if not, remove or adjust)
RNN_HIDDEN_SIZE = 128
RNN_NUM_LAYERS = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===========================
# CRC Class (CRC-8)
# ===========================
class CRC:
    def __init__(self, poly=0x07, length=8):
        self.poly = poly
        self.length = length
        self.poly_mask = (1 << length)

    def compute(self, bits):
        reg = 0
        for bit in bits:
            reg = ((reg << 1) | bit) & ((1 << (self.length + 1)) - 1)
            if reg & self.poly_mask:
                reg ^= self.poly << (self.length - 1)
        crc = reg & ((1 << self.length) - 1)
        return crc

    def check(self, bits):
        # bits includes info + crc
        crc_calc = self.compute(bits[:-self.length])
        crc_received = 0
        for i in range(self.length):
            crc_received |= bits[-self.length + i] << (self.length - 1 - i)
        return crc_calc == crc_received

    def append_crc(self, bits):
        crc = self.compute(bits)
        crc_bits = [(crc >> (self.length - 1 - i)) & 1 for i in range(self.length)]
        return np.concatenate([bits, crc_bits])

# ===========================
# Polar Code Generator Class
# ===========================
class PolarCodeGenerator:
    def __init__(self, N, K, crc=None, design_snr_db=0.5):
        self.N = N
        self.K = K
        self.crc = crc
        self.design_snr_db = design_snr_db

        self.K_total = self.K + (self.crc.length if self.crc else 0)

        # Arikan reliability sequence for N=128
        ARIKAN_RELIABILITY_128 = np.array([
             0,  1,  2,  4,  8, 16, 32, 64,
             3,  5,  6,  9, 10, 12, 17, 18,
            20, 24, 33, 34, 36, 40, 48, 65,
            66,  7, 11, 13, 14, 19, 21, 22,
            25, 26, 28, 35, 37, 38, 41, 42,
            44, 49, 50, 52, 67, 68, 69, 70,
            15, 23, 27, 29, 30, 39, 43, 45,
            46, 47, 51, 53, 54, 56, 71, 72,
            73, 74, 75, 76, 77, 78, 79, 80,
            31, 55, 57, 58, 59, 60, 61, 62,
            63, 81, 82, 83, 84, 85, 86, 87,
            88, 89, 90, 91, 92, 93, 94, 95,
            96, 97, 98, 99,100,101,102,103,
           104,105,106,107,108,109,110,111,
           112,113,114,115,116,117,118,119,
           120,121,122,123,124,125,126,127
        ])

        self.frozen_bits = np.ones(self.N, dtype=bool)
        self.information_bits_positions = np.sort(ARIKAN_RELIABILITY_128[:self.K_total])
        self.frozen_bits[self.information_bits_positions] = False
        self.R = self.K / self.N


    def encode(self, info_bits):
        if self.crc is not None:
            info_bits = self.crc.append_crc(info_bits)
        u = np.zeros(self.N, dtype=int)
        u[self.information_bits_positions] = info_bits[:self.K_total]
        x = self._polar_transform(u)
        return x

    def _polar_transform(self, u):
        N = len(u)
        n = int(np.log2(N))
        x = u.copy()
        for i in range(n):
            step = 2 ** i
            for j in range(0, N, 2 * step):
                for k in range(step):
                    x[j + k] ^= x[j + k + step]
        return x

    def decode(self, llr, decoder_type='SC', list_size=1):
        if decoder_type == 'SC':
            decoded = self.sc_decode(llr)
        elif decoder_type == 'SCL':
            scl_decoder = SCLDecoder(self.N, self.K_total, self.frozen_bits, list_size)
            decoded = scl_decoder.decode(llr)
        else:
            raise ValueError(f"Unsupported decoder type: {decoder_type}")

        if self.crc:
            decoded_info = decoded[self.information_bits_positions[:self.K + self.crc.length]]
        else:
            decoded_info = decoded[self.information_bits_positions[:self.K]]

        return decoded, decoded_info

    def sc_decode(self, llr):
        N = self.N
        u_hat = np.zeros(N, dtype=int)

        def recursive_sc_decode(llr_sub, start_idx):
            n = len(llr_sub)
            if n == 1:
                bit_idx = start_idx
                if self.frozen_bits[bit_idx]:
                    return np.array([0], dtype=int)
                else:
                    return np.array([0 if llr_sub[0] >= 0 else 1], dtype=int)

            llr_left = np.sign(llr_sub[:n//2]) * np.minimum(np.abs(llr_sub[:n//2]), np.abs(llr_sub[n//2:]))

            u_left = recursive_sc_decode(llr_left, start_idx)
            llr_right = llr_sub[n//2:] + ((-1) ** u_left) * llr_sub[:n//2]

            u_right = recursive_sc_decode(llr_right, start_idx + n//2)

            return np.concatenate([u_left ^ u_right, u_right])

        u_hat = recursive_sc_decode(llr, 0)
        return u_hat

# ===========================
# Tal-Vardy SCL Decoder (list decoder)
# ===========================
class FastSCLDecoder:
    def __init__(self, N, K_total, frozen_bits, list_size=8):
        self.N = N
        self.K_total = K_total
        self.frozen_bits = frozen_bits
        self.list_size = list_size

    def decode(self, llr):
        N = self.N
        L = self.list_size

        paths = [([], 0.0)]  # (path_bits, path_metric)

        for i in range(N):
            new_paths = []
            for path, metric in paths:
                if self.frozen_bits[i]:
                    # frozen bit = 0
                    new_paths.append((path + [0], metric))
                else:
                    # branch for bit=0 and bit=1
                    for bit in [0, 1]:
                        # path metric update using log-domain approximation
                        pm_increment = np.log1p(np.exp(-llr[i] * (1 - 2 * bit)))
                        new_metric = metric + pm_increment
                        new_paths.append((path + [bit], new_metric))

            # prune to list_size best paths
            new_paths = sorted(new_paths, key=lambda x: x[1])[:L]
            paths = new_paths

        best_path = paths[0][0]
        return np.array(best_path, dtype=int)


# ===========================
#RNN
# ===========================
# Replace DummyRNNDecoder with actual RNNDecoder
class NNNDecoder(nn.Module):
    def __init__(self, seq_len, hidden_size=128, num_layers=3):
        super(NNNDecoder, self).__init__()
        self.gru = nn.GRU(input_size=1,
                          hidden_size=hidden_size,
                          num_layers=num_layers,
                          batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
        self.seq_len = seq_len

    def forward(self, x):
        out, _ = self.gru(x)            # out shape: (batch, seq_len, hidden_size)
        out = self.fc(out)              # out shape: (batch, seq_len, 1)
        out = self.sigmoid(out)         # outputs probabilities
        return out.squeeze(2)           # shape: (batch, seq_len)

# ===========================
# Helper Functions
# ===========================
def bpsk_modulation(bits):
    return 1 - 2 * bits  # 0 -> +1, 1 -> -1

def calculate_llr(y, sigma):
    return 2 * y / (sigma ** 2)

def bit_errors(true_bits, decoded_bits, info_positions):
    decoded_info_bits = decoded_bits[info_positions]
    return np.sum(true_bits != decoded_info_bits[:len(true_bits)])

def block_error(true_bits, decoded_bits, info_positions):
    decoded_info_bits = decoded_bits[info_positions]
    return int(np.any(true_bits != decoded_info_bits[:len(true_bits)]))

def compute_mutual_information(llr, bits):
    p = 1 / (1 + np.exp(-llr))
    mi = np.mean(1 - (-bits*np.log2(p + 1e-15) - (1 - bits)*np.log2(1 - p + 1e-15)))
    return mi

# ===========================
# RNN Training Function and decoder

def train_nnn(model, polar, crc, device, epochs, batch_size, learning_rate, num_train_frames=10000):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.BCELoss()
    model.train()

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        epoch_loss = 0.0
        for _ in range(num_train_frames // batch_size):
            inputs = []
            targets = []

            for _ in range(batch_size):
                info_bits = np.random.randint(0, 2, polar.K)
                full_bits = crc.append_crc(info_bits) if crc else info_bits
                codeword = polar.encode(info_bits)
                symbols = bpsk_modulation(codeword)

                noise = np.random.randn(polar.N) * np.sqrt(1 / (2 * 10**(1.0 / 10)))  # fixed 1 dB noise
                y = symbols + noise
                llr = calculate_llr(y, np.sqrt(1 / (2 * 10**(1.0 / 10))))

                inputs.append(llr)
                targets.append(full_bits)

            inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(2).to(device)  # (batch, N, 1)
            targets = torch.tensor(targets, dtype=torch.float32).to(device)             # (batch, K+CRC)

            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs[:, :targets.shape[1]]  # truncate if needed
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_train_loss = epoch_loss / (num_train_frames // batch_size)
        train_losses.append(avg_train_loss)

        # Validation step (simple, same method)
        model.eval()
        with torch.no_grad():
            val_inputs = []
            val_targets = []
            for _ in range(batch_size):
                info_bits = np.random.randint(0, 2, polar.K)
                full_bits = crc.append_crc(info_bits) if crc else info_bits
                codeword = polar.encode(info_bits)
                symbols = bpsk_modulation(codeword)
                noise = np.random.randn(polar.N) * np.sqrt(1 / (2 * 10**(1.0 / 10)))
                y = symbols + noise
                llr = calculate_llr(y, np.sqrt(1 / (2 * 10**(1.0 / 10))))
                val_inputs.append(llr)
                val_targets.append(full_bits)

            val_inputs = torch.tensor(val_inputs, dtype=torch.float32).unsqueeze(2).to(device)
            val_targets = torch.tensor(val_targets, dtype=torch.float32).to(device)
            val_outputs = model(val_inputs)
            val_outputs = val_outputs[:, :val_targets.shape[1]]
            val_loss = criterion(val_outputs, val_targets).item()

            val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f}")
        model.train()

    return train_losses, val_losses


def nnn_decode(model, llr, device, K):
    model.eval()
    with torch.no_grad():
        llr_tensor = torch.tensor(llr, dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(device)
        output = model(llr_tensor).squeeze(0).cpu().numpy()
        decoded_bits = (output[:K] > 0.5).astype(int)  # threshold at 0.5
    return decoded_bits





# ===========================
# Plotting Function
# ===========================
def plot_results(SNR_RANGE_DB, ber_sc, bler_sc, ber_scl, bler_scl,
                 ber_rnn, bler_rnn, train_losses, val_losses, mi_list, epochs):
    # SC BER/BLER
    plt.figure(figsize=(12, 8))

    plt.subplot(2, 2, 1)
    plt.semilogy(SNR_RANGE_DB, ber_sc, marker='o', label='SC')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Bit Error Rate (BER)')
    plt.title('SC Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.semilogy(SNR_RANGE_DB, bler_sc, marker='o', label='SC')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Block Error Rate (BLER)')
    plt.title('SC Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    # SCL BER/BLER
    plt.subplot(2, 2, 3)
    for L in ber_scl:
        plt.semilogy(SNR_RANGE_DB, ber_scl[L], marker='o', label=f'SCL L={L}')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('SCL Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(2, 2, 4)
    for L in bler_scl:
        plt.semilogy(SNR_RANGE_DB, bler_scl[L], marker='o', label=f'SCL L={L}')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('SCL Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # RNN BER/BLER
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.semilogy(SNR_RANGE_DB, ber_rnn, marker='o', color='purple', label='RNN')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.title('RNN Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.semilogy(SNR_RANGE_DB, bler_rnn, marker='o', color='purple', label='RNN')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.title('RNN Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.tight_layout()
    plt.show()

    # Training and Validation Loss
    plt.figure()
    plt.plot(range(1, epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, epochs + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('RNN Training and Validation Loss')
    plt.grid(True)
    plt.legend()
    plt.show()

    # Mutual Information
    plt.figure()
    plt.plot(SNR_RANGE_DB, mi_list, marker='o')
    plt.xlabel('SNR (dB)')
    plt.ylabel('Mutual Information (bits)')
    plt.title('Mutual Information vs SNR')
    plt.grid(True)
    plt.show()

    #Training, Validation plot

    # After training RNN, plot training and validation loss
    plt.figure()
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Training Loss')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('RNN Training and Validation Loss')
    plt.grid(True)
    plt.legend()
    plt.show()
# ===========================
# Main Simulation Function

###################################
def main():
    # Config params (keep as before)
    N = 128
    K = 64
    CRC_LEN = 8
    LIST_SIZE = [1, 4, 8, 16]
    SNR_RANGE_DB = np.arange(0.5, 4.5, 0.5)
    NUM_FRAMES_PER_SNR = 1000

    BATCH_SIZE = 64
    EPOCHS = 10
    LEARNING_RATE = 1e-3

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    crc = CRC(poly=0x07, length=CRC_LEN)
    polar = PolarCodeGenerator(N, K, crc=crc)

    # Use your real NNNDecoder here
    rnn_decoder = NNNDecoder(seq_len=N, hidden_size=128, num_layers=3).to(device)

    print("Training RNN decoder...")
    train_losses, val_losses = train_rnn(
        model=rnn_decoder,
        polar=polar,
        crc=crc,
        device=device,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        learning_rate=LEARNING_RATE,
        num_train_frames=5000
    )

    ber_sc, bler_sc = [], []
    ber_rnn, bler_rnn = [], []
    ber_scl = {L: [] for L in LIST_SIZE}
    bler_scl = {L: [] for L in LIST_SIZE}
    mi_list = []

    for snr_db in SNR_RANGE_DB:
        print(f"Simulating at SNR = {snr_db} dB")
        snr_linear = 10 ** (snr_db / 10)
        noise_variance = 1 / (2 * snr_linear)

        bit_err_sc = block_err_sc = 0
        bit_err_rnn = block_err_rnn = 0
        bit_err_scl = {L: 0 for L in LIST_SIZE}
        block_err_scl = {L: 0 for L in LIST_SIZE}

        mi_total = 0

        for frame_idx in range(NUM_FRAMES_PER_SNR):
            info_bits = np.random.randint(0, 2, K)
            encoded_bits = polar.encode(info_bits)
            tx = 1 - 2 * encoded_bits  # BPSK modulation
            noise = np.random.randn(N) * np.sqrt(noise_variance)
            y = tx + noise
            llr = 2 * y / noise_variance

            decoded_sc, _ = polar.decode(llr, decoder_type='SC')
            bit_err_sc += bit_errors(info_bits, decoded_sc, polar.information_bits_positions)
            block_err_sc += block_error(info_bits, decoded_sc, polar.information_bits_positions)

            for L in LIST_SIZE:
                decoded_scl, _ = polar.decode(llr, decoder_type='SCL', list_size=L)
                bit_err_scl[L] += bit_errors(info_bits, decoded_scl, polar.information_bits_positions)
                block_err_scl[L] += block_error(info_bits, decoded_scl, polar.information_bits_positions)

            with torch.no_grad():
                llr_tensor = torch.tensor(llr, dtype=torch.float32).unsqueeze(0).unsqueeze(2).to(device)  # (1, N, 1)
                rnn_output = rnn_decoder(llr_tensor).cpu().numpy().squeeze(0)  # (N,)
                decoded_rnn = (rnn_output[:K] > 0.5).astype(int)

            bit_err_rnn += np.sum(info_bits != decoded_rnn)
            block_err_rnn += int(np.any(info_bits != decoded_rnn))

            mi_total += compute_mutual_information(llr, encoded_bits)

        ber_sc.append(bit_err_sc / (NUM_FRAMES_PER_SNR * K))
        bler_sc.append(block_err_sc / NUM_FRAMES_PER_SNR)

        for L in LIST_SIZE:
            ber_scl[L].append(bit_err_scl[L] / (NUM_FRAMES_PER_SNR * K))
            bler_scl[L].append(block_err_scl[L] / NUM_FRAMES_PER_SNR)

        ber_rnn.append(bit_err_rnn / (NUM_FRAMES_PER_SNR * K))
        bler_rnn.append(block_err_rnn / NUM_FRAMES_PER_SNR)

        mi_list.append(mi_total / NUM_FRAMES_PER_SNR)

    # (Your plotting code remains the same here...)

    # ...

    # --- Plot SC decoder results ---
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.semilogy(SNR_RANGE_DB, ber_sc, marker='o', label='SC')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Bit Error Rate (BER)')
    plt.title('SC Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.semilogy(SNR_RANGE_DB, bler_sc, marker='o', label='SC')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Block Error Rate (BLER)')
    plt.title('SC Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot SCL decoder results ---
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    for L in LIST_SIZE:
        plt.semilogy(SNR_RANGE_DB, ber_scl[L], marker='o', label=f'SCL L={L}')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Bit Error Rate (BER)')
    plt.title('SCL Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(1, 2, 2)
    for L in LIST_SIZE:
        plt.semilogy(SNR_RANGE_DB, bler_scl[L], marker='o', label=f'SCL L={L}')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Block Error Rate (BLER)')
    plt.title('SCL Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot RNN decoder results ---
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.semilogy(SNR_RANGE_DB, ber_rnn, marker='o', color='purple', label='RNN')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Bit Error Rate (BER)')
    plt.title('RNN Decoder BER vs SNR')
    plt.grid(True, which='both')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.semilogy(SNR_RANGE_DB, bler_rnn, marker='o', color='purple', label='RNN')
    plt.ylim(1e-5, 1)
    plt.xlabel('SNR (dB)')
    plt.ylabel('Block Error Rate (BLER)')
    plt.title('RNN Decoder BLER vs SNR')
    plt.grid(True, which='both')
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot Training and Validation Loss ---
    plt.figure()
    plt.plot(range(1, EPOCHS + 1), train_losses, label='Training Loss')
    plt.plot(range(1, EPOCHS + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('RNN Training and Validation Loss')
    plt.grid(True)
    plt.legend()
    plt.show()

    # --- Plot Mutual Information ---
    if len(mi_list) == len(SNR_RANGE_DB):
        plt.figure()
        plt.plot(SNR_RANGE_DB, mi_list, marker='o')
        plt.xlabel('SNR (dB)')
        plt.ylabel('Mutual Information (bits)')
        plt.title('Mutual Information vs SNR')
        plt.grid(True)
        plt.show()


if __name__ == "__main__":
    main()






Training RNN decoder...
Epoch 1/10 - Train Loss: 0.6922 | Val Loss: 0.6926
Epoch 2/10 - Train Loss: 0.6920 | Val Loss: 0.6928
Epoch 3/10 - Train Loss: 0.6919 | Val Loss: 0.6912
Epoch 4/10 - Train Loss: 0.6917 | Val Loss: 0.6913
Epoch 5/10 - Train Loss: 0.6911 | Val Loss: 0.6920
