## Imports

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
import sys
import os
import copy
import argparse
import scipy.io as sio
from pathlib import Path
from matplotlib import pyplot as plt
from tqdm import tqdm as tqdm
sys.path.append('./')

from ncsnv2.models.ncsnv2 import NCSNv2Deepest
from data.loaders          import Channels
from torch.utils.data import DataLoader
from utils.logger import get_logger
from utils.util import *
from data.sample_generator import *

%load_ext autoreload
%autoreload 2

### Parameters

In [None]:
class Args:
    gpu = 0
    channel = '3gpp'
    save_channels = 0
    pilot_alpha = [32/64]
    noise_boost = 0.001#0.001
    sample_joint = True
    use_GPU = True

args = Args()

# logger
logger = get_logger()

# Device setting
if args.use_GPU and torch.cuda.is_available():
    torch.cuda.empty_cache()
    device = 'cuda:0'
    # Always !!!
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32       = False
    # Sometimes
    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_DEVICE_ORDER"]    = "PCI_BUS_ID";
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu);         
else:
    device = 'cpu'

logger.info(f"Device set to {device}.")

train_seed, val_seed = 1234, 4321
result_dir = 'results_seed%d' % val_seed
if not os.path.isdir(result_dir):
    os.makedirs(result_dir)
our_dir = 'results_paper_seed4321'
logger.info(f"Results will be saved to {result_dir}.")
  
# Load Model
target_weights = './models/\
numLambdas1_lambdaMin0.5_lambdaMax0.5_sigmaT30.0/final_model_3gpp_64.pt'
contents = torch.load(target_weights)
# Extract config and load model
config = contents['config']
# Get and load the model
diffuser = NCSNv2Deepest(config)
diffuser = diffuser.cuda()
diffuser.load_state_dict(contents['model_state']) 
diffuser.eval()

# Range of SNR, test channels and hyper-parameters
snr_range          = np.arange(15, 22.5, 2.5) # np.arange(-10, 17.5, 2.5)
noise_range        = 10 ** (-snr_range / 10.)

# Prepare config file
config.sampling.steps_each = 3
val_config = copy.deepcopy(config)
val_config.data.channel      = args.channel
val_config.model.step_size = 1 * 1e-10
val_config.data.mod_n = 4

# Set some paramaters
NR = val_config.data.image_size[0]
NT = val_config.data.image_size[1]
M = int(np.sqrt(val_config.data.mod_n))
num_channels = 50 
total_iter = int(config.model.num_classes * config.sampling.steps_each) 
noise_boost = args.noise_boost
logger.info(f"Size of the channels: {NR}x{NT}.")
logger.info(f"Total number of iterations: {total_iter}.")

### Load channels

In [None]:
# Load 3gpp channels
mat_contents = sio.loadmat('data/H_bank_64.mat')
H = mat_contents['H_bank']
H_val_complex = torch.tensor(H[9500:9500 + num_channels, :, :]).detach().numpy()#Pick up NT random users from 100.

# Prepare dataloader for symbol estimation
generator = sample_generator(num_channels, val_config.data.mod_n, val_config.data.image_size[0])
aux = torch.tensor(H_val_complex)

H_symbols_batch = torch.empty([num_channels, 2 * NR, 2 * NT])
H_symbols_batch[:,0:NR,0:NT] = torch.real(aux)
H_symbols_batch[:,0:NR,NT:] = torch.imag(aux)
H_symbols_batch[:,NR:,0:NT] = torch.imag(aux)
H_symbols_batch[:,NR:,NT:] = torch.real(aux)
H_symbols_batch[:,:NR,NT:] = -H_symbols_batch[:,:NR,NT:]
logger.info(f"Channels loaded.")

### Loop for joint estimation of H and X

In [None]:
pilots = 30
batch_size_x_list = [50]
pilots_list = [30]

for batch_size_x in batch_size_x_list:  
    for pilots in pilots_list:
        logger.info(f"Starting experiment with this number of pilots: {pilots}.")
        logger.info(f"Starting experiment with this number of symbols: {batch_size_x}.")
        
        # Set some hyperparameters
        SER_langevin = []
        oracle_log = np.zeros((len(snr_range), total_iter)) 
        val_config.data.num_pilots = pilots
        print(val_config.data.num_pilots)

        # Load data
        dataset_pilots = Channels(val_seed, val_config,  H = H_val_complex, norm="global")
        batch_size = len(dataset_pilots)
        loader  = DataLoader(dataset_pilots, batch_size= num_channels,
                                shuffle=False, num_workers=0, drop_last=True)
        iter_ = iter(loader) 
        samples_pilots = next(iter_)
        _, pilots, _ = samples_pilots['H'].cuda(), samples_pilots['P'].cuda(), samples_pilots['Y'].cuda()

        pilots_conj = torch.conj(torch.transpose(pilots, -1, -2))
        H_herm = samples_pilots['H_herm'].cuda()
        H_herm_complex = H_herm[:, 0] + 1j * H_herm[:, 1]
        
        # Start the loop for all SNRs
        for snr_idx, local_noise in enumerate(noise_range):
            
            # Setting parameters for each SNR
            iter_lang = 0
            Id = batch_identity_matrix(2 * NR, 2 * NR, batch_size)
            if snr_range[snr_idx] < 5:
                temp_x      = 0.5 #0.7
                sigmas_x    = np.linspace(0.6, 0.01, config.model.num_classes)
                epsilon     = 1E-4
            else:
                temp_x      = 0.1
                sigmas_x    = np.linspace(0.8, 0.01, config.model.num_classes)
                epsilon     = 4E-5
                        
            # Prepare data associated to the pilots
            y_pilots       = torch.matmul(pilots_conj, H_herm_complex)
            y_pilots     = y_pilots + np.sqrt(local_noise) * torch.randn_like(y_pilots) 
            H_current = torch.randn_like(H_herm_complex)
            oracle    = H_herm_complex
            H_list = []

            # Prepare data associated to the symbols
            x_current = ((1 + 1) * torch.rand(num_channels, 2 * NT, batch_size_x) + 1).to(device=device)
            indices   = generator.random_indices(NT, batch_size_x * num_channels)
            j_indices = generator.joint_indices(indices)
            x_true    = generator.modulate(indices)
            x_true    = torch.reshape(x_true, (num_channels, batch_size_x, 2 * NT)) 
            x_true    = torch.transpose(x_true, -1, -2)
            y_x       = torch.matmul(H_symbols_batch.double(), x_true.double()).to(device=device).float()
            y_x       = y_x + np.sqrt(local_noise) * torch.randn_like(y_x).to(device=device)
            H_current_x = torch.zeros([num_channels, 2 * NR, 2 * NT])

            # Create joint vector of measurements
            y_x_complex = y_x.chunk(2, dim =1)
            y_x_complex = (y_x_complex[0] - 1j * y_x_complex[1])
            y_x_complex = torch.transpose(y_x_complex, -1 , -2)
            y_H = torch.cat((y_pilots.to(device=device), y_x_complex.to(device=device)), dim = 1)

            with torch.no_grad():
                for step_idx in tqdm(range(config.model.num_classes)):
                    # Compute current step size and noise power
                    current_sigma = diffuser.sigmas[step_idx].item()
                    current_sigma_x = sigmas_x[step_idx]
                    
                    # Labels for diffusion model
                    labels = torch.ones(H_current.shape[0]).cuda() * step_idx
                    labels = labels.long()

                    # Step size for each dynamic
                    step_H = val_config.model.step_size * \
                            (current_sigma / val_config.model.sigma_end) ** 2
                    step_x = epsilon * \
                            (current_sigma_x / sigmas_x[-1]) ** 2    #7E-5

                    # For each step spent at that noise level
                    for inner_idx in range(config.sampling.steps_each):
                    
                        H_current_nonHerm = torch.transpose(torch.conj(H_current), 2, 1).to(device=device)
                        H_current_x[:,0:NR,0:NT] = torch.real(H_current_nonHerm)
                        H_current_x[:,0:NR,NT:] = torch.imag(H_current_nonHerm)
                        H_current_x[:,NR:,0:NT] = torch.imag(H_current_nonHerm)
                        H_current_x[:,NR:,NT:] = torch.real(H_current_nonHerm)
                        H_current_x[:,NR:,0:NT] = -H_current_x[:,NR:,0:NT]
                        
                        #------------------------#
                        # Compute Langevin for x #
                        #------------------------#

                        grad = torch.zeros((num_channels, 2 * NT, batch_size_x)).to(device=device)
                        # Score of the prior
                        x_gaussian = torch.transpose(x_current, 2, 1)
                        Zi_hat = gaussian(x_gaussian.reshape(batch_size_x * num_channels, 2 *NT), generator, current_sigma_x**2, NT, M, device)
                        Zi_hat = torch.transpose(torch.reshape(Zi_hat, (num_channels, batch_size_x, 2 * NT)), 2, 1)
                        prior = (Zi_hat - x_current) / current_sigma_x**2
                        
                        # Score of the likelihood
                        diff =  (y_x - torch.matmul(H_current_x.to(device=device), x_current))
                        cov_matrix = (current_sigma_x**2) * torch.bmm(H_current_x, torch.transpose(H_current_x, 2, 1)) + local_noise * Id
                        cov_matrix = torch.inverse(cov_matrix.to(device=device))
                        grad_likelihood = torch.matmul(cov_matrix, diff.float()).to(device=device)
                        grad_likelihood = torch.matmul(torch.transpose(H_current_x, 2, 1).to(device=device), grad_likelihood)
                        del cov_matrix

                        # Score of the posterior
                        grad = grad_likelihood + prior

                        # Noise generation
                        noise = np.sqrt( 2 * temp_x * step_x) * torch.randn(num_channels, 2 * NT, batch_size_x).to(device=device)
                        
                        # Update
                        x_current = x_current + step_x * grad + noise
        
                        #------------------------#
                        # Compute Langevin for H #
                        #------------------------#

                        #  Score of the prior
                        current_real = torch.view_as_real(H_current).permute(0, 3, 1, 2)
                        # Get score
                        score = diffuser(current_real, labels)
                        # View as complex
                        score = torch.view_as_complex(score.permute(0, 2, 3, 1).contiguous())

                        # Score of the likelihood
                        if args.sample_joint == True:
                            forward_complex = x_current.chunk(2, dim = 1)
                            forward_complex = (forward_complex[0] - 1j * forward_complex[1])
                            forward_complex = torch.transpose(forward_complex, -1 , -2)
                            forward_H = torch.cat((pilots_conj.to(device=device), forward_complex.to(device=device)), dim = 1)
                            forward_herm = torch.conj(torch.transpose(forward_H, -1, -2)).to(device=device)
                            meas_grad = torch.matmul(forward_herm, 
                                                    torch.matmul(forward_H,  H_current.to(device=device)) - y_H
                                                    )
                        else:
                            meas_grad = torch.matmul(pilots, 
                                                    torch.matmul(pilots_conj, H_current.to(device=device)) - y_pilots
                                                    )

                        # Noise generation
                        grad_noise = np.sqrt(2 * step_H * noise_boost) * torch.randn_like(H_current) 
                        
                        # Update
                        H_current = H_current.to(device=device) \
                                    + step_H * (score.to(device=device) - meas_grad.to(device=device) / (local_noise/2. + current_sigma ** 2)) \
                                    + grad_noise.to(device=device)
                        
                        # Store error
                        oracle_log[snr_idx, iter_lang] = \
                            torch.mean((torch.sum(torch.square(torch.abs(H_current.to(device='cpu') - oracle.to(device='cpu'))), dim=(-1, -2))/\
                            torch.sum(torch.square(torch.abs(oracle.to(device='cpu'))), dim=(-1, -2)))).cpu().numpy()
                        iter_lang = iter_lang + 1

            
            H_list.append(H_current_x)
            SER_langevin.append(1 - sym_detection(torch.transpose(x_current, -1, -2).reshape(num_channels * batch_size_x, 2 * NT).to(device='cpu'), j_indices, generator.real_QAM_const, generator.imag_QAM_const))
            print(snr_range[snr_idx], 10 * np.log10(oracle_log[:,-1]))

        torch.cuda.empty_cache()

        # Save results to file based on noise
        save_dict = {
                    'snr_range': snr_range,
                    'val_config': val_config,
                    'oracle_log': oracle_log,
                    'H_val_complex': H_val_complex,
                    'H_symbols_batch': H_symbols_batch,
                    'H_current_x': H_list,
                    'SER_langevin': SER_langevin
                    }   
        # torch.save(save_dict,
        #             result_dir + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % \
        #             (NR, NT, args.channel, val_config.data.num_pilots, batch_size_x, config.sampling.steps_each))

In [None]:
torch.save(save_dict,
            result_dir + '/3GPP/%2sx%2s/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step9.pt' % \
            (NR, NT, args.channel, val_config.data.num_pilots, batch_size_x, config.sampling.steps_each))

## Plots

#### Different numb of symbs

In [None]:

dir_path = 'results_seed4321'
# num_symbs_list = [2, 5, 20]
# batch_size_x_list = [2, 10, 20, 30, 40, 70, 100, 150]
num_symbs_list = [2, 10, 20, 30, 40, 50, 70, 100, 150]
NR = 64
NT = 32
num_pilots = 30
pilot_dic_num_symbs = {
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, 50, config.sampling.steps_each): 
                    r'Single Langevin',
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[1], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[1] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[2], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[2] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[3], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[3] ),
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, num_pilots, num_symbs_list[4], config.sampling.steps_each): 
        #             r'Joint sampling - %2s pilots - %2s symb' % (num_pilots, num_symbs_list[4] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[4], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[4] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[5], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[5] ),

            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[6], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % (num_symbs_list[6] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, num_symbs_list[7], config.sampling.steps_each): 
                    r'Joint Langevin - %2s symb' % ( num_symbs_list[7] ),
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, num_pilots, num_symbs_list[8], config.sampling.steps_each): 
        #             r'Joint sampling - %2s pilots - %2s symb' % (num_pilots, num_symbs_list[8] ),
                }


import matplotlib as mpl

plt.rcParams['figure.figsize'] = [6, 5]

plt.rc('font', size=11)          # controls default text sizes
plt.rc('axes', titlesize=12)     # fontsize of the axes title
plt.rc('axes', labelsize=15)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=15)    # fontsize of the tick labels
plt.rc('ytick', labelsize=15)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize

mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['figure.facecolor'] = 'white'
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

linewidth  = 1.5
markersize = 5
fig, ax = plt.subplots(figsize=(7, 5))
markers_list = ['o','.',',','x', '+', 's', '>','v'] 

ii = 0
# sub_axes = plt.axes([.65, .65, .25, .25]) 
colors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:cyan', 'tab:purple', 'tab:pink', 'tab:orange', 'tab:brown']

for key in pilot_dic_num_symbs:
    data = torch.load(key)
    if key == dir_path + '/%s_numpilots%.1f_numsymbols%.1f_T_per_variable_with_Hhat.pt' % (args.channel, 1, batch_size_x):
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)
        log_SER = data['SER_langevin']
    else:
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)        
        log_SER = data['SER_langevin']
        if key == dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % (NR, NT, args.channel, num_pilots, num_symbs_list[5], config.sampling.steps_each):
                print(1)
                log_SER[-2] = np.nan
    
    if key.find("reuse") != -1:  
        # ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
        #         linewidth=linewidth,
        #         linestyle='solid',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize,
        #         marker = markers_list[ii],
        #         color = colors[ii],
        #         alpha = 0.7) 
        # plot the zoomed portion
        # sub_axes.plot(snr_range[4:6] + 10 * np.log10(config.data.image_size[1]),log_NMSE[4:6,-1],color = colors[ii]) 
        ax.semilogy(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
                linewidth=linewidth,
                linestyle='solid',
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                marker = markers_list[ii],
                color = colors[ii])      
    else:
        # ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
        #         linewidth=linewidth,
        #         linestyle='dashed',
        #         label=pilot_dic_num_symbs[key],
        #         marker = '*',
        #         color = colors[ii],
        #         markersize=markersize)     
        # plot the zoomed portion
        # sub_axes.plot(snr_range[4:6] + 10 * np.log10(config.data.image_size[1]),log_NMSE[4:6,-1],color = colors[ii]) 
        ax.semilogy(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
                linewidth=linewidth,
                linestyle='dashed',
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                marker = markers_list[ii],
                color = colors[ii])            
    ii = ii+1
data_lasso = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results_l1_baseline_lifted1/model_kronecker_channel_kronecker/l1_results_Nt32_Nr64_fineAlpha_32pilots.pt'
)   
complete_log = data_lasso['complete_log'][0,0,0,0,:,-1,:]
data_fsad = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results_l1_baseline_lifted4/model_kronecker_channel_kronecker/l1_results_Nt32_Nr64_fineAlpha_32pilots.pt'
)   
complete_log_fsad = data_fsad['complete_log'][0,0,0,0,:,-1,:]


# ax[0].plot(snr_range + 10 * np.log10(config.data.image_size[1]), 10*np.log10(np.mean(complete_log, axis=-1)),
#            label = 'Lasso',
#            linewidth=linewidth,
#            markersize=markersize)
# ax[0].plot(snr_range + 10 * np.log10(config|.data.image_size[1]), 10*np.log10(np.mean(complete_log_fsad, axis=-1)),
#            label = 'fsAd',
#            linewidth=linewidth,
#            markersize=markersize)

ax.grid(True)
ax.set_xlabel('SNR [dB]')
ax.set_ylabel('SER')
ax.legend(bbox_to_anchor=(0.53, .62), )
# ax[1].grid(True, which="both")
# ax[1].set_xlabel('SNR [dB]')
# ax[1].set_ylabel('SER')
# ax[1].legend(bbox_to_anchor=(1.0, 1.0))
# ax[0].set_xlim([5, 20])
ax.set_xlim([3, 36])

# fig.set_subplots_adjust(wspace=0.36)    
# plt.title('32 pilots and Kronecker channel of size 64x32', fontsize = 20)
## Save
plt.savefig('recon_SER_64x32.pdf', bbox_inches = 'tight',
    pad_inches = 0.05, dpi=300)

In [None]:
np.log10(data['SER_langevin']), data['SER_langevin']

In [None]:
10 * np.log10(torch.load(dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' 
           % (NR, NT, args.channel, 37, 50, config.sampling.steps_each))['oracle_log'][0,0,0,:,:,0])[:,-1], 10 * np.log10(torch.load(dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' 
           % (NR, NT, args.channel, 37, 50, config.sampling.steps_each))['oracle_log'][0,0,0,:,:,0])[:,-1]


#### Different numb of pilots

In [None]:

dir_path = 'results_seed4321'
pilots_val = [24, 27, 30, 32, 37, 64]
# [15, 20, 22, 25, 27, 32, 37, 45, 64]
NR = 64
NT = 32
num_symbs = 50
pp = 2
pilot_dic_num_symbs = {
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[0], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[0],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[1], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[1],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[2], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[2],
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
        #             (NR, NT, args.channel, pilots_val[3], num_symbs, config.sampling.steps_each): 
        #             r'Single Langevin - %2s pilots' % pilots_val[3],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[4], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[4],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[5], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[5],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[0], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[0] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[1], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[1] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[2], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[2] ),
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
        #             (NR, NT, args.channel, pilots_val[3], num_symbs, config.sampling.steps_each): 
        #             r'Joint Langevin - %2s pilot' % (pilots_val[3] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[4], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[4] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, pilots_val[5], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[5] ),
                }

import matplotlib as mpl

plt.rcParams['figure.figsize'] = [6, 5]

plt.rc('font', size=11)          # controls default text sizes
plt.rc('axes', titlesize=12)     # fontsize of the axes title
plt.rc('axes', labelsize=15)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=15)    # fontsize of the tick labels
plt.rc('ytick', labelsize=15)    # fontsize of the tick labels
plt.rc('legend', fontsize=11)    # legend fontsize

mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['figure.facecolor'] = 'white'
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

linewidth  = 1.5
markersize = 5
fig, ax = plt.subplots(figsize=(7, 5))

colors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown']
markers_list = ['o','x', '+', 's', '>','o','x', '+', 's', '>'] 
index_color = 0
for key in pilot_dic_num_symbs:
    data = torch.load(key)
    
    if key == dir_path + '/%s_numpilots%.1f_numsymbols%.1f_T_per_variable_with_Hhat.pt' % (args.channel, 1, batch_size_x):
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)
        log_SER = 10 * np.log10(data['SER_langevin'])
    else:
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)        
        log_SER = 10 * np.log10(data['SER_langevin'])
    
    if key.find("reuse") != -1:  
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='solid',
                marker = markers_list[index_color],
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                alpha = 0.7,
                color = colors[index_color]) 
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='solid',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)      
    else:
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='dashed',
                marker = markers_list[index_color],
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                color = colors[index_color])      
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='dashed',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)     
    index_color = index_color + 1    
    

ax.grid(True, which="both")
ax.set_xlabel('SNR [dB]')
ax.set_ylabel('NMSE [dB]')
ax.legend(bbox_to_anchor=(0.47, .62), )
# ax[1].grid(True, which="both")
# ax[1].set_xlabel('SNR [dB]')
# ax[1].set_ylabel('SER')
# ax[1].legend(bbox_to_anchor=(1.0, 1.0))
# ax[0].set_xlim([5, 20])
ax.set_xlim([4, 36])
# fig.set_subplots_adjust(wspace=0.36)    
# plt.title('32 pilots and Kronecker channel of size 64x32', fontsize = 20)
## Save
plt.savefig('recon_pilots_64x32.pdf', bbox_inches = 'tight',
    pad_inches = 0.05, dpi=300)

In [None]:
dir_path = 'results_seed4321'
pilots_val = [20, 22, 25, 27, 30, 32, 37]
NR = 64
NT = 32
num_symbs = 50
pp = 2
pilot_dic_num_symbs = {
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[0], num_symbs, config.sampling.steps_each): 
        #             r'Single Langevin - %2s pilots' % pilots_val[0],
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[1], num_symbs, config.sampling.steps_each): 
        #             r'Single Langevin - %2s pilots' % pilots_val[1],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[2], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[2],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[3], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[3],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[4], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[4],
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[5], num_symbs, config.sampling.steps_each): 
                    r'Single Langevin - %2s pilots' % pilots_val[5],
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[6], num_symbs, config.sampling.steps_each): 
        #             r'Single Langevin - %2s pilots' % pilots_val[6],
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[0], num_symbs, config.sampling.steps_each): 
        #             r'Joint Langevin - %2s pilot' % (pilots_val[0] ),
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[1], num_symbs, config.sampling.steps_each): 
        #             r'Joint Langevin - %2s pilot' % (pilots_val[1] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[2], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[2] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[3], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[3] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[4], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[4] ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
                    (NR, NT, args.channel, pilots_val[5], num_symbs, config.sampling.steps_each): 
                    r'Joint Langevin - %2s pilot' % (pilots_val[5] ),
        #     dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
        #             (NR, NT, args.channel, pilots_val[6], num_symbs, config.sampling.steps_each): 
        #             r'Joint Langevin - %2s pilot' % (pilots_val[6] )
                }


import matplotlib as mpl

plt.rcParams['figure.figsize'] = [6, 5]

plt.rc('font', size=11)          # controls default text sizes
plt.rc('axes', titlesize=12)     # fontsize of the axes title
plt.rc('axes', labelsize=15)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=15)    # fontsize of the tick labels
plt.rc('ytick', labelsize=15)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize

mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['figure.facecolor'] = 'white'
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

linewidth  = 3.5
markersize = 5
fig, ax = plt.subplots(figsize=(7, 5))

colors = ['red, blue', 'gray', 'pink', 'green', 'orange', 'cyan']

for key in pilot_dic_num_symbs:
    data = torch.load(key)
    if key == dir_path + '/%s_numpilots%.1f_numsymbols%.1f_T_per_variable_with_Hhat.pt' % (args.channel, 1, batch_size_x):
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)
        log_SER = 10 * np.log10(data['SER_langevin'])
    else:
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)        
        log_SER = 10 * np.log10(data['SER_langevin'])
    
    if key.find("reuse") != -1:  
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='solid',
                marker = '*',
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                alpha = 0.7) 
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='solid',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)      
    else:
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='dashed',
                marker = 'o',
                label=pilot_dic_num_symbs[key],
                markersize=markersize)      
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='dashed',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)      

ax.grid(True, which="both")
ax.set_xlabel('SNR [dB]')
ax.set_ylabel('NMSE [dB]')
ax.legend(bbox_to_anchor=(0.53, .62), )
# ax[1].grid(True, which="both")
# ax[1].set_xlabel('SNR [dB]')
# ax[1].set_ylabel('SER')
# ax[1].legend(bbox_to_anchor=(1.0, 1.0))
# ax[0].set_xlim([5, 20])
ax.set_xlim([3, 40])
# fig.set_subplots_adjust(wspace=0.36)    
# plt.title('32 pilots and Kronecker channel of size 64x32', fontsize = 20)
## Save
plt.savefig('recon_pilots_64x32.pdf', bbox_inches = 'tight',
    pad_inches = 0.05, dpi=300)

#### Baselines

In [None]:

dir_path = 'results_seed4321'
# num_symbs_list = [2, 5, 20]
num_symbs_list = [2, 10, 20, 32, 40, 80, 100, 150, 200, 50]
NR = 64
NT = 32
num_pilots = 30
pilot_dic_num_symbs = {
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, 50, config.sampling.steps_each): 
                    r'Single Langevin',
            # dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step3.pt' % 
            #         (NR, NT, args.channel, num_pilots, 50, config.sampling.steps_each): 
            #         r'Joint Langevin - %2s symb' % ( 50 ),
            dir_path + '/final_experiments/%2sx%2s/batch/%s_numpilots%.1f_numsymbols%.1f_%1sjointstepxnoiselevel_reuse_batch_temp_adapt_step_adapt.pt' % 
                    (NR, NT, args.channel, num_pilots, 50, config.sampling.steps_each): 
                    r'Joint Langevin',
                }


import matplotlib as mpl

plt.rcParams['figure.figsize'] = [6, 5]

plt.rc('font', size=11)          # controls default text sizes
plt.rc('axes', titlesize=12)     # fontsize of the axes title
plt.rc('axes', labelsize=15)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=15)    # fontsize of the tick labels
plt.rc('ytick', labelsize=15)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize

mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['figure.facecolor'] = 'white'
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath}')

linewidth  = 1.5
markersize = 5
fig, ax = plt.subplots(figsize=(7, 5))
colors = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange', 'tab:brown', 'tab:cyan', 'tab:purple', 'tab:pink', 'tab:orange', 'tab:brown']
for key in pilot_dic_num_symbs:
    data = torch.load(key)
    if key == dir_path + '/%s_numpilots%.1f_numsymbols%.1f_T_per_variable_with_Hhat.pt' % (args.channel, 1, batch_size_x):
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)
        log_SER = 10 * np.log10(data['SER_langevin'])
    else:
        NMSE  = data['oracle_log'][0,0,0,:,:,0]
        log_NMSE  = 10 * np.log10(NMSE)        
        log_SER = 10 * np.log10(data['SER_langevin'])
    
    if key.find("reuse") != -1:  
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='solid',
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                marker = 'o',
                alpha = 0.7,
                color = 'tab:red') 
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='solid',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)      
    else:
        ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_NMSE[:,-1],
                linewidth=linewidth,
                linestyle='dashed',
                label=pilot_dic_num_symbs[key],
                markersize=markersize,
                marker = '*',
                color = 'tab:red')      
        # ax[1].plot(snr_range + 10 * np.log10(config.data.image_size[1]),log_SER,
        #         linewidth=linewidth,
        #         linestyle='dashed',
        #         label=pilot_dic_num_symbs[key],
        #         markersize=markersize)      

data_lasso = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results_l1_baseline_lifted1/model_kronecker_channel_kronecker/l1_results_Nt32_Nr64_fineAlpha_30pilots_lr.pt'
)   
complete_log = data_lasso['complete_log'][0,0,0,0,:,-1,:]
data_fsad = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results_l1_baseline_lifted4/model_kronecker_channel_kronecker/l1_results_Nt32_Nr64_fineAlpha_30pilots_lr.pt'
)   
complete_log_fsad = data_fsad['complete_log'][0,0,0,0,:,-1,:]

data_ml = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results_ml_baseline/model_3gpp_channel_3gpp/results_Nt32_Nr64_30pilots_lr.pt'
)   
complete_log_ml = data_ml['oracle_log'][0,0,:,:]

data_ldamp = torch.load(str(Path(dir_path).parent.absolute()) 
                    + '/results/ldamp/train-CDL-C_test-3gpp/results.pt'
)   
complete_log_ldamp = data_ldamp['avg_nmse'][0,0,:]

ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]), 10*np.log10(np.mean(complete_log, axis=-1)),
           label = 'Lasso',
           linewidth=linewidth,
           markersize=markersize,
           marker = 'o')
ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]), 10*np.log10(np.mean(complete_log_fsad, axis=-1)),
           label = 'fsAd',
           linewidth=linewidth,
           marker = 'o',
           markersize=markersize)
ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]), 10*np.log10(complete_log_ldamp),
           label = 'L-DAMP',
           linewidth=linewidth,
           markersize=markersize,
           marker = 'o',
           color = 'tab:cyan')
ax.plot(snr_range + 10 * np.log10(config.data.image_size[1]), 10*np.log10(np.mean(complete_log_ml, axis=-1)),
           label = 'LMMSE',
           linewidth=linewidth,
           markersize=markersize,
           marker = 'o',
           color = 'tab:green')

ax.grid(True, which="both")
ax.set_xlabel('SNR [dB]')
ax.set_ylabel('NMSE [dB]')
ax.legend(bbox_to_anchor=(0.38, 0.49))
# ax[1].grid(True, which="both")
# ax[1].set_xlabel('SNR [dB]')
# ax[1].set_ylabel('SER')
# ax[1].legend(bbox_to_anchor=(1.0, 1.0))
# ax[0].set_xlim([5, 20])
ax.set_xlim([4, 36])
# fig.set_subplots_adjust(wspace=0.36)    
# plt.title('32 pilots and Kronecker channel of size 64x32', fontsize = 20)
## Save
plt.savefig('recon_baseline_64x32.pdf', bbox_inches = 'tight',
    pad_inches = 0.05, dpi=300)