In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.linalg as LA
import scipy.io
import matplotlib.ticker as ticker

In [None]:
W = scipy.io.loadmat('../weights/analytic_D_m70_n100.mat')
W = np.array(W['X'])

D = scipy.io.loadmat('../weights/D_m70_n100.mat')
D = np.array(D['D'])

m, n = D.shape
S_GAP = np.eye(n) - D.T @ LA.inv(D @ D.T) @ D
S_ISTA = np.eye(n) - (1 / LA.norm(D.T @ D, 2)) * D.T @ D
S_ALISTA = np.eye(n) - W.T @ D

S_LISTA_list = np.load(r'../weights/S.npy')
S_LISTA = S_LISTA_list[-1]

s_matrix_dict = {'ISTA': S_ISTA, 'GAP': S_GAP,
                 'ALISTA': S_ALISTA, 'LISTA': S_LISTA}

FILE_PATH = '../figures/02_S_SVD/'

In [None]:
def plot_diracs(tk, ak, ax=None, plot_colour='blue', alpha=1,
                line_width=2, marker_style='o', marker_size=4, line_style='-',
                legend_show=True, legend_loc='lower left', legend_label=None, ncols=2,
                title_text=None, xaxis_label=None, yaxis_label=None, xlimits=[0, 1],
                ylimits=[-1, 1], show=False, save=None):
    ''' Plots Diracs at tk, ak '''
    if ax is None:
        fig = plt.figure(figsize=(12, 6))
        ax = plt.gca()

    markerline, stemlines, baseline = plt.stem(tk, ak, label=legend_label,
                                               linefmt=line_style)
    plt.setp(stemlines, linewidth=line_width, color=plot_colour, alpha=alpha)
    plt.setp(markerline, marker=marker_style, linewidth=line_width, alpha=alpha,
             markersize=marker_size, markerfacecolor=plot_colour, mec=plot_colour)
    plt.setp(baseline, linewidth=0)

    if legend_label and legend_show:
        plt.legend(ncol=ncols, loc=legend_loc, frameon=True, framealpha=0.8,
                   facecolor='white')

    plt.xlim(xlimits)
    plt.ylim(ylimits)
    plt.xlabel(xaxis_label)
    plt.ylabel(yaxis_label)
    plt.title(title_text)

    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%0.1f'))

    if save:
        plt.savefig(save + '.pdf', format='pdf')

    if show:
        plt.show()

    return

In [None]:
for title_text, matrix in s_matrix_dict.items():

    filename = FILE_PATH + "S_Matrix_SVD_" + title_text
    sv = LA.svd(matrix, compute_uv=False)
    # Prepare tk and ak for the plot_diracs function
    tk = np.arange(len(sv))  # Indices of the singular values
    ak = sv  # Singular values

    # Call your custom plotting function
    plot_diracs(tk, ak, plot_colour='blue', alpha=0.7, line_width=2, marker_style='o',
                marker_size=8, line_style='-', legend_show=True, legend_loc='upper right',
                # legend_label='Singular Values', title_text=f'Singular Value Decomposition of S Matrix ({title_text})',
                xaxis_label='Index'.upper(), yaxis_label='Singular Value'.upper(), xlimits=[-1, len(sv)],
                ylimits=[0, np.max(sv) * 1.1], show=True, save=filename)