## SBRT 2024 - **An Introduction to Generative Artificial Intelligence with Applications in Telecommunications**

### This python notebook was based in the code at https://github.com/benediktfesl/Diffusion_channel_est, that comes from the paper called: [Diffusion-based Generative Prior for Low-Complexity MIMO Channel Estimation](https://arxiv.org/abs/2403.03545).

##### ***Heads-up:*** this notebook is only pedagogical. In other words, it is meant to showcase code snippets from the original repository that highlight the steps to generate a channel estimate using a diffusion model, but without considering the possibility to execute them.

# Observations

1) This notebook is based on the script `load_and_eval_dm.py`, available at https://github.com/benediktfesl/Diffusion_channel_est/blob/master/load_and_eval_dm.py

2) The `modules` imported in the cell below is in the folder https://github.com/benediktfesl/Diffusion_channel_est/tree/master/modules

3) The `DMCE` imported in the cell below is in the folder https://github.com/benediktfesl/Diffusion_channel_est/tree/master/DMCE

4) To run this and other scripts from this repository it is also necessary to download the channel data from https://syncandshare.lrz.de/getlink/fi93y1AnwmsvHrAGNqq5zX/ (password: `Diffusion2024`) and move it into a folder named `bin`.

# Importing packages

In [5]:
"""
Train and test script for the DMCE.
"""
import os
import argparse
import modules.utils as ut
import datetime
import csv
import matplotlib.pyplot as plt
import numpy as np
import DMCE
import torch
from DMCE.utils import cmplx2real

CUDA_DEFAULT_ID = 0


# Parameters configurations

In [6]:
device = 'cuda:0'

date_time_now = datetime.datetime.now()
date_time = date_time_now.strftime('%Y-%m-%d_%H-%M-%S')  # convert to str compatible with all OSs

n_dim = 64 # RX antennas
n_dim2 = 16 # TX antennas
num_train_samples = 100_000
num_val_samples = 10_000  # must not exceed size of training set
num_test_samples = 10_000

return_all_timesteps = False # evaluates all intermediate MSEs
fft_pre = True # learn channel distribution in angular domain through Fourier transform
reverse_add_random = False # re-sampling in the reverse process

In [None]:
# set data params
ch_type = '3gpp' # {quadriga_LOS, 3gpp}
n_path = 3
if n_dim2 > 1:
    mode = '2D'
else:
    mode = '1D'

_, _, data_test = ut.load_or_create_data(ch_type=ch_type, n_path=n_path, n_antennas_rx=n_dim,
                                                         n_antennas_tx=n_dim2, n_train_ch=num_train_samples,
                                                         n_val_ch=num_val_samples,
                                                         n_test_ch=num_test_samples, return_toep=False)
del _
if ch_type.startswith('3gpp') and n_dim2 > 1:
    data_test = np.reshape(data_test, (-1, n_dim, n_dim2), 'F')
data_test = torch.from_numpy(np.asarray(data_test[:, None, :]))
data_test = cmplx2real(data_test, dim=1, new_dim=False).float()
if ch_type.startswith('3gpp'):
    ch_type += f'_path={n_path}'

In [10]:
# load the model parameter dictionaries
cwd = os.getcwd()
#which_dataset = dataset
model_dir = os.path.join(cwd, './results/best_models_dm_paper', ch_type)
sim_params = DMCE.utils.load_params(os.path.join(model_dir, 'sim_params'))
num_timesteps = sim_params['diff_model_dict']['num_timesteps']
cnn_dict = sim_params['unet_dict']
diff_model_dict = sim_params['diff_model_dict']

# manually set the correct device for this simulation
cnn_dict['device'] = device

# instantiate the neural network
cnn = DMCE.CNN(**cnn_dict)

# instantiate the diffusion model and give it a reference to the unet model
diffusion_model = DMCE.DiffusionModel(cnn, **diff_model_dict)

# load the parameters of the pre-trained model into the DiffusionModel instance
model_path = os.path.join(model_dir, 'train_models')
model_list = os.listdir(model_path)
model_path = os.path.join(model_path, model_list[-1])
model_params = torch.load(model_path, map_location=device)

diffusion_model.load_state_dict(model_params['model'])

diffusion_model.reverse_add_random = reverse_add_random

In [11]:
# Tester parameter dictionary, which is saved in 'sim_params.json'
tester_dict = {
    'batch_size': 512,
    'criteria': ['nmse'],
    'complex_data': False,
    'return_all_timesteps': return_all_timesteps,
    'fft_pre': fft_pre,
    'mode': mode,
}

# Testing the neural network

In [None]:
# instantiate the Tester and give it a reference to the diffusion model as well as testing data
tester = DMCE.Tester(diffusion_model, data=data_test, **tester_dict)

In [None]:
# call the test() function. This returns a dictionary with the testing stats.
# Depending on the size of the test set, this might take a while.
test_dict = tester.test()

# to store the results
os.makedirs('./results/dm_est/', exist_ok=True)

When executing `tester.test()`, the code below is called. There, we can observe the steps mentioned in the slides for the channel estimation phase. 

The following code implements the step 3, the first one in the channel estimation phase from the slides. This happens in the line `29`

`y = functional.awgn(data_batch, snr, multiplier=self.model.noise_multiplier)`

where an initial estimate of the channel is obtained by implementing 

<img src="./figures/step3.png" width=240 height=40 />

which in code, translates to:

`x + (1 / snr ** 0.5) * multiplier * torch.randn_like(x)`

where x is the input matrix data. So, it could be rewritten as 

`H + (1 / snr ** 0.5) * multiplier * torch.randn_like(H)`

with `Ñ` being

`(1 / snr ** 0.5) * torch.randn_like(H)`

This adds to `H` an AWGN with variance = `1/snr` (i.e std_variation = `1/snr^0.5`)

obs: as mentioned in the code, the `multiplier` is a value added if "the input data is complex but real and imaginary parts are split up"

In [None]:
def _test_nmse(self) -> dict:
    """
    Test function for the NMSE criterion. For different SNR values between -20 and 40 dB, the test data is corrupted
    with noise and the DiffusionModel estimates the original data from the noisy input. For each SNR value, the MSE
    normalized by the average power of the whole dataset is calculated.

    Returns
    -------
    test_dict: dict
        Dictionary with tested SNRs in dB, MSEs normalized per sample and MSEs normalized by the average data power
    """

    # specify which SNRs should be evaluated
    snr_db_range = torch.arange(-10, 45, 5, dtype=torch.float32, device=self.device)
    snr_range = 10 ** (snr_db_range / 10)

    nmse_total_power_list = []

    with torch.no_grad():
        for snr in tqdm(iterable=snr_range):
            # test each SNR value
            x_hat = []
            for data_batch in self.dataloader:
                data_batch = data_batch.to(device=self.device)

                # NOTE from the workshop team: the code below implements the 
                # step 3. The first one in the channel estimation phase from the slides
                # add noise to the test data
                y = functional.awgn(data_batch, snr, multiplier=self.model.noise_multiplier)

                # calculate channel estimate
                x_est = self.model.generate_estimate(y.to(device=self.device), snr, return_all_timesteps=self.return_all_timesteps)
                if self.fft_pre:
                    x_est = ut.complex_1d_fft(x_est, ifft=True, mode=self.mode)
                x_hat.append(x_est)
            x_hat = torch.cat(x_hat, dim=0).cpu()

            if len(self.data.shape) == 4:
                #print('Reshaping...')
                dim = int(self.data.shape[-1] * self.data.shape[-2])
                x_hat = ut.reshape_fortran(x_hat, (-1, dim))
                nmse_total_power_list.append(functional.nmse_torch(ut.reshape_fortran(torch.squeeze(self.data), (-1, dim)), x_hat, norm_per_sample=False))
            else:
                # calculate NMSE from estimated channels
                nmse_total_power_list.append(functional.nmse_torch(torch.squeeze(self.data), torch.squeeze(x_hat), norm_per_sample=False))

    return {'SNRs': snr_db_range.tolist(),
            'NMSEs_total_power': nmse_total_power_list,
            }

During the above code, `generate_estimate()` is called, passing the noisy input data `y` and the known current experienced SNR `snr`. There, the diffusion model step `t` is calculated (as shown in the slides, in the step 6) and used in the function `reverse_sample_loop()` to obtain the channel estimate `x_hat`.

In other words, in the code, the step 6 comes first, estimating the diffusion model step

<img src="./figures/step6.png" width=300 height=40 />.

After that, the data is normalized in step 4:

<img src="./figures/step4.png" width=240 height=40 />

and the DM reverse process is initiallized (step 7):

<img src="./figures/step7.png" width=120 height=40 />

Let's recall the steps until now:

<img src="./figures/slide.png" />

In [None]:
def generate_estimate(self, y: torch.Tensor, snr: float, *, add_random: bool = None,
                          return_all_timesteps: bool = False) -> torch.Tensor:
        """
        Implements the estimation algorithm for channel data, but can also be used for other data types. Requires the DM
        to already be trained in order to work properly. It scales the input and performs the reverse process starting
        at the timestep that corresponds to the correct SNR value. Intended for public use.

        Parameters
        ----------
        y : Tensor of shape [batch_size, *self.data_shape]
            batch_size noisy data samples
        snr : float
            Estimated or known SNR of the noisy data sample
        return_all_timesteps : optional bool
            specifies whether to return the data samples of all timesteps or only the final one.
        add_random : optional bool
            Specifies whether the reverse_step should be deterministic or include a noise sampling step.

        Returns
        -------
        x_hat : Tensor of shape [n_samples, *self.data_shape]
            The denoised data samples after the whole reverse process
        OR
        x_ts : Tensor of shape [t_start + 1, n_samples, *self.data_shape]
            Collection of the  data samples in all timesteps. x_ts[-1] contains the fully denoised data samples.
        """

        add_random = utils.default(add_random, self.reverse_add_random)

        # NOTE from the workshop team: notice the step 6 happening below:
        # estimate t_hat, the time step that corresponds to the correct SNR
        t = int(torch.abs(self.snrs - snr).argmin())

        # NOTE from the workshop team: the code below implements the step 4:
        # normalize the input data accordingly (this might differ for other data than normalized channels)
        norm_multiplier = (snr / (1 + snr)) ** 0.5

        # We consider snr = 1/n2, so the code above can be read as
        # (1/((1/n2) / (1 + (1/n2)))) ** -0.5

        # (1/(snr / (1 + snr))) ** -0.5

                # Taking only the above denominator
                #    (1/n2)                  1
                # --------------   ->  -------------- 
                #  (1 + (1/n2))          n2  +  1 

        # (1/(1 / (n2 + 1))) ** -0.5

                # Taking only the term being exponentiated
                #       1                   
                # --------------   ->   n2 + 1
                #  (1 / (n2 + 1))        

        # n2 + 1 ** -0.5 (achieving the equation shown in step 4)

        # NOTE from the workshop team: the code below implements the step 7:
        x_t = norm_multiplier * y 

        # NOTE from the workshop team: the function reverse_sample_loop
        # implements the step 8:
        x_hat = self.reverse_sample_loop(x_t, t, return_all_timesteps=return_all_timesteps, add_random=add_random)
        return x_hat

Below there is the code equivalent to the step 8 from the channel estimation algorithm shown in the paper and the slides: 

Where, the step 8 is the DM loop:

<img src="./figures/step8.png" width=240 height=80 />

obtaining the final estimate generated by the model

In [None]:
def reverse_sample_loop(self, x_t: torch.Tensor, t_start: int,
                            *, return_all_timesteps: bool = False, add_random: bool = False) -> torch.Tensor:
        """
        Implements the whole reverse process down to t=0 by iteratively calling 'reverse_step()'.

        Parameters
        ----------
        x_t : Tensor of shape [batch_size, *self.data_shape]
            batch_size different data samples
        t_start : int
            starting time step of the reverse process
        return_all_timesteps : optional bool
            specifies whether to return the data samples of all timesteps or only the final one.
        add_random : optional bool
            Specifies whether the reverse_step should be deterministic or include a noise sampling step.

        Returns
        -------
        x_0 : Tensor of shape [batch_size, *self.data_shape]
            The denoised data sample after the whole reverse process
        OR
        x_ts : Tensor of shape [t_start + 1, batch_size, *self.data_shape]
            Collection of data samples in all timesteps. x_ts[-1] contains the fully denoised data sample.
        """

        assert t_start <= self.num_timesteps
        assert utils.equal_iterables(x_t.shape[1:], self.data_shape)
        x_all = [x_t]

        # NOTE from the workshop team: the code below implements the loop seem in step 8:
        for t in reversed(range(t_start)):
            x_t = self.reverse_step(x_t, t, add_random=add_random)
            if return_all_timesteps:
                x_all.append(x_t)

        # clip the final samples for image data to the range [-1, 1]
        if self.clipping:
            x_all = [torch.clamp(x, -1, 1) for x in x_all]
            x_t = torch.clamp(x_t, -1, 1)
        if return_all_timesteps:
            return torch.stack(x_all, dim=1)
        else:
            return x_t