In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import uuid
import time
import datetime
from scipy import stats
import matplotlib.ticker as ticker

In [None]:
def data_gen(D, noise_Std_Dev, k, p, rng):
    print(" noise level is ", noise_Std_Dev)
    m, n = D.shape
    x = rng.normal(0, 1, (n, p)) * rng.binomial(1, k, (n, p))
    # x = rng.binomial(1,k,(n,p))
    # mask = rng.uniform(0,1,(n,p))
    # x[mask > 0.5 ] *= -1

    y = D @ x
    # y = torch.tensor(y)
    noise = rng.normal(0, noise_Std_Dev, y.shape)
    y = y + noise

    return x, y

def pes(x, x_est):
    d = []
    for i in range(x.shape[1]):
        M = max(np.sum(x[:, i] != 0), np.sum(x_est[:, i] != 0))
        pes_ = (M - np.sum((x[:, i] != 0) * (x_est[:, i] != 0))) / M
        if not np.isnan(pes_):
            d.append(pes_)
        else:
            print(M)
            print("nan is found here")
    return np.mean(d), np.std(d)

In [None]:
def soft_thr(x, thr):
    """Applies soft thresholding to the input tensor x with threshold thr."""
    return torch.sign(x) * torch.relu(torch.abs(x) - thr)


class LISTA(nn.Module):
    def __init__(self, m, n, Dict, numIter, device='cpu', thr=None, no_activation=False):
        super(LISTA, self).__init__()
        self.numIter = numIter
        self.no_activation = no_activation
        # compute the max eigen value of the D'*D
        self.alpha = (np.linalg.norm(Dict, 2) ** 2) * 1.001
        self.thr = thr
        self.device = device

        # Setting up the linear layers with specific dimensions and without bias
        self._W = nn.Linear(in_features=m, out_features=n, bias=False)
        self._S = nn.Linear(in_features=n, out_features=n, bias=False)

        # Shrinkage thresholds, one per iteration
        self.thr = nn.Parameter(torch.rand(numIter, 1), requires_grad=True)
        self.A = Dict

    def weights_init(self):
        """Initializes weights for the _W and _S matrices based on the provided dictionary A and scaling factor alpha."""
        A = self.A
        alpha = self.alpha

        # Compute the initial weights for _W and _S based on the dictionary and alpha
        S = torch.from_numpy(np.eye(A.shape[1]) - (1 / alpha) * np.matmul(A.T, A))
        S = S.float().to(self.device)
        B = torch.from_numpy((1 / alpha) * A.T)
        B = B.float().to(self.device)

        thr = torch.ones(self.numIter, 1) * 0.1 / alpha

        # Setting the weights of _S and _W layers
        self._S.weight = nn.Parameter(S)
        self._W.weight = nn.Parameter(B)
        self.thr.data = thr

    def forward(self, y):
        """Forward pass of LISTA, performing iterative thresholding."""
        x = []
        d = torch.zeros(y.shape[0], self.A.shape[1], device=self.device)

        for iter in range(self.numIter):
            if self.no_activation:
                d = self._W(y) + self._S(d)
            else:
                d = soft_thr(self._W(y) + self._S(d), self.thr[iter])

            x.append(d)
        return x

def LISTA_test(net, Y, D, device):
    
    # convert the data into tensors
    Y_t = torch.from_numpy(Y.T)
    if len(Y.shape) <= 1:
        Y_t = Y_t.view(1, -1)
    Y_t = Y_t.float().to(device)
    D_t = torch.from_numpy(D.T)
    D_t = D_t.float().to(device)

    ratio = 1
    with torch.no_grad():
        # Compute the output
        net.eval()
        X_lista = net(Y_t.float())
        if len(Y.shape) <= 1:
            X_lista = X_lista.view(-1)
        X_final = X_lista[-1].cpu().numpy()
        X_final = X_final.T

    return X_final, X_lista

seed = 80
print("Seed: ", seed)
rng = np.random.RandomState(seed)

m = 70
n = 100
# create the random matrix
D = rng.normal(0, 1 / np.sqrt(m), [m, n])
D /= np.linalg.norm(D, 2, axis=0)


input_SNR = 0.0 
sparsity = 10
numTest = 100
X_test, Y_test = data_gen(D, input_SNR, sparsity / 100, numTest, rng)

In [None]:
def science_plot_squared_norm_of_recon(squared_norms, step=1, filename=None):
    # Plotting the stem plot
    plt.figure(figsize=(12, 6))
    layers_index = np.arange(1, len(squared_norms) + 1, step)
    plt.stem(
        layers_index,
        squared_norms[::step],
        linefmt="r-",
        markerfmt="bo",
        basefmt="gray",
    )

    # Labeling the axes
    plt.xlabel("Layer Index")
    plt.ylabel(r"$\|\hat{x}\|_2^2$")

    plt.legend()
    plt.grid(True, which="both", ls="--", antialiased=True)

    if filename:
        plt.savefig(filename, format="pdf", bbox_inches="tight")

    # Adding a title
    # plt.title('Squared Norm of Reconstructed X for Each Layer')

    # Showing the plot
    plt.show()

In [None]:
def plot_squared_norm_of_recon(squared_norms, step=1, filename=None):
    # Plotting the stem plot
    plt.figure(figsize=(10, 5))
    layers_index = np.arange(0, len(squared_norms), step)
    # squared_norms = np.ndarray(squared_norms)
    # plt.stem(layers_index, squared_norms[::step], linefmt='r-', markerfmt='bo', basefmt="gray")
    plot_signal(
        layers_index,
        squared_norms[::step],
        xlimits=[-1, len(squared_norms)],
        ylimits=[np.min(squared_norms) - 1, min(1e18, np.max(squared_norms) + 1)],
        xaxis_label="Layer Index",
        yaxis_label=r"$\|X_{recon}\|^2$",
        grid=True,
        save=filename,
        axis_formatter=None
    )

def plot_signal(
    x,
    y,
    ax=None,
    plot_colour="blue",
    alpha=1,
    xaxis_label=None,
    yaxis_label=None,
    title_text=None,
    legend_label=None,
    legend_show=True,
    legend_loc="lower left",
    n_col=2,
    line_style="-",
    line_width=None,
    xlimits=[-2, 2],
    ylimits=[-2, 2],
    axis_formatter="%0.1f",
    show=False,
    save=None,
    annotates=False,
    annotation=None,
    pos=None,
    marker=None,
    markersize=10,
    grid=None,
):
    """
    Plots signal with abscissa in x and ordinates in y

    """
    if ax is None:
        fig = plt.figure(figsize=(12, 6))
        ax = plt.gca()

    plt.plot(
        x,
        y,
        linestyle=line_style,
        linewidth=line_width,
        color=plot_colour,
        label=legend_label,
        zorder=0,
        alpha=alpha,
        marker=marker,
        markersize=markersize,
    )
    if legend_label and legend_show:
        plt.legend(
            ncol=n_col, loc=legend_loc, frameon=False, framealpha=0.8, facecolor="white"
        )

    if grid:
        # plt.grid()
        plt.grid(True, ls="--", lw=0.5, c="k", alpha=0.2)

    plt.xlim(xlimits)
    plt.ylim(ylimits)
    plt.xlabel(xaxis_label)
    plt.ylabel(yaxis_label)
    plt.title(title_text)

    if annotates:
        plt.annotate(annotation, xy=pos, color=plot_colour)

    if axis_formatter:
        ax.yaxis.set_major_formatter(ticker.FormatStrFormatter(axis_formatter))

    if save:
        plt.savefig(save + ".pdf", format="pdf")

    if show:
        plt.show()

    return

def evaluate_model(model, save=False):
    # Sparse Estimation using LISTA (testing phase)
    PES_list = []
    X_l2_squared_norm_list = []
    X_l2_squared_norm_gt_list = []
    SNR_list = []

    X_out, X_LISTA = LISTA_test(model, Y_test, D, "cpu")

    N_test = Y_test.shape[1]
    for i in range(N_test):

        err = np.linalg.norm(X_out[:, i] - X_test[:, i])
        RSNR = 20 * np.log10(np.linalg.norm(X_test[:, i]) / err)
        if np.isnan(RSNR):
            print("!!! nan found")
            break
        SNR_list.append(RSNR)

        # Calculate the L2 squared norm of X_out
        # l2_squared_norm = np.linalg.norm(X_out[:, i], ord=2) ** 2
        l2_squared_norm = X_out[:, i].T @ X_out[:, i]
        X_l2_squared_norm_list.append(l2_squared_norm)

        l2_squared_norm = X_test[:, i].T @ X_test[:, i]
        X_l2_squared_norm_gt_list.append(l2_squared_norm)

    PES_mean, PES_std = pes(X_test, X_out)

    SNR_list_LISTA = np.array(SNR_list)

    L2_squared_list_LISTA = np.array(X_l2_squared_norm_list)

    # Calculate and format LISTA-related statistics
    avg_snr = round(np.mean(SNR_list_LISTA), 4)
    std_snr = round(np.std(SNR_list_LISTA), 4)
    max_snr = np.max(SNR_list_LISTA)

    avg_pes = round(PES_mean, 4)
    std_pes = round(PES_std, 4)

    avg_l2_squared = round(np.mean(L2_squared_list_LISTA), 4)
    std_l2_squared = round(np.std(L2_squared_list_LISTA), 4)
    max_l2_squared = round(np.max(L2_squared_list_LISTA), 4)

    avg_l2_squared = round(np.linalg.norm(X_out, "fro") ** 2 / X_out.shape[1], 4)
    avg_l2_squared_GT = round(np.linalg.norm(X_test, "fro") ** 2 / X_out.shape[1], 4)
    std_l2_squared = round(np.std(L2_squared_list_LISTA), 4)
    max_l2_squared = round(np.max(L2_squared_list_LISTA), 4)

    x_layer_i_list = []
    for x_layer_i in X_LISTA:
        l2_norms = torch.norm(
            x_layer_i, p=2, dim=1
        )  # Calculate the L2 norm for each sample in the layer
        average_l2_norm = torch.mean(l2_norms)  # Compute the average L2 norm
        x_layer_i_list.append(
            average_l2_norm.item()
        )  # Append the average to the list, converting tensor to scalar

    print(f"Testing: my LISTA average SNR: {avg_snr}")
    print(f"Testing: my LISTA standard deviation in SNR: {std_snr}")
    print(f"Testing: my LISTA peak SNR: {max_snr}")

    print(f"Testing: my LISTA average X_out L2^2: {avg_l2_squared :.3f}")
    print(f"Testing: my LISTA average X_test L2^2: {avg_l2_squared_GT :.3f}")
    print(f"Testing: my LISTA standard deviation in X_out L2^2: {std_l2_squared}")
    print(f"Testing: my LISTA peak X_out L2^2: {max_l2_squared}")
    print(f"Testing: my LISTA mode X_out L2^2: {stats.mode(X_l2_squared_norm_list)}")
    print(
        f"Testing: my LISTA peak X_test L2^2: {stats.mode(X_l2_squared_norm_gt_list)}"
    )

    print(f"Testing: my LISTA average PES: {avg_pes}")
    print(f"Testing: my LISTA standard deviation in PES: {std_pes}")

    print("-" * 40)  # Divider line for better visual separation

    if save:
        with open(
            f"logs/01_test_stability_L{model.numIter}_act_{not model.no_activation}.txt",
            "w",
        ) as f:
            f.write(f"Learned threshold: {model.thr.T}\n")
            f.write(f"Input SNR: {input_SNR}\n")
            f.write(f"Sparsity: {sparsity}\n")
            f.write("-" * 40 + "\n")
            f.write(
                f"Average SNR: {avg_snr}, Standard deviation SNR: {std_snr}, Max SNR: {max_snr}\n"
            )
            f.write(
                f"Average X_out^2: {avg_l2_squared}, Standard deviation X_out^2: {std_l2_squared}, Max X_out^2: {max_l2_squared}\n"
            )
            f.write(f"Average PES: {avg_pes}, Standard deviation PES: {std_pes}\n")

    return X_out, X_test, x_layer_i_list

In [None]:
numIter_list = [15, 100, 500, 1000]  # Number of iterations
steps = [1, 2, 10, 15]
# numIter = 15  # Number of iterations
# thr_list = [0.001, 0.005, 0.01, 0.015, 0.02, 0.025, 0.03, 0.035, 0.04, 0.045, 0.05, 0.06, 0.07]
# thr_list = np.arange(0.04, 0.06, 0.001)

thr_list = [0.045, 0.045, 0.055, 0.055]


# Load the state dictionary
state_dict = torch.load("../weights/model_parameters.pth")

for numIter, thr, step in zip(numIter_list, thr_list, steps):
    print("*" * 40)
    print(f"Num Iter: {numIter}")
    print(f"Thr: {thr}")
    state_dict["thr"] = torch.ones(numIter, 1) * thr

    model = LISTA(m, n, D, numIter)
    model.load_state_dict(state_dict)

    X_out, X_test, x_layer_i_list = evaluate_model(model, save=False)
    filename = f"../figures/03_Impact_of_Soft_Thr/squared_norm_of_recon_L{model.numIter}_act_{not model.no_activation}_2.pdf"
    # plot_squared_norm_of_recon(x_layer_i_list, step, filename=filename)
    science_plot_squared_norm_of_recon(x_layer_i_list, step, filename=filename)
    # plot_squared_norm_of_recon(x_layer_i_list, step, filename=filename)

In [None]:
numIter_list = [15, 100, 500, 1000]  # Number of iterations
steps = [1, 2, 2, 2]
# steps = [1, 2, 10, 15]

# Load the state dictionary
state_dict = torch.load("../weights/model_parameters.pth")

for numIter, step in zip(numIter_list, steps):
    print("*" * 40)
    print(f"Num Iter: {numIter}")
    # No Activation function
    model = LISTA(m, n, D, numIter, no_activation=True)
    state_dict["thr"] = torch.ones(numIter, 1)
    model.load_state_dict(state_dict)
    X_out, X_test, x_layer_i_list = evaluate_model(model, save=False)
    # filename = f"../Latex_Plots/squared_norm_of_recon_L{model.numIter}_act_{not model.no_activation}.pdf"
    # science_plot_squared_norm_of_recon(x_layer_i_list, filename=filename)

    filename = f"../figures/03_Impact_of_Soft_Thr/squared_norm_of_recon_L{model.numIter}_act_{not model.no_activation}_2.pdf"
    # filename = None
    science_plot_squared_norm_of_recon(x_layer_i_list, step, filename=filename)
    # plot_squared_norm_of_recon(x_layer_i_list, step, filename=filename)