Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ce4fa35
Showing
12 changed files
with
1,341 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
function [data] = get_data(file_name) | ||
load(file_name); | ||
|
||
%--------------- Required Fields ---------------------- | ||
% y : measurement in time domain | ||
|
||
%--------------- Optional Fields ---------------------- | ||
% x_true : true sparse sequence in time domain | ||
% h_true : true pulse sequence in time domain | ||
% f_max : upper limit for bandwidth of the pulse | ||
% H_subspace : subspace matric for pulse shape | ||
|
||
data.y = y_25dB; | ||
data.x_true = x; | ||
data.h_true = h; | ||
%data.f_max = fhi; | ||
%data.H_subspace = H_subspace; | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
function [params] = get_params(data) | ||
%------------------- Pulse Shape Model Parameters --------------------- | ||
% Set pulse_length to the length of true pulse shape, if known, | ||
% otherwise, set it to an arbitrary constant | ||
params.lambda_g = 10; | ||
if isfield(data,'h_true') | ||
params.pulse_length = length(data.h_true); | ||
else | ||
params.pulse_length = 23; | ||
end | ||
% Set upper limit for pulse bandwidth to an arbitrary constant, if it | ||
% is not provided within the data | ||
if isfield(data,'f_max') | ||
params.f_max = data.f_max; | ||
else | ||
params.f_max = 16e9; | ||
end | ||
% Construct subspace matrix using either dps sequences, or set it as | ||
% identity, if it is not provided within the data | ||
params.subspace_type = "identity"; % dpss | identity | ||
if isfield(data,'H_subspace') | ||
params.H_subspace = data.H_subspace; | ||
params.coeff_order = size(params.H_subspace,2); | ||
else | ||
params.H_subspace = get_pulse_subspace(params); | ||
params.coeff_order = size(params.H_subspace,2); | ||
end | ||
|
||
% General sampler parameters | ||
params.num_of_MCMC_iteration = 10000; | ||
params.burn_in_ratio = 0.75; | ||
params.solver = "NIGS2"; | ||
|
||
% NIGS2 parameters | ||
params.NIGS2_window_size = 10; | ||
|
||
% BGS parameters | ||
params.p0 = 0.9; | ||
params.lambda_x = 1; | ||
|
||
% BGS_Tuple parameters | ||
params.K_tuple = 3; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
function [H_subspace] = get_pulse_subspace(params) | ||
T = params.pulse_length; | ||
if params.subspace_type == "dpss" | ||
fs = params.fs; | ||
f_max = params.f_max; | ||
[dps_seq,~] = dpss(T,T*f_max/fs,floor(2*T*f_max/fs - 1)); | ||
B = [eye(T-1);-1*ones(1,T-1)]; | ||
M = null([B,dps_seq]); | ||
H_subspace = B*M(1:T-1,:); | ||
elseif params.subspace_type == "identity" | ||
H_subspace = eye(T); | ||
else | ||
fprintf('Incorrect key!'); | ||
end | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
clear all; | ||
close all; | ||
|
||
addpath(strcat(pwd,'\solvers')); | ||
addpath(strcat(pwd,'\datasets')) | ||
|
||
% change file_name to the '.mat' file containing the data | ||
file_name = strcat(pwd,'\datasets\mendel_sequence_data'); | ||
|
||
% load data | this function must be updated/changed according to the | ||
% structure of the input data file | ||
data = get_data(file_name); | ||
|
||
% get parameters | ||
params = get_params(data); | ||
|
||
% run the sampler and get the samples from the posterior distribution | ||
[samples,diagnostics] = run_MCMC(data,params); | ||
|
||
% get estimates from the samples | ||
MMSE_estimates = get_estimates(samples,params); | ||
|
||
% correct time-shift and scaling ambiguities to match the true sequences | ||
if isfield(data,'x_true') && isfield(data,'h_true') | ||
[MMSE_estimates_corrected,samples_corrected] = correct_shift_and_scale(MMSE_estimates,samples,data,params); | ||
|
||
figure;plot(data.x_true);hold on;grid on;plot(MMSE_estimates_corrected.x); | ||
xlabel('n');ylabel('Amplitude');legend('True Sequence','Recovered Sequence'); | ||
title('Recovered Sparse Sequence'); | ||
|
||
figure;plot(data.h_true);hold on;grid on;plot(MMSE_estimates_corrected.h); | ||
xlabel('n');ylabel('Amplitude');legend('True Sequence','Recovered Sequence'); | ||
title('Recovered Pulse Sequence'); | ||
else | ||
figure;plot(MMSE_estimates.x);grid on; | ||
xlabel('n');ylabel('Amplitude');title('Recovered Sparse Sequence'); | ||
|
||
figure;plot(MMSE_estimates.h);grid on; | ||
xlabel('n');ylabel('Amplitude');title('Recovered Pulse Sequence'); | ||
end | ||
|
||
function [MMSE_estimates] = get_estimates(samples,params) | ||
burn_in_ratio = params.burn_in_ratio; | ||
discarded_samples = burn_in_ratio*params.num_of_MCMC_iteration; | ||
|
||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
MMSE_estimates.x = mean(samples.x(:,discarded_samples:end),2); | ||
MMSE_estimates.h = mean(samples.h(:,discarded_samples:end),2); | ||
MMSE_estimates.lambda_x = mean(samples.lambda_x(:,discarded_samples:end),2); | ||
MMSE_estimates.lambda_v = mean(samples.lambda_v(:,discarded_samples:end),2); | ||
else | ||
MMSE_estimates.x = mean(samples.x(:,discarded_samples:end),2); | ||
MMSE_estimates.h = mean(samples.h(:,discarded_samples:end),2); | ||
MMSE_estimates.s = mean(samples.s(:,discarded_samples:end),2); | ||
MMSE_estimates.lambda_v = mean(samples.lambda_v(:,discarded_samples:end),2); | ||
end | ||
end | ||
|
||
function [MMSE_estimates_corrected,samples_corrected] = correct_shift_and_scale(MMSE_estimates,samples,data,params) | ||
h_true = data.h_true; | ||
|
||
h_est = MMSE_estimates.h; | ||
x_est = MMSE_estimates.x; | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
lambda_x_est = MMSE_estimates.lambda_x; | ||
else | ||
s_est = MMSE_estimates.s; | ||
end | ||
lambda_v_est = MMSE_estimates.lambda_v; | ||
T = length(h_est); | ||
|
||
delay_array = -round(T/2):round(T/2); | ||
error = zeros(length(delay_array),1); | ||
alpha = zeros(length(delay_array),1); | ||
for n = 1:length(delay_array) | ||
h_shifted = circshift(h_est,delay_array(n)); | ||
alpha(n) = (h_shifted'*h_true)/(h_shifted'*h_shifted); | ||
h_shifted_scaled = alpha(n)*h_shifted; | ||
error(n) = sum(abs(h_shifted_scaled - h_true).^2); | ||
end | ||
[~,min_idx] = min(error); | ||
MMSE_estimates_corrected.h = circshift(h_est,delay_array(min_idx))*alpha(min_idx); | ||
MMSE_estimates_corrected.x = circshift(x_est,-delay_array(min_idx))/alpha(min_idx); | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
MMSE_estimates_corrected.lambda_x = circshift(lambda_x_est,-delay_array(min_idx),1)*(alpha(min_idx)^2); | ||
else | ||
MMSE_estimates_corrected.s = circshift(s_est,-delay_array(min_idx),1)*(alpha(min_idx)^2); | ||
end | ||
MMSE_estimates_corrected.lambda_v = lambda_v_est; | ||
|
||
samples_corrected.h = circshift(samples.h,delay_array(min_idx))*alpha(min_idx); | ||
samples_corrected.x = circshift(samples.x,-delay_array(min_idx))/alpha(min_idx); | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
samples_corrected.lambda_x = circshift(samples.lambda_x,-delay_array(min_idx),1)*(alpha(min_idx)^2); | ||
else | ||
samples_corrected.s = circshift(samples.s,-delay_array(min_idx),1); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
function [samples,diagnostics] = run_MCMC(data,params) | ||
if params.solver == "NIGS1" | ||
[theta_chain,diagnostics] = sparse_NIGS1_solver(data,params); | ||
elseif params.solver == "NIGS2" | ||
[theta_chain,diagnostics] = sparse_NIGS2_solver(data,params); | ||
elseif params.solver == "BGS" | ||
[theta_chain,diagnostics] = sparse_BGS_solver(data,params); | ||
elseif params.solver == "BGS_Tuple" | ||
[theta_chain,diagnostics] = sparse_BGS_Ktuple_solver(data,params); | ||
else | ||
fprintf('No such solver!\n'); | ||
end | ||
|
||
H_subspace = params.H_subspace; | ||
% Extract original samples | ||
x_chain = theta_chain.x_chain; | ||
h_chain = H_subspace*theta_chain.gamma_chain; | ||
lambda_v_chain = theta_chain.lambda_v_chain; | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
lambda_x_chain = theta_chain.lambda_x_chain; | ||
beta_x_chain = theta_chain.beta_x_chain; | ||
alpha_x_chain = theta_chain.alpha_x_chain; | ||
else | ||
s_chain = theta_chain.s_chain; | ||
end | ||
log_posterior = diagnostics.log_posterior; | ||
|
||
% Calculate the reference pulse shape | ||
[h_ref] = get_h_ref(h_chain,log_posterior); | ||
|
||
% Correct shift-scale ambiguities based on reference pulse shape | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
[h_corrected,x_corrected,lambda_x_corrected] = correct_shift_and_scale_NIGS(h_ref,h_chain,x_chain,lambda_x_chain); | ||
else | ||
[h_corrected,x_corrected,s_corrected] = correct_shift_and_scale_BGS(h_ref,h_chain,x_chain,s_chain); | ||
end | ||
|
||
% Collect corrected samples | ||
samples.h = h_corrected; | ||
samples.x = x_corrected; | ||
samples.lambda_v = lambda_v_chain; | ||
if params.solver == "NIGS1" || params.solver == "NIGS2" | ||
samples.lambda_x = lambda_x_corrected; | ||
samples.beta_x = beta_x_chain; | ||
samples.alpha_x = alpha_x_chain; | ||
else | ||
samples.s = s_corrected; | ||
end | ||
end | ||
|
||
function [h_ref] = get_h_ref(h_samples,log_posterior_samples) | ||
% This function calculates the reference pulse shape that achieves the | ||
% highest posterior value. This will be used to correct different time | ||
% shift and scale configurations. | ||
|
||
[T,num_of_realizations] = size(h_samples); | ||
delay_array = -round(T/2):round(T/2); | ||
h_est = h_samples(:,end); | ||
|
||
h_delay = zeros(1,num_of_realizations); | ||
for i = 1:num_of_realizations | ||
h = h_samples(:,i); | ||
error = zeros(length(delay_array),1); | ||
alpha = zeros(length(delay_array),1); | ||
for j = 1:length(delay_array) | ||
h_shifted = circshift(h,delay_array(j)); | ||
alpha(j) = (h_shifted'*h_est)/(h_shifted'*h_shifted); | ||
h_shifted_scaled = alpha(j)*h_shifted; | ||
error(j) = sum(abs(h_shifted_scaled - h_est).^2); | ||
end | ||
[~,min_idx] = min(error); | ||
h_delay(i) = delay_array(min_idx); | ||
end | ||
[counts,bin_edges] = histcounts(h_delay); | ||
[~,max_idx] = max(counts); | ||
selected_delay = round(mean(bin_edges([max_idx,max_idx+1]))); | ||
same_delay_idx = find(h_delay==selected_delay); | ||
[~,max_idx] = max(log_posterior_samples(same_delay_idx)); | ||
h_ref = h_samples(:,same_delay_idx(max_idx)); | ||
end | ||
|
||
function [h_corrected,x_corrected,lambda_x_corrected] = correct_shift_and_scale_NIGS(h_ref,h_samples,x_samples,lambda_x_samples) | ||
num_of_iter = 20; | ||
[T,num_of_realizations] = size(h_samples); | ||
delay_array = -round(T/2):round(T/2); | ||
h_est = h_ref; | ||
|
||
h_est_array = zeros(size(h_samples,1),num_of_iter); | ||
x_est_array = zeros(size(x_samples,1),num_of_iter); | ||
lambda_x_est_array = zeros(size(lambda_x_samples,1),num_of_iter); | ||
for iter = 1:num_of_iter | ||
h_corrected = zeros(size(h_samples,1),num_of_realizations); | ||
x_corrected = zeros(size(x_samples,1),num_of_realizations); | ||
lambda_x_corrected = zeros(size(lambda_x_samples,1),num_of_realizations); | ||
for i = 1:num_of_realizations | ||
h = h_samples(:,i); | ||
x = x_samples(:,i); | ||
lambda_x = lambda_x_samples(:,i); | ||
error = zeros(length(delay_array),1); | ||
alpha = zeros(length(delay_array),1); | ||
for j = 1:length(delay_array) | ||
h_shifted = circshift(h,delay_array(j)); | ||
alpha(j) = (h_shifted'*h_est)/(h_shifted'*h_shifted); | ||
h_shifted_scaled = alpha(j)*h_shifted; | ||
error(j) = sum(abs(h_shifted_scaled - h_est).^2); | ||
end | ||
[~,min_idx] = min(error); | ||
h_corrected(:,i) = circshift(h,delay_array(min_idx))*alpha(min_idx); | ||
x_corrected(:,i) = circshift(x,-delay_array(min_idx))/alpha(min_idx); | ||
lambda_x_corrected(:,i) = circshift(lambda_x,-delay_array(min_idx))*(alpha(min_idx)^2); | ||
end | ||
h_est = mean(h_corrected,2); | ||
x_est = mean(x_corrected,2); | ||
lambda_x_est = mean(lambda_x_corrected,2); | ||
|
||
h_est_array(:,iter) = h_est; | ||
x_est_array(:,iter) = x_est; | ||
lambda_x_est_array(:,iter) = lambda_x_est; | ||
end | ||
end | ||
|
||
function [h_corrected,x_corrected,s_corrected] = correct_shift_and_scale_BGS(h_ref,h_samples,x_samples,s_samples) | ||
num_of_iter = 20; | ||
[T,num_of_realizations] = size(h_samples); | ||
delay_array = -round(T/2):round(T/2); | ||
h_est = h_ref; | ||
|
||
h_est_array = zeros(size(h_samples,1),num_of_iter); | ||
x_est_array = zeros(size(x_samples,1),num_of_iter); | ||
s_est_array = zeros(size(s_samples,1),num_of_iter); | ||
for iter = 1:num_of_iter | ||
h_corrected = zeros(size(h_samples,1),num_of_realizations); | ||
x_corrected = zeros(size(x_samples,1),num_of_realizations); | ||
s_corrected = zeros(size(s_samples,1),num_of_realizations); | ||
for i = 1:num_of_realizations | ||
h = h_samples(:,i); | ||
x = x_samples(:,i); | ||
s = s_samples(:,i); | ||
error = zeros(length(delay_array),1); | ||
alpha = zeros(length(delay_array),1); | ||
for j = 1:length(delay_array) | ||
h_shifted = circshift(h,delay_array(j)); | ||
alpha(j) = (h_shifted'*h_est)/(h_shifted'*h_shifted); | ||
h_shifted_scaled = alpha(j)*h_shifted; | ||
error(j) = sum(abs(h_shifted_scaled - h_est).^2); | ||
end | ||
[~,min_idx] = min(error); | ||
h_corrected(:,i) = circshift(h,delay_array(min_idx))*alpha(min_idx); | ||
x_corrected(:,i) = circshift(x,-delay_array(min_idx))/alpha(min_idx); | ||
s_corrected(:,i) = circshift(s,-delay_array(min_idx)); | ||
end | ||
h_est = mean(h_corrected,2); | ||
x_est = mean(x_corrected,2); | ||
s_est = mean(s_corrected,2); | ||
|
||
h_est_array(:,iter) = h_est; | ||
x_est_array(:,iter) = x_est; | ||
s_est_array(:,iter) = s_est; | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
function [xp] = self_slice_sampler_for_uv_case(initial,logpdf,width) | ||
maxiter = 200; | ||
e = exprnd(1,1,1); % needed for the vertical position of the slice. | ||
|
||
RW = rand(1,1); % factors of randomizing the width | ||
RD = rand(1,1); % uniformly draw the point within the slice | ||
x0 = initial; % current value | ||
|
||
inside = @(x,th) (logpdf(x) > th); | ||
|
||
z = logpdf(x0) - e; | ||
|
||
r = width.*RW; % random width/stepsize | ||
xl = x0 - r; | ||
xr = xl + width; | ||
|
||
iter = 0; | ||
% step out to the left. | ||
while inside(xl,z) && iter<maxiter | ||
xl = xl - width; | ||
iter = iter +1; | ||
end | ||
% step out to the right | ||
iter = 0; | ||
while (inside(xr,z)) && iter<maxiter | ||
xr = xr + width; | ||
iter = iter+1; | ||
end | ||
|
||
xp = RD.*(xr-xl) + xl; | ||
|
||
iter = 0; | ||
while(~inside(xp,z))&& iter<maxiter | ||
rshrink = (xp>x0); | ||
xr(rshrink) = xp(rshrink); | ||
lshrink = ~rshrink; | ||
xl(lshrink) = xp(lshrink); | ||
xp = rand(1,1).*(xr-xl) + xl; % draw again | ||
iter = iter+1; | ||
end | ||
end | ||
|
Oops, something went wrong.