<a href="https://colab.research.google.com/github/kumuds4/BCH/blob/master/Making_the_Most_of_your_Colab_Subscription.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making the Most of your Colab Subscription



## Faster GPUs

Users who have purchased one of Colab's paid plans have access to faster GPUs and more memory. You can upgrade your notebook's GPU settings in `Runtime > Change runtime type` in the menu to select from several accelerator options, subject to availability.

The free of charge version of Colab grants access to Nvidia's T4 GPUs subject to quota restrictions and availability.

You can see what GPU you've been assigned at any time by executing the following cell. If the execution result of running the code cell below is "Not connected to a GPU", you can change the runtime by going to `Runtime > Change runtime type` in the menu to enable a GPU accelerator, and then re-execute the code cell.


In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In order to use a GPU with your notebook, select the `Runtime > Change runtime type` menu, and then set the hardware accelerator to the desired option.

## More memory

Users who have purchased one of Colab's paid plans have access to high-memory VMs when they are available. More powerful GPUs are always offered with high-memory VMs.



You can see how much memory you have available at any time by running the following code cell. If the execution result of running the code cell below is "Not using a high-RAM runtime", then you can enable a high-RAM runtime via `Runtime > Change runtime type` in the menu. Then select High-RAM in the Runtime shape toggle button. After, re-execute the code cell.


In [None]:
import psutil

ram_gb = psutil.virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Longer runtimes

All Colab runtimes are reset after some period of time (which is faster if the runtime isn't executing code). Colab Pro and Pro+ users have access to longer runtimes than those who use Colab free of charge.

## Background execution

Colab Pro+ users have access to background execution, where notebooks will continue executing even after you've closed a browser tab. This is always enabled in Pro+ runtimes as long as you have compute units available.



## Relaxing resource limits in Colab Pro

Your resources are not unlimited in Colab. To make the most of Colab, avoid using resources when you don't need them. For example, only use a GPU when required and close Colab tabs when finished.



If you encounter limitations, you can relax those limitations by purchasing more compute units via Pay As You Go. Anyone can purchase compute units via [Pay As You Go](https://colab.research.google.com/signup); no subscription is required.

## Send us feedback!

If you have any feedback for us, please let us know. The best way to send feedback is by using the Help > 'Send feedback...' menu. If you encounter usage limits in Colab Pro consider subscribing to Pro+.

If you encounter errors or other issues with billing (payments) for Colab Pro, Pro+, or Pay As You Go, please email [colab-billing@google.com](mailto:colab-billing@google.com).

## More Resources

### Working with Notebooks in Colab
- [Overview of Colab](/notebooks/basic_features_overview.ipynb)
- [Guide to Markdown](/notebooks/markdown_guide.ipynb)
- [Importing libraries and installing dependencies](/notebooks/snippets/importing_libraries.ipynb)
- [Saving and loading notebooks in GitHub](https://colab.research.google.com/github/googlecolab/colabtools/blob/main/notebooks/colab-github-demo.ipynb)
- [Interactive forms](/notebooks/forms.ipynb)
- [Interactive widgets](/notebooks/widgets.ipynb)

<a name="working-with-data"></a>
### Working with Data
- [Loading data: Drive, Sheets, and Google Cloud Storage](/notebooks/io.ipynb)
- [Charts: visualizing data](/notebooks/charts.ipynb)
- [Getting started with BigQuery](/notebooks/bigquery.ipynb)

### Machine Learning Crash Course
These are a few of the notebooks from Google's online Machine Learning course. See the [full course website](https://developers.google.com/machine-learning/crash-course/) for more.
- [Intro to Pandas DataFrame](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/pandas_dataframe_ultraquick_tutorial.ipynb)
- [Linear regression with tf.keras using synthetic data](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/linear_regression_with_synthetic_data.ipynb)


<a name="using-accelerated-hardware"></a>
### Using Accelerated Hardware
- [TensorFlow with GPUs](/notebooks/gpu.ipynb)
- [TPUs in Colab](/notebooks/tpu.ipynb)

<a name="machine-learning-examples"></a>

## Machine Learning Examples

To see end-to-end examples of the interactive machine learning analyses that Colab makes possible, check out these tutorials using models from [TensorFlow Hub](https://tfhub.dev).

A few featured examples:

- [Retraining an Image Classifier](https://tensorflow.org/hub/tutorials/tf2_image_retraining): Build a Keras model on top of a pre-trained image classifier to distinguish flowers.
- [Text Classification](https://tensorflow.org/hub/tutorials/tf2_text_classification): Classify IMDB movie reviews as either *positive* or *negative*.
- [Style Transfer](https://tensorflow.org/hub/tutorials/tf2_arbitrary_image_stylization): Use deep learning to transfer style between images.
- [Multilingual Universal Sentence Encoder Q&A](https://tensorflow.org/hub/tutorials/retrieval_with_tf_hub_universal_encoder_qa): Use a machine learning model to answer questions from the SQuAD dataset.
- [Video Interpolation](https://tensorflow.org/hub/tutorials/tweening_conv3d): Predict what happened in a video between the first and the last frame.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from google.colab import files
uploaded = files.upload()

Saving making_the_most_of_your_colab_subscription (18).py to making_the_most_of_your_colab_subscription (18).py


In [None]:
#latest plot functions
# Updates code script
# Comprehensive Polar Code Simulation Framework
!pip install torch numpy matplotlib scikit-learn

# Essential Libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import scipy.special as sps
import logging, traceback, sys

# Logging Configuration
# Fix: Corrected format string for datefmt
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')


# Device Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using Device: {DEVICE}")

#part two

class PolarCodeGenerator:
    def __init__(self, N, K, crc_type='CRC-7'):
        self.N = N
        self.K = K
        self.crc_type = crc_type
        self.crc_polynomials = {
            'CRC-7': {'polynomial': [1, 1, 1, 0, 0, 1, 1], 'length': 7}
        }

    def generate_info_bits(self):
        return np.random.randint(2, size=self.K)

    def compute_crc(self, bits):
        poly_info = self.crc_polynomials.get(self.crc_type)
        if not poly_info:
            raise ValueError(f"Unsupported CRC type: {self.crc_type}")

        polynomial = poly_info['polynomial']
        crc_length = poly_info['length']
        message = bits.tolist() + [0] * crc_length
        for i in range(len(message) - crc_length):
            if message[i] == 1:
                for j in range(crc_length + 1):
                    message[i + j] ^= polynomial[j] if j < len(polynomial) else 0

        return np.array(message[-crc_length:], dtype=int)

    def polar_encode(self, info_bits):
        crc_bits = self.compute_crc(info_bits)
        extended_info_bits = np.concatenate([info_bits, crc_bits])
        codeword = np.zeros(self.N, dtype=int)
        codeword[:len(extended_info_bits)] = extended_info_bits
        return codeword

    def verify_codeword(self, codeword):
        poly_info = self.crc_polynomials[self.crc_type]
        crc_length = poly_info['length']
        info_bits = codeword[:-crc_length]
        received_crc = codeword[-crc_length:]
        computed_crc = self.compute_crc(info_bits)
        return np.array_equal(received_crc, computed_crc)

class EnhancedChannelSimulator:
    def __init__(self, channel_type='AWGN'):
        self.channel_type = channel_type
        logging.info(f"Initializing {channel_type} Channel Simulator")

    def simulate(self, encoded_signal, snr_db):
        try:
            encoded_signal = np.array(encoded_signal, dtype=float)
            bpsk_signal = 1 - 2 * encoded_signal
            snr_linear = 10 ** (snr_db / 10)
            signal_power = np.mean(bpsk_signal**2)
            noise_power = signal_power / snr_linear
            noise_std = np.sqrt(noise_power / 2.0)

            if self.channel_type == 'AWGN':
                noise = np.random.normal(0, noise_std, bpsk_signal.shape)
                received_signal = bpsk_signal + noise
            elif self.channel_type == 'Rayleigh':
                fading = np.random.rayleigh(scale=1.0, size=bpsk_signal.shape)
                noise = np.random.normal(0, noise_std, bpsk_signal.shape)
                received_signal = fading * bpsk_signal + noise
            else:
                raise ValueError(f"Unsupported channel type: {self.channel_type}")

            # Return the raw received signal instead of hard decisions for RNN input
            return received_signal
        except Exception as e:
            logging.error(f"Channel simulation error: {e}")
            # Return the original signal if simulation fails
            return bpsk_signal

    def compute_theoretical_performance(self, block_length, snr_linear):
        try:
            if self.channel_type == 'AWGN':
                # Theoretical BER for BPSK in AWGN
                bep = 0.5 * sps.erfc(np.sqrt(snr_linear))
            elif self.channel_type == 'Rayleigh':
                 # Theoretical BER for BPSK in Rayleigh (assuming ideal channel estimation)
                 bep = 0.5 * (1 - np.sqrt(snr_linear / (1 + snr_linear)))
            else:
                raise ValueError(f"Unsupported channel type: {self.channel_type}")

            # Theoretical BLER is complex for Polar codes; using a simple bound might be misleading.
            # Using a very loose upper bound (Union Bound)
            bler = 1 - (1 - bep) ** block_length
            return bep, bler
        except Exception as e:
            logging.error(f"Theoretical performance computation error: {e}")
            return np.zeros_like(snr_linear), np.ones_like(snr_linear)


#part three

def prepare_polar_dataset(polar_code_gen, num_samples, snr_db=5, channel_type='AWGN'):
    channel_simulator = EnhancedChannelSimulator(channel_type=channel_type)
    X, y = [], []

    for _ in range(num_samples):
        info_bits = polar_code_gen.generate_info_bits()
        encoded_signal = polar_code_gen.polar_encode(info_bits)
        # Simulate the channel and get the received signal (soft values)
        received_signal = channel_simulator.simulate(encoded_signal, snr_db)
        X.append(received_signal)
        y.append(info_bits) # Keep the original info bits as labels

    return np.array(X), np.array(y)

#part four

class EnhancedRNNDecoder(nn.Module):
    def __init__(self, input_size, output_size):
        super(EnhancedRNNDecoder, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, output_size), # Output size is number of info bits (K)
            nn.Sigmoid()
        )

    def forward(self, x):
        if x.dim() == 1:
            x = x.unsqueeze(0)
        if x.dim() > 2:
            x = x.view(x.size(0), -1) # Flatten if input is not already 2D
        return self.model(x)

class DecoderTrainer:
    def __init__(self, model, learning_rate=1e-3):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.criterion = nn.BCELoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate, weight_decay=1e-5)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=5, verbose=True)
        self.train_losses = []
        self.val_losses = []

    def train(self, X, y, epochs=100, batch_size=32, validation_split=0.2):
        X_tensor = X.to(self.device)
        y_tensor = y.to(self.device)

        dataset = TensorDataset(X_tensor, y_tensor)
        train_size = int((1 - validation_split) * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        for epoch in range(epochs):
            self.model.train()
            train_loss = self._train_epoch(train_loader)
            self.train_losses.append(train_loss)

            self.model.eval()
            val_loss = self._validate(val_loader)
            self.val_losses.append(val_loss)

            self.scheduler.step(val_loss)
            print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        return self.train_losses, self.val_losses

    def _train_epoch(self, dataloader):
        total_loss = 0
        for batch_X, batch_y in dataloader:
            batch_X = batch_X.to(self.device)
            batch_y = batch_y.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(batch_X)
            loss = self.criterion(outputs, batch_y)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        return total_loss / len(dataloader)

    def _validate(self, dataloader):
        total_loss = 0
        with torch.no_grad():
            for batch_X, batch_y in dataloader:
                batch_X = batch_X.to(self.device)
                batch_y = batch_y.to(self.device)
                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                total_loss += loss.item()
        return total_loss / len(dataloader)

    def predict(self, X):
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X)
        if X.dim() > 2:
            X = X.view(X.size(0), -1)
        X = X.to(self.device)
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X)
        return (outputs > 0.5).cpu().numpy().astype(int)

# Add traditional decoder
class SCLDecoder:
    def __init__(self, N, K, list_size, crc_type='CRC-7'):
        self.N = N
        self.K = K
        self.list_size = list_size
        self.crc_type = crc_type

    def decode(self, received_signal):
        # Simplified pseudo-code for SCL decoding
        # Initialize list paths and metrics
        paths = [np.zeros(self.N, dtype=int)] * self.list_size
        path_metrics = [0] * self.list_size

        # Successive cancellation with list keeping
        for i in range(self.N):
            # Update each path with new bit decision
            for l in range(self.list_size):
                # Example decision logic (to be replaced with actual SCL logic)
                paths[l][i] = self.make_decision(received_signal[i], path_metrics[l])

                # Update path metric
                path_metrics[l] += self.calculate_metric(paths[l][i], received_signal[i])

            # Sort and prune paths based on metrics (keep best `list_size` paths)
            best_indices = np.argsort(path_metrics)[:self.list_size]
            paths = [paths[i] for i in best_indices]
            path_metrics = [path_metrics[i] for i in best_indices]

        # Extract and return the best path
        best_path = paths[np.argmin(path_metrics)]
        return best_path[:self.K]  # Return the first K bits

    def make_decision(self, received_signal, path_metric):
        # Placeholder logic for bit decision
        return 0 if received_signal < 0.5 else 1

    def calculate_metric(self, bit_decision, received_signal):
        # Example metric computation (Hamming distance, etc.)
        return np.abs(bit_decision - received_signal)

def run_scl_decoder(polar_code_gen, SNRS, list_size, channel_type, num_trials):
    results = []
    for snr_db in SNRS:
        X, y = generate_dataset(channel_type, polar_code_gen, num_samples=num_trials, snr_db=snr_db)

        decoder = SCLDecoder(N=polar_code_gen.N, K=polar_code_gen.K, list_size=list_size)
        decoded_bits = np.array([decoder.decode(x) for x in X])  # Batch process

        # Calculate BER and BLER
        ber = np.sum(np.abs(decoded_bits - y)) / (num_trials * polar_code_gen.K)
        bler = np.mean(np.any(decoded_bits != y, axis=1))

        results.append({'SNR': snr_db, 'BER': ber, 'BLER': bler})

    return results

# part five

# Modified performance comparison to evaluate multi-bit predictions
# This function will now evaluate the same RNN decoder but store results keyed by 'list_size' labels.
#def performance_comparison(rnn_trainer, polar_code_gen, snr_range, channel_name, list_sizes, num_trials):
    # Initialize performance results dictionary to store results for each list size label
 #   performance_results = {list_size: {'BER': [], 'BLER': []} for list_size in list_sizes}
  #  channel_simulator = EnhancedChannelSimulator(channel_type=channel_name)

   # for snr_db in snr_range:
        # Generate data for performance evaluation
    #    X, y = prepare_polar_dataset(polar_code_gen, num_samples=num_trials, snr_db=snr_db, channel_type=channel_name)

     #   predictions = rnn_trainer.predict(X) # predictions shape: [num_trials, K]
      #  actual_labels = y # actual_labels shape: [num_trials, K]

        # Calculate BER: Total number of bit errors / Total number of bits
       # ber = np.sum(np.abs(predictions - actual_labels)) / (num_trials * polar_code_gen.K)

        # Calculate BLER: Number of blocks with at least one bit error / Total number of blocks
        #block_errors = np.sum(np.any(predictions != actual_labels, axis=1))
        #bler = block_errors / num_trials

        # Store the calculated BER and BLER for EACH specified list size label.
        # Note: The values are the same because it's the same RNN performance being measured.
        #for list_size in list_sizes:
         #   performance_results[list_size]['BER'].append(ber)
          #  performance_results[list_size]['BLER'].append(bler)

    #return performance_results

# Modified plot function to use the updated performance results structure
def plot_comprehensive_analysis(train_losses, val_losses, performance_results, snr_range, channel_name):
    plt.figure(figsize=(12, 15)) # Increased figure size

    # Plot Training and Validation Loss
    plt.subplot(3, 1, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title(f'{channel_name} Channel - Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()

    # Plot BER (from performance_results)
    plt.subplot(3, 1, 2)
    # Iterate through the decoder types (which are now just the list size labels)
    for list_size, results in performance_results.items():
        # Use the list_size as the label
        plt.plot(snr_range, results['BER'], label=f'RNN Decoder (List size {list_size})')

    plt.title(f'{channel_name} Channel - BER Performance')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.yscale('log')
    # Adjust ylim for BER to show better detail at lower error rates if needed
    plt.ylim(1e-4, 1) # Example adjustment
    plt.legend()
    plt.grid(True, which="both", ls="--") # Add grid

    # Plot BLER (from performance_results)
    plt.subplot(3, 1, 3)
    # Iterate through the decoder types (which are now just the list size labels)
    for list_size, results in performance_results.items():
        # Use the list_size as the label
        plt.plot(snr_range, results['BLER'], label=f'RNN Decoder (List size {list_size})')

    plt.title(f'{channel_name} Channel - BLER Performance')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.yscale('log')
     # Adjust ylim for BLER to show better detail at lower error rates if needed
    plt.ylim(1e-4, 1) # Example adjustment
    plt.legend()
    plt.grid(True, which="both", ls="--") # Add grid


    plt.tight_layout() # Adjust subplot parameters for a tight layout
    plt.show()
####################################################
# another trial to have all plots

def performance_comparison(rnn_trainer, polar_code_gen, snr_range, channel_name, list_sizes, num_trials):
    performance_results = {list_size: {'BER': [], 'BLER': []} for list_size in list_sizes}
    channel_simulator = EnhancedChannelSimulator(channel_type=channel_name)

    for list_size in list_sizes:
        for snr_db in snr_range:
            # Here, introduce any logic specific to list_size
            X, y = prepare_polar_dataset(polar_code_gen, num_samples=num_trials, snr_db=snr_db, channel_type=channel_name)

            # Example: Adjust how predictions are made or interpreted based on list size
            predictions = rnn_trainer.predict(X)

            actual_labels = y
            ber = np.sum(np.abs(predictions - actual_labels)) / (num_trials * polar_code_gen.K)
            block_errors = np.sum(np.any(predictions != actual_labels, axis=1))
            bler = block_errors / num_trials

            performance_results[list_size]['BER'].append(ber)
            performance_results[list_size]['BLER'].append(bler)
            print(f"List Size: {list_size}, SNR: {snr_db}, BER: {ber}, BLER: {bler}")

    return performance_results


#####################################################
# with two SNR separately teh comparison of SCL tradional and ML decoder
######################################################################
def compare_decoders():
    BLOCK_LENGTH = 128
    INFO_BITS = 64
    NUM_SAMPLES = 10000
    NUM_TRIALS = 1000
    LIST_SIZES = [1, 8, 16]
    SNR_AWGN = np.linspace(0, 5, 11)
    SNR_RAYLEIGH = np.linspace(0, 10, 11)

    polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
    results = {}

    for channel_type in ['AWGN', 'Rayleigh']:
        X, y = generate_dataset(channel_type, polar_code_gen, num_samples=NUM_SAMPLES, snr_db=3)
        X_tensor = torch.FloatTensor(X).view(-1, BLOCK_LENGTH)
        y_tensor = torch.FloatTensor(y).view(-1, INFO_BITS)

        # Choose SNR range based on channel type
        snr_range = SNR_AWGN if channel_type == 'AWGN' else SNR_RAYLEIGH

        # ML-based decoding
        ml_decoder_results = []
        for list_size in LIST_SIZES:
            print(f'Evaluating ML Decoder: List Size {list_size}')
            model = AWGNDecoder(input_size=BLOCK_LENGTH, output_size=INFO_BITS)  # Adjust for Rayleigh
            trainer = DecoderTrainer(model, learning_rate=0.001)
            trainer.train(X_tensor, y_tensor, epochs=100, batch_size=32)
            ml_results = performance_comparison(
                trainer, polar_code_gen, snr_range, channel_type, [list_size], NUM_TRIALS
            )
            ml_decoder_results.append(ml_results)

        # SCL decoding
        scl_decoder_results = []
        for list_size in LIST_SIZES:
            print(f'Evaluating SCL Decoder: List Size {list_size}')
            scl_results = run_scl_decoder(polar_code_gen, snr_range, list_size, channel_type, NUM_TRIALS)
            scl_decoder_results.append(scl_results)

        results[channel_type] = {
            'ML': ml_decoder_results,
            'SCL': scl_decoder_results
        }

    return results

###################################################
 # part six
#latest main
#Add on traditional decoder comprison
#latest main
#####################################################
def plot_results(results, channel_type):
    list_sizes = results[channel_type]['ML']
    snr_range = results[channel_type]['SCL'][0]['SNR_RANGE']

    plt.figure(figsize=(12, 8))

    # Plot BER
    plt.subplot(2, 1, 1)
    for i, list_size in enumerate(LIST_SIZES):
        ml_ber = [res['BER'] for res in results[channel_type]['ML'][i]]
        scl_ber = [res['BER'] for res in results[channel_type]['SCL'][i]]

        plt.plot(snr_range, ml_ber, label=f'ML Decoder (List Size {list_size})')
        plt.plot(snr_range, scl_ber, '--', label=f'SCL Decoder (List Size {list_size})')

    plt.title(f'{channel_type} Channel - BER Performance')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BER')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    # Plot BLER
    plt.subplot(2, 1, 2)
    for i, list_size in enumerate(LIST_SIZES):
        ml_bler = [res['BLER'] for res in results[channel_type]['ML'][i]]
        scl_bler = [res['BLER'] for res in results[channel_type]['SCL'][i]]

        plt.plot(snr_range, ml_bler, label=f'ML Decoder (List Size {list_size})')
        plt.plot(snr_range, scl_bler, '--', label=f'SCL Decoder (List Size {list_size})')

    plt.title(f'{channel_type} Channel - BLER Performance')
    plt.xlabel('SNR (dB)')
    plt.ylabel('BLER')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, which="both", ls="--")

    plt.tight_layout()
    plt.show()
#####################################################
def main():
    try:
        BLOCK_LENGTH = 32
        INFO_BITS = 16
        LEARNING_RATE = 1e-3
        EPOCHS = 50
        BATCH_SIZE = 32
        NUM_SAMPLES_TRAIN = 10000 # Increased training samples
        NUM_TRIALS_PERF = 1000  # Number of trials (blocks) for performance comparison at each SNR
        SNR_RANGE_AWGN = np.linspace(0, 5, 11) # More points for smoother curve
        SNR_RANGE_RAYLEIGH = np.linspace(0, 10, 11) # More points for smoother curve
        LIST_SIZES = [1, 8, 16] # List sizes to use for plotting labels


        polar_code_gen = PolarCodeGenerator(N=BLOCK_LENGTH, K=INFO_BITS)
        results = {}
        channels = {
            'AWGN': EnhancedChannelSimulator(channel_type='AWGN'),
            'Rayleigh': EnhancedChannelSimulator(channel_type='Rayleigh')
        }

        for channel_name, channel in channels.items():
            logging.info(f"Analyzing {channel_name} Channel")
            # Prepare dataset for training and validation
            logging.info(f"Generating training data ({NUM_SAMPLES_TRAIN} samples) for {channel_name} at SNR=5dB")
            # Train at a fixed moderate SNR, evaluate performance across a range
            X, y = prepare_polar_dataset(polar_code_gen, num_samples=NUM_SAMPLES_TRAIN, snr_db=5.0, channel_type=channel_name)

            # Convert numpy arrays to PyTorch tensors
            X_tensor = torch.FloatTensor(X)
            y_tensor = torch.FloatTensor(y) # y_tensor shape: [num_samples, K]

            # Flatten input features for the FCNN-based decoder
            X_tensor_flat = X_tensor.view(X_tensor.shape[0], -1) # Shape [num_samples, N]

            # No need to split y_tensor into binary labels, keep its original shape [num_samples, K]
            # The BCELoss will expect predictions of shape [batch_size, K] and targets of shape [batch_size, K]

            # Verify tensor shapes before training
            print("\n🔬 Processed Tensor Shapes (Training):")
            print(f"X_tensor_flat shape: {X_tensor_flat.shape}")
            print(f"y_tensor shape: {y_tensor.shape}")

            # Calculate the input size for the RNN based on the flattened data
            input_feature_size = X_tensor_flat.size(1) # This will be N (BLOCK_LENGTH)
            output_size = INFO_BITS # The RNN should output K bits
            print(f"Calculated input feature size: {input_feature_size}")
            print(f"Calculated output size (info bits): {output_size}")


            # Enhanced RNN Decoder (now correctly outputs K bits)
            rnn_model = EnhancedRNNDecoder(input_size=input_feature_size, output_size=output_size)
            rnn_trainer = DecoderTrainer(rnn_model)

            logging.info(f"Starting training for {channel_name} Channel RNN Decoder")
            # Train the RNN Decoder with multi-bit labels
            # Pass the flattened X and original y tensors
            train_losses, val_losses = rnn_trainer.train(X_tensor_flat, y_tensor, epochs=EPOCHS, batch_size=BATCH_SIZE)
            logging.info(f"Finished training for {channel_name} Channel RNN Decoder")


            # Perform performance comparison across SNR range
            snr_range = SNR_RANGE_AWGN if channel_name == 'AWGN' else SNR_RANGE_RAYLEIGH
            logging.info(f"Evaluating performance for {channel_name} Channel across SNR range: {snr_range}")

            # Call the modified performance_comparison
            # This will run the RNN decoder performance once and store results under multiple list_size keys.
            performance_results = performance_comparison(
                rnn_trainer, polar_code_gen, snr_range, channel_name, LIST_SIZES, NUM_TRIALS_PERF
            )
            logging.info(f"Finished performance evaluation for {channel_name} Channel")


            # Plotting Confusion Matrix for the test set
            # First, prepare a separate test set for confusion matrix visualization
            # Use a moderate SNR, e.g., 3dB, and a reasonable number of samples
            logging.info(f"Generating test data ({NUM_TRIALS_PERF} samples) for Confusion Matrix at SNR=3dB for {channel_name}")
            X_test_cm, y_test_cm = prepare_polar_dataset(polar_code_gen, num_samples=NUM_TRIALS_PERF, snr_db=3.0, channel_type=channel_name)
            X_test_cm_tensor = torch.FloatTensor(X_test_cm).view(X_test_cm.shape[0], -1)
            y_test_cm_tensor = torch.FloatTensor(y_test_cm) # Keep original shape [num_samples, K]


            predictions_test = rnn_trainer.predict(X_test_cm_tensor) # predictions_test shape: [num_samples, K]
            actual_labels_test = y_test_cm_tensor.numpy() # actual_labels_test shape: [num_samples, K]

            # To plot a single confusion matrix, we need to flatten the predictions and actual labels
            # This treats each predicted bit as an independent classification outcome.
            predictions_flat = predictions_test.flatten()
            actual_labels_flat = actual_labels_test.flatten()

            # Calculate and display confusion matrix
            logging.info(f"Plotting Confusion Matrix for {channel_name} Channel Test Set")
            cm = confusion_matrix(actual_labels_flat, predictions_flat)
            ConfusionMatrixDisplay(cm, display_labels=[0, 1]).plot() # Specify display_labels
            plt.title(f'Confusion Matrix - {channel_name}')
            plt.xlabel('Predicted label (All Info Bits)')
            plt.ylabel('True label (All Info Bits)')
            plt.show()


            # Plot comprehensive analysis (training loss, BER, BLER)
            logging.info(f"Plotting performance analysis for {channel_name} Channel")
            plot_comprehensive_analysis(
                 train_losses, val_losses, performance_results, snr_range, channel_name
            )


            results[channel_name] = {
                'decoder': rnn_trainer,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'performance': performance_results
            }

        logging.info("🎉 Simulation Complete!")
        return results
        results = compare_decoders()  # Call the comparison function
        # Analyze and plot results as desired
    except Exception as e:
        logging.error(f"🆘 Comprehensive Simulation Error: {e}")
        traceback.print_exc()
        return None


if __name__ == "__main__":
    main()