In [None]:
import numpy as np
from scipy.linalg import pinv, inv, norm
from scipy.fft import fft2, ifft2

def sunsal_tv(M, Y, **kwargs):
    # Check for required parameters
    if len([M, Y]) != 2:
        raise ValueError("Wrong number of required parameters")

    # Mixing matrix size
    LM, n = M.shape
    # Data set size
    L, N = Y.shape
    if LM != L:
        raise ValueError("Mixing matrix M and data set Y are inconsistent")

    # Set defaults for optional parameters
    reg_l1 = 0
    reg_TV = 0
    im_size = []
    tv_type = 'niso'
    AL_iters = 1000
    mu = 0.001
    verbose = 'off'
    positivity = 'no'
    reg_pos = 0
    addone = 'no'
    reg_add = 0
    U0 = 0
    true_x = 0
    rmse = 0

    # Read optional parameters
    lambda_l1 = kwargs.get('LAMBDA_1', 0)
    lambda_TV = kwargs.get('LAMBDA_TV', 0)
    tv_type = kwargs.get('TV_TYPE', 'niso')
    im_size = kwargs.get('IM_SIZE', [])
    AL_iters = kwargs.get('AL_ITERS', 1000)
    positivity = kwargs.get('POSITIVITY', 'no')
    addone = kwargs.get('ADDONE', 'no')
    mu = kwargs.get('MU', 0.001)
    verbose = kwargs.get('VERBOSE', 'off')
    U0 = kwargs.get('X0', 0)
    XT = kwargs.get('TRUE_X', 0)

    if lambda_l1 < 0 or lambda_TV < 0 or AL_iters <= 0 or mu <= 0:
        raise ValueError("Invalid optional parameter values")

    if true_x:
        nr, nc = XT.shape
        if nr != n or nc != N:
            raise ValueError("Wrong image size")

    if reg_TV > 0:
        if N != np.prod(im_size):
            raise ValueError("Wrong image size")
        n_lin, n_col = im_size

        FDh = np.zeros(im_size)
        FDh[0, 0] = -1
        FDh[0, -1] = 1
        FDh = fft2(FDh)
        FDhH = np.conj(FDh)

        FDv = np.zeros(im_size)
        FDv[0, 0] = -1
        FDv[-1, 0] = 1
        FDv = fft2(FDv)
        FDvH = np.conj(FDv)

        IL = 1 / (FDhH * FDh + FDvH * FDv + 1)

        Dh = lambda x: np.real(ifft2(fft2(x) * FDh))
        DhH = lambda x: np.real(ifft2(fft2(x) * FDhH))

        Dv = lambda x: np.real(ifft2(fft2(x) * FDv))
        DvH = lambda x: np.real(ifft2(fft2(x) * FDvH))

    SMALL = 1e-12
    B = np.ones((1, n))
    a = np.ones((1, N))

    if not (reg_TV or reg_l1 or reg_pos or reg_add):
        U = pinv(M).dot(Y)
        res = norm(M.dot(U) - Y, 'fro')
        return U, res, rmse

    if not (reg_TV or reg_l1 or reg_pos) and reg_add:
        F = M.T.dot(M)
        if np.linalg.cond(F) < 1 / SMALL:
            IF = inv(F)
            U = IF.dot(M.T).dot(Y) - IF.dot(B.T).dot(inv(B.dot(IF).dot(B.T))).dot(B.dot(IF).dot(M.T).dot(Y) - a)
            res = norm(M.dot(U) - Y, 'fro')
            return U, res, rmse

    IF = inv(M.T.dot(M) + (reg_l1 + reg_pos + reg_add + reg_TV) * np.eye(n))

    if U0 == 0:
        U = IF.dot(M.T).dot(Y)

    index = 1
    V = [None] * (1 + reg_l1 + reg_pos + reg_add + reg_TV)
    D = [None] * (1 + reg_l1 + reg_pos + reg_add + reg_TV)

    reg = [1] + [0] * (reg_l1 + reg_pos + reg_add + reg_TV)
    V[0] = M.dot(U)
    D[0] = np.zeros_like(Y)

    if reg_pos:
        reg[index] = 2
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_add:
        reg[index] = 3
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_l1:
        reg[index] = 4
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_TV:
        reg[index] = 5
        V[index] = U
        D[index] = np.zeros_like(U)

        U_im = U.T.reshape(im_size + (n,))
        V.append([None] * n)
        D.append([None] * n)
        for i in range(n):
            V[-1][i] = [Dh(U_im[:, :, i]), Dv(U_im[:, :, i])]
            D[-1][i] = [np.zeros(im_size), np.zeros(im_size)]

    tol1 = np.sqrt(N) * 1e-5
    res = np.inf
    i = 1

    while i <= AL_iters and np.sum(np.abs(res)) > tol1:
        Xi = M.T.dot(V[0] + D[0])
        for j in range(1, reg_l1 + reg_pos + reg_add + reg_TV + 1):
            Xi += V[j] + D[j]
        U = IF.dot(Xi)

        for j in range(1 + reg_l1 + reg_pos + reg_add + reg_TV):
            if reg[j] == 1:
                V[j] = (1 / (1 + mu)) * (Y + mu * (M.dot(U) - D[j]))
            elif reg[j] == 2:
                V[j] = np.maximum(U - D[j], 0)
            elif reg[j] == 3:
                nu_aux = U - D[j]
                V[j] = nu_aux + np.outer((1 - np.sum(nu_aux, axis=0)) / n, np.ones(N))
            elif reg[j] == 4:
                V[j] = np.sign(U - D[j]) * np.maximum(np.abs(U - D[j]) - lambda_l1 / mu, 0)
            elif reg[j] == 5:
                nu_aux = U - D[j]
                nu_aux5_im = nu_aux.T.reshape(im_size + (n,))
                V5_im = np.zeros_like(nu_aux5_im)
                for k in range(n):
                    V5_im[:, :, k] = np.real(ifft2(IL * fft2(DhH(V[-1][k][0] + D[-1][k][0]) + DvH(V[-1][k][1] + D[-1][k][1]) + nu_aux5_im[:, :, k])))
                    aux_h = Dh(V5_im[:, :, k])
                    aux_v = Dv(V5_im[:, :, k])
                    if tv_type == 'niso':
                        V[-1][k][0] = np.sign(aux_h - D[-1][k][0]) * np.maximum(np.abs(aux_h - D[-1][k][0]) - lambda_TV / mu, 0)
                        V[-1][k][1] = np.sign(aux_v - D[-1][k][1]) * np.maximum(np.abs(aux_v - D[-1][k][1]) - lambda_TV / mu, 0)
                    else:
                        aux = np.maximum(np.sqrt((aux_h - D[-1][k][0]) ** 2 + (aux_v - D[-1][k][1]) ** 2) - lambda_TV / mu, 0)
                        V[-1][k][0] = aux / (aux + lambda_TV / mu) * (aux_h - D[-1][k][0])
                        V[-1][k][1] = aux / (aux + lambda_TV / mu) * (aux_v - D[-1][k][1])
                    D[-1][k][0] = D[-1][k][0] - (aux_h - V[-1][k][0])
                    D[-1][k][1] = D[-1][k][1] - (aux_v - V[-1][k][1])
                V[j] = V5_im.reshape(N, n).T

        for j in range(1 + reg_l1 + reg_pos + reg_add + reg_TV):
            D[j] = D[j] - (U - V[j])

        res[0] = norm(M.dot(U) - Y, 'fro')
        if true_x:
            rmse = norm(U - XT, 'fro') / np.sqrt(N)
            print(f"{i} : {np.sum(np.abs(res))} {rmse}")
        elif verbose == 'yes':
            print(f"{i} : {np.sum(np.abs(res))}")

        i += 1

    return U, res, rmse

# Example usage:
M = np.random.randn(100, 50)
Y = np.random.randn(100, 10)
result,a,b = sunsal_tv(M, Y, LAMBDA_1=0.1, LAMBDA_TV=0.1, IM_SIZE=[10, 10])


In [None]:
print(result)

[[ 1.88413928e-01  1.83190559e-02  1.10610326e-01 -7.04223174e-03
   1.56824762e-01 -2.08972421e-01 -2.73300529e-02  1.35921027e-01
   7.61633840e-02  2.16772072e-02]
 [-1.57986929e-01 -2.46072737e-01  1.74090996e-01 -5.61969785e-03
  -7.11473026e-02 -1.23000818e-01  1.40568141e-02 -2.62536047e-01
  -2.28286296e-02  8.99927938e-02]
 [ 1.32240377e-01  2.26885238e-01 -1.81431244e-01  2.53750299e-02
   1.32625294e-01 -1.90968615e-03 -4.86221949e-02  3.04920487e-02
   1.83313898e-01  2.27724190e-02]
 [ 1.25000236e-02  1.01195888e-01 -9.71973516e-02  2.14656886e-01
   1.79360518e-01  9.32386348e-02  1.65672717e-01 -1.65830632e-01
  -6.32265518e-01 -6.82536690e-02]
 [ 3.34720176e-01 -2.97324424e-02  2.50789257e-01  6.46356597e-02
   1.06812451e-01 -1.15976216e-01  2.27658772e-01 -4.05521759e-01
  -2.04049956e-01  1.12844953e-01]
 [-1.23222149e-01 -6.35627444e-02  4.86636692e-02 -2.55347501e-02
   4.58721691e-02  1.14218998e-01  2.47360783e-01  1.09230890e-01
  -2.98248598e-01  3.08342844e-02

In [None]:
import numpy as np
from scipy.linalg import pinv, inv, norm
from scipy.fft import fft2, ifft2

def sunsal_tv(M, Y, **kwargs):
    # Check for required parameters
    if len([M, Y]) != 2:
        raise ValueError("Wrong number of required parameters")

    # Mixing matrix size
    LM, n = M.shape
    # Data set size
    L, N = Y.shape
    if LM != L:
        raise ValueError("Mixing matrix M and data set Y are inconsistent")

    # Set defaults for optional parameters
    reg_l1 = 0
    reg_TV = 0
    im_size = []
    tv_type = 'niso'
    AL_iters = 1000
    mu = 0.001
    verbose = 'off'
    positivity = 'no'
    reg_pos = 0
    addone = 'no'
    reg_add = 0
    U0 = 0
    true_x = 0
    rmse = 0

    # Read optional parameters
    lambda_l1 = kwargs.get('LAMBDA_1', 0)
    lambda_TV = kwargs.get('LAMBDA_TV', 0)
    tv_type = kwargs.get('TV_TYPE', 'niso')
    im_size = kwargs.get('IM_SIZE', [])
    AL_iters = kwargs.get('AL_ITERS', 1000)
    positivity = kwargs.get('POSITIVITY', 'no')
    addone = kwargs.get('ADDONE', 'no')
    mu = kwargs.get('MU', 0.001)
    verbose = kwargs.get('VERBOSE', 'off')
    U0 = kwargs.get('X0', 0)
    XT = kwargs.get('TRUE_X', 0)

    if lambda_l1 < 0 or lambda_TV < 0 or AL_iters <= 0 or mu <= 0:
        raise ValueError("Invalid optional parameter values")

    if true_x:
        nr, nc = XT.shape
        if nr != n or nc != N:
            raise ValueError("Wrong image size")

    if reg_TV > 0:
        if N != np.prod(im_size):
            raise ValueError("Wrong image size")
        n_lin, n_col = im_size

        FDh = np.zeros(im_size)
        FDh[0, 0] = -1
        FDh[0, -1] = 1
        FDh = fft2(FDh)
        FDhH = np.conj(FDh)

        FDv = np.zeros(im_size)
        FDv[0, 0] = -1
        FDv[-1, 0] = 1
        FDv = fft2(FDv)
        FDvH = np.conj(FDv)

        IL = 1 / (FDhH * FDh + FDvH * FDv + 1)

        Dh = lambda x: np.real(ifft2(fft2(x) * FDh))
        DhH = lambda x: np.real(ifft2(fft2(x) * FDhH))

        Dv = lambda x: np.real(ifft2(fft2(x) * FDv))
        DvH = lambda x: np.real(ifft2(fft2(x) * FDvH))

    SMALL = 1e-12
    B = np.ones((1, n))
    a = np.ones((1, N))

    if not (reg_TV or reg_l1 or reg_pos or reg_add):
        U = pinv(M).dot(Y)
        res = norm(M.dot(U) - Y, 'fro')
        return U, res, rmse

    if not (reg_TV or reg_l1 or reg_pos) and reg_add:
        F = M.T.dot(M)
        if np.linalg.cond(F) < 1 / SMALL:
            IF = inv(F)
            U = IF.dot(M.T).dot(Y) - IF.dot(B.T).dot(inv(B.dot(IF).dot(B.T))).dot(B.dot(IF).dot(M.T).dot(Y) - a)
            res = norm(M.dot(U) - Y, 'fro')
            return U, res, rmse

    IF = inv(M.T.dot(M) + (reg_l1 + reg_pos + reg_add + reg_TV) * np.eye(n))

    if U0 == 0:
        U = IF.dot(M.T).dot(Y)

    index = 1
    V = [None] * (1 + reg_l1 + reg_pos + reg_add + reg_TV)
    D = [None] * (1 + reg_l1 + reg_pos + reg_add + reg_TV)

    reg = [1] + [0] * (reg_l1 + reg_pos + reg_add + reg_TV)
    V[0] = M.dot(U)
    D[0] = np.zeros_like(Y)

    if reg_pos:
        reg[index] = 2
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_add:
        reg[index] = 3
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_l1:
        reg[index] = 4
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1
    if reg_TV:
        reg[index] = 5
        V[index] = U
        D[index] = np.zeros_like(U)

        U_im = U.T.reshape(im_size + (n,))
        V.append([None] * n)
        D.append([None] * n)
        for i in range(n):
            V[-1][i] = [Dh(U_im[:, :, i]), Dv(U_im[:, :, i])]
            D[-1][i] = [np.zeros(im_size), np.zeros(im_size)]

    tol1 = np.sqrt(N) * 1e-5
    res = np.inf
    i = 1

    while i <= AL_iters and np.sum(np.abs(res)) > tol1:
        Xi = M.T.dot(V[0] + D[0])
        for j in range(1, reg_l1 + reg_pos + reg_add + reg_TV + 1):
            Xi += V[j] + D[j]
        U = IF.dot(Xi)

        for j in range(1 + reg_l1 + reg_pos + reg_add + reg_TV):
            if reg[j] == 1:
                V[j] = (1 / (1 + mu)) * (Y + mu * (M.dot(U) - D[j]))
            elif reg[j] == 2:
                V[j] = np.maximum(U - D[j], 0)
            elif reg[j] == 3:
                nu_aux = U - D[j]
                V[j] = nu_aux + np.outer((1 - np.sum(nu_aux, axis=0)) / n, np.ones(N))
            elif reg[j] == 4:
                V[j] = np.sign(U - D[j]) * np.maximum(np.abs(U - D[j]) - lambda_l1 / mu, 0)
            elif reg[j] == 5:
                nu_aux = U - D[j]
                nu_aux5_im = nu_aux.T.reshape(im_size + (n,))
                V5_im = np.zeros_like(nu_aux5_im)
                for k in range(n):
                    V5_im[:, :, k] = np.real(ifft2(IL * fft2(DhH(V[-1][k][0] + D[-1][k][0]) + DvH(V[-1][k][1] + D[-1][k][1]) + nu_aux5_im[:, :, k])))
                    aux_h = Dh(V5_im[:, :, k])
                    aux_v = Dv(V5_im[:, :, k])
                    if tv_type == 'niso':
                        V[-1][k][0] = np.sign(aux_h - D[-1][k][0]) * np.maximum(np.abs(aux_h - D[-1][k][0]) - lambda_TV / mu, 0)
                        V[-1][k][1] = np.sign(aux_v - D[-1][k][1]) * np.maximum(np.abs(aux_v - D[-1][k][1]) - lambda_TV / mu, 0)
                    else:
                        aux = np.maximum(np.sqrt((aux_h - D[-1][k][0]) ** 2 + (aux_v - D[-1][k][1]) ** 2) - lambda_TV / mu, 0)
                        V[-1][k][0] = aux / (aux + lambda_TV / mu) * (aux_h - D[-1][k][0])
                        V[-1][k][1] = aux / (aux + lambda_TV / mu) * (aux_v - D[-1][k][1])
                    D[-1][k][0] = D[-1][k][0] - (aux_h - V[-1][k][0])
                    D[-1][k][1] = D[-1][k][1] - (aux_v - V[-1][k][1])
                V[j] = V5_im.reshape(N, n).T

        for j in range(1 + reg_l1 + reg_pos + reg_add + reg_TV):
            D[j] = D[j] - (U - V[j])

        res[0] = norm(M.dot(U) - Y, 'fro')
        if true_x:
            rmse = norm(U - XT, 'fro') / np.sqrt(N)
            print(f"{i} : {np.sum(np.abs(res))} {rmse}")
        elif verbose == 'yes':
            print(f"{i} : {np.sum(np.abs(res))}")

        i += 1

    return U, res, rmse

# Example usage:
M = np.random.randn(100, 50)
Y = np.random.randn(100, 10)
result = sunsal_tv(M, Y, LAMBDA_1=0.1, LAMBDA_TV=0.1, IM_SIZE=[10, 10])


In [None]:
print(result)

(array([[ 1.78123099e-01,  1.05631668e-01,  1.08208015e-01,
        -7.09671434e-03, -5.94934061e-02,  2.05574802e-01,
        -6.55797572e-02,  5.10465578e-02,  1.47965833e-01,
        -8.47379020e-02],
       [ 1.31346641e-01,  9.78052285e-02,  2.71512974e-01,
        -5.44107810e-02,  7.17428433e-02,  1.23353974e-02,
         2.40892561e-02,  2.17022913e-03, -1.27529184e-01,
         2.87945581e-01],
       [ 2.67195231e-02,  8.90045286e-02,  7.73421040e-02,
        -1.49861242e-01, -4.84352894e-02, -3.79718568e-02,
        -9.80762356e-02, -9.77738010e-02,  9.81988206e-02,
        -6.00591107e-02],
       [-1.46795862e-02, -3.78636087e-02, -4.37966651e-02,
         6.36211348e-02, -5.60616083e-02, -1.41262914e-01,
        -1.41231381e-01,  6.86333946e-02,  3.37894641e-02,
        -2.03705431e-02],
       [ 1.12373982e-01, -1.90145623e-01, -1.41223374e-01,
        -1.58866116e-01, -2.67586413e-01,  5.27836980e-02,
        -1.70307782e-01, -1.16231162e-01,  2.55429143e-01,
         1

# New Section

Suslar


In [3]:
import numpy as np
from scipy.linalg import pinv, inv
from scipy.fftpack import fft2, ifft2
import numpy.fft as fft

def soft_threshold(x, threshold):
    return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

def weighted_svt(X, tau, weight):
    U, s, VT = np.linalg.svd(X, full_matrices=False)
    s_thresholded = soft_threshold(s, tau * weight)
    return U @ np.diag(s_thresholded) @ VT

def susrlr_tv(M, Y, segments, s=5, rho=0.01, **kwargs):
    # Ensure M and Y are numpy arrays
    M = np.array(M)
    Y = np.array(Y)

    # Mixing matrix size
    LM, n = M.shape
    # Data set size
    L, N = Y.shape

    if LM != L:
        raise ValueError('Mixing matrix M and data set Y are inconsistent')

    # Set defaults for optional parameters
    reg_l1 = kwargs.get('LAMBDA_1', 0) > 0
    lambda_l1 = kwargs.get('LAMBDA_1', 0)
    reg_TV = kwargs.get('LAMBDA_TV', 0) > 0
    lambda_TV = kwargs.get('LAMBDA_TV', 0)
    im_size = kwargs.get('IM_SIZE', [])
    tv_type = kwargs.get('TV_TYPE', 'niso')
    AL_iters = kwargs.get('AL_ITERS', 1000)
    mu = kwargs.get('MU', 0.001)
    verbose = kwargs.get('VERBOSE', 'off')
    positivity = kwargs.get('POSITIVITY', 'no') == 'yes'
    addone = kwargs.get('ADDONE', 'no') == 'yes'
    U0 = kwargs.get('X0', 0)
    true_x = 'TRUE_X' in kwargs
    XT = kwargs.get('TRUE_X', 0)

    if true_x:
        nr, nc = XT.shape
        if nr != n or nc != N:
            raise ValueError('Wrong image size')

    if reg_TV:
        num_superpixels = np.max(segments) + 1
        if N != np.prod(im_size):
            raise ValueError('Wrong image size')
        n_lin, n_col = im_size
        FDh = np.zeros(im_size)
        FDh[0, 0] = -1
        FDh[0, -1] = 1
        FDh = fft2(FDh)
        FDhH = np.conj(FDh)

        FDv = np.zeros(im_size)
        FDv[0, 0] = -1
        FDv[-1, 0] = 1
        FDv = fft2(FDv)
        FDvH = np.conj(FDv)

        IL = 1.0 / (FDhH * FDh + FDvH * FDv + 1)

        Dh = lambda x: np.real(ifft2(fft2(x) * FDh))
        DhH = lambda x: np.real(ifft2(fft2(x) * FDhH))

        Dv = lambda x: np.real(ifft2(fft2(x) * FDv))
        DvH = lambda x: np.real(ifft2(fft2(x) * FDvH))

    IF = inv(M.T @ M + (reg_l1 + reg_TV + positivity + addone) * np.eye(n))

    if isinstance(U0, int) and U0 == 0:
        U = IF @ M.T @ Y
    else:
        U = U0

    V = [None] * (1 + reg_l1 + reg_TV + positivity + addone)
    D = [None] * (1 + reg_l1 + reg_TV + positivity + addone)

    V[0] = M @ U
    D[0] = np.zeros_like(Y)

    index = 1

    if positivity:
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1

    if addone:
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1

    if reg_l1:
        V[index] = U
        D[index] = np.zeros_like(U)
        index += 1

    if reg_TV:
        V[index] = U
        D[index] = np.zeros_like(U)

        U_im = U.T.reshape(im_size[0], im_size[1], n)
        V[index + 1] = [None] * n
        D[index + 1] = [None] * n

        HT_H = np.zeros((n, n))
        H = np.zeros((n, N))
        D2 = np.zeros((n, N))
        D3 = np.zeros((n, N))

        for i in range(n):
            indices_i = np.where(segments == i)[0]
            HT_H[i, i] = np.sum(np.ones_like(indices_i))
            H[i, indices_i] = 1

        index += 2

    tol1 = np.sqrt(N) * 1e-5
    res = np.inf
    i = 1

    while i <= AL_iters and np.sum(np.abs(res)) > tol1:
        Xi = M.T @ (V[0] + D[0])
        for j in range(1, 1 + reg_l1 + reg_TV + positivity + addone):
            Xi += V[j] + D[j]

        U = IF @ Xi

        for j in range(1 + reg_l1 + reg_TV + positivity + addone):
            if j == 0:
                V[j] = (Y + mu * (M @ U - D[j])) / (1 + mu)
            elif positivity and j == 1:
                V[j] = np.maximum(U - D[j], 0)
            elif addone and j == 2:
                nu_aux = U - D[j]
                V[j] = nu_aux + np.tile((1 - np.sum(nu_aux, axis=0)) / n, (n, 1))
            elif reg_l1 and j == 3:
                V[j] = soft_threshold(U - D[j], lambda_l1 / mu)
            elif reg_TV and j == 4:
                V2 = np.linalg.solve(HT_H + np.eye(n), U - D2 + H.T @ (V3 + D3))
                V3 = soft_threshold(H @ V2 - D3, lambda_TV / mu)
                V4 = np.zeros_like(U)
                for superpixel in range(num_superpixels):
                    indices = np.where(segments == superpixel)[0]
                    if len(indices) == 0:
                        continue  # Skip empty superpixels
                    num_submatrices = min(s, len(indices))
                    if num_submatrices == 0:
                        continue  # Skip if no submatrices
                    submatrix_size = max(len(indices) // num_submatrices, 1)  # Ensure submatrix_size is at least 1
                    for j in range(num_submatrices):
                        start_idx = j * submatrix_size
                        end_idx = (j + 1) * submatrix_size if j < num_submatrices - 1 else len(indices)
                        submatrix_indices = indices[start_idx:end_idx]
                        U_submatrix = U[:, submatrix_indices]
                        V4[:, submatrix_indices] = weighted_svt(U_submatrix, rho / mu, np.ones(U_submatrix.shape[1]))
                V[j] = V4

        for j in range(1 + reg_l1 + reg_TV + positivity + addone):
            if j == 0:
                D[j] = D[j] - (M @ U - V[j])
            else:
                D[j] = D[j] - (U - V[j])

        if i % 10 == 1:
            res = [np.linalg.norm(M @ U - V[j], 'fro') if j == 0 else np.linalg.norm(U - V[j], 'fro')
                   for j in range(1 + reg_l1 + reg_TV + positivity + addone)]

            if verbose == 'yes':
                print(f'iter = {i} -', ' '.join([f'res({j}) = {r:2.6f}' for j, r in enumerate(res)]))

        if true_x:
            rmse = np.linalg.norm(U - XT, 'fro') / np.sqrt(n * N)
            if verbose == 'yes':
                print(f'RMSE = {rmse:2.6f}')

        i += 1

    return U, i, res, rmse if true_x else None

# Test scenario with random data
np.random.seed(0)

# Parameters
L = 30  # Number of spectral bands
N = 100  # Number of pixels
n = 5  # Number of endmembers

# Generate random mixing matrix M (L x n)
M = np.random.rand(L, n)

# Generate random observation matrix Y (L x N)
Y = np.random.rand(L, N)

# Generate random superpixel segments (arbitrary for demonstration)
segments = np.random.randint(0, 10, size=N)

# Call the susrlr_tv function with default parameters
U_estimated, _, _, _ = susrlr_tv(M, Y, segments)

U_estimated[U_estimated < 0] = 0

# Normalize each abundance vector to sum to one
U_estimated = U_estimated / np.sum(U_estimated, axis=0)


print(U_estimated.T[4])

Adjusted abundance matrix U:
Adjusted abundance vector U[4]:
[0.         0.11200228 0.22329625 0.33910281 0.32559866]
