In [None]:
import numpy as np

In [None]:
def get_robustK(thrs, args, params, d_comps):
    ## Initialize parameters

    ncomps = args.K
    nchs = args.num_chains
    nsamples = args.num_samples
    t_nsamples = nsamples * nchs
    M, N, D = args.num_sources, params['Z'][0].shape[1], params['W'][0].shape[1]

    # List of empty lists, one for each component, to store robust components.
    X_rob = [[] for _ in range(ncomps)]

    ## Initialization of Horseshoe Parameters for Z for all samples across all chains

    Z = np.full([t_nsamples, N, ncomps], np.nan)

    if args.model == 'sparseGFA':

        # Initializes the local shrinkage parameters λ_Z
        lmbZ = np.full([t_nsamples, N, ncomps], np.nan)

        # Initializes the global shrinkage parameter τ_Z
        tauZ = np.full([t_nsamples, ncomps], np.nan)

        if args.reghsZ:
            # Initializes the slab scale parameters cZ
            cZ = np.full([t_nsamples, ncomps], np.nan)

            ## Initialization of Horseshoe Parameters for W for all samples across all chains

    W = np.full([t_nsamples, D, ncomps], np.nan)

    if 'sparseGFA' in args.model:

        # Initializes the local shrinkage parameters λ_W
        lmbW = np.full([t_nsamples, D, ncomps], np.nan)

        # Initializes the slab scale parameters cW
        cW = np.full([t_nsamples, M, ncomps], np.nan)


    elif args.model == 'GFA':

        alpha = np.full([t_nsamples, M, ncomps], np.nan)

        ## Component Selection Initialization

    # Initializes a list to store component indices for each chain: a list with dimension: number of chains x number of components
    storecomps = [np.arange(ncomps) for _ in range(nchs)]

    # Retrieves the cosine similarity threshold (minimum similarity required for components to be considered as matching (same factors) across different chains)
    cosThr = thrs['cosineThr']

    # Retrieves the match threshold (the minimum number of chains in which a component must be present to be considered robust)
    matchThr = thrs['matchThr']

    # Initializes a counter for the number of robust components identified
    nrobcomp = 0

    ## Component Matching Across Chains

    # Iterates over each chain
    for c1 in range(nchs):

        # Matrix to store the similarity of each component with components in each chain.
        max_sim = np.zeros((ncomps, nchs))

        # Matrix to store maximum cosine similarities for the loading matrices W.
        max_simW = np.zeros((ncomps, nchs))

        # Matrix to store maximum cosine similarities for the latent variables Z.
        max_simZ = np.zeros((ncomps, nchs))

        # Initializes an array to store the matching indices for each component and chain.
        matchInds = np.zeros((ncomps, nchs))

        # Iterates over each component k in the current chain c1, the goal is to find matching components in other chains
        for k in storecomps[c1]:

            # Initializes an empty list to store chains that have components left to be matched.
            nonempty_chs = []

            # This loop identifies chains that have non-empty components and appends them to nonempty_chs.
            for ne in range(len(storecomps)):
                if storecomps[ne].size > 0:
                    nonempty_chs.append(ne)

            # Iterates over each non-empty chain c2 to find the most similar components to k.
            for c2 in nonempty_chs:

                # Initializes an array to store cosine similarities between component "k" in chain "c1" and each component in chain "c2".
                cosine = np.zeros((1, storecomps[c2].size))

                # Initializes an array to store cosine similarities for the loading matrices W.
                cosW = np.zeros((1, storecomps[c2].size))

                # Initializes an array to store cosine similarities for the latent variables Z.
                cosZ = np.zeros((1, storecomps[c2].size))

                # Initializes an index for storing similarities in the arrays
                cind = 0

                ## Calculating Cosine Similarities

                # Iterates over each component in chain "c2", to compare with the current component "k" in "c1" to find the most similar one.
                for comp in storecomps[c2]:

                    """
                    comp1 = np.ndarray.flatten(d_comps[c2][comp])
                    comp2 = np.ndarray.flatten(d_comps[c1][k])
                    cosine[0, cind] = np.dot(comp1, comp2) / (np.linalg.norm(comp1) * np.linalg.norm(comp2))

                    compW1 = np.mean(params['W'][c2], axis=0)[:, comp]
                    compW2 = np.mean(params['W'][c1], axis=0)[:, k]
                    cosW[0, cind] = np.dot(compW1, compW2) / (np.linalg.norm(compW1) * np.linalg.norm(compW2))

                    compZ1 = np.mean(params['Z'][c2], axis=0)[:, comp]
                    compZ2 = np.mean(params['Z'][c1], axis=0)[:, k]
                    cosZ[0, cind] = np.dot(compZ1, compZ2) / (np.linalg.norm(compZ1) * np.linalg.norm(compZ2))
                    """

                    ## X
                    comp1 = np.ndarray.flatten(d_comps[c2][comp])
                    comp2 = np.ndarray.flatten(d_comps[c1][k])
                    norm1_X = np.linalg.norm(comp1)
                    norm2_X = np.linalg.norm(comp2)
                    if norm1_X == 0 or norm2_X == 0:
                        cosine[0, cind] = 0
                    else:
                        cosine[0, cind] = np.dot(comp1, comp2) / (norm1_X * norm2_X)

                    ## W
                    compW1 = np.mean(params['W'][c2], axis=0)[:, comp]
                    compW2 = np.mean(params['W'][c1], axis=0)[:, k]
                    norm1_W = np.linalg.norm(compW1)
                    norm2_W = np.linalg.norm(compW2)
                    if norm1_W == 0 or norm2_W == 0:
                        cosW[0, cind] = 0
                    else:
                        cosW[0, cind] = np.dot(compW1, compW2) / (norm1_W * norm2_W)

                    ## Z
                    compZ1 = np.mean(params['Z'][c2], axis=0)[:, comp]
                    compZ2 = np.mean(params['Z'][c1], axis=0)[:, k]
                    norm1_Z = np.linalg.norm(compZ1)
                    norm2_Z = np.linalg.norm(compZ2)
                    if norm1_Z == 0 or norm2_Z == 0:
                        cosZ[0, cind] = 0
                    else:
                        cosZ[0, cind] = np.dot(compZ1, compZ2) / (norm1_Z * norm2_Z)

                    cind += 1

                ## Identifying the Most Similar Components

                # Finds the maximum cosine similarity for component "k" in chain "c1" with the components in chain "c2".
                max_sim[k, c2] = cosine[0, np.argmax(cosine)]

                max_simW[k, c2] = cosW[0, np.argmax(cosine)]

                max_simZ[k, c2] = cosZ[0, np.argmax(cosine)]

                # Records the index of the component in "c2" that is most similar to the component "k" in "c1".
                matchInds[k, c2] = storecomps[c2][np.argmax(cosine)]

                # Sets the match index to -1 if the maximum similarity is below the cosine threshold.
                if max_sim[k, c2] < cosThr:
                    matchInds[k, c2] = -1

            ## Identifying Robust Components

            # Checks if the number of chains with a similarity above the threshold exceeds the match threshold.
            if np.sum(max_sim[k, :] > cosThr) > matchThr * nchs:

                # Finds indices of chains where the component k has a match.
                goodInds = np.where(matchInds[k, :] >= 0)

                # Initializes an array to store the robust component k.
                X_rob[k] = np.zeros((N, D))

                # Initializes a sample index counter.
                s = 0

                # Loops over the chains where the component k has a match.
                for c2 in list(goodInds[0]):

                    # Creates an array of sample indices.
                    inds = np.arange(s, s + nsamples)

                    # Adds the matched component data to X_rob[k].
                    X_rob[k] += d_comps[c2][int(matchInds[k, c2])]

                    ## Updating Parameters for Robust Components

                    if args.model == 'sparseGFA':

                        # Updates the local shrinkage parameters for Z.
                        lmbZ[inds, :, k] = params['lmbZ'][c2][:, :, int(matchInds[k, c2])]

                        # Updates the global shrinkage parameters for Z.
                        tauZ[inds, k] = params['tauZ'][c2][:, int(matchInds[k, c2])]

                        # Updates the slab scale parameters for Z if regularized horseshoe priors are used.
                        if args.reghsZ:
                            cZ[inds, k] = params['cZ'][c2][:, int(matchInds[k, c2])]

                    if 'sparseGFA' in args.model:

                        # Updates the local shrinkage parameters for W
                        lmbW[inds, :, k] = params['lmbW'][c2][:, :, int(matchInds[k, c2])]

                        # Updates the slab scale parameters for W.
                        cW[inds, :, k] = params['cW'][c2][:, :, int(matchInds[k, c2])]

                    # Updates the ARD parameters for the GFA model.
                    elif args.model == 'GFA':
                        alpha[inds, :, k] = params['alpha'][c2][:, :, int(matchInds[k, c2])]

                    ## Handling Loading Matrices and Latent Variables

                    # Updates the loading matrices if similarity is positive.
                    if max_simW[k, c2] > 0:
                        W[inds, :, k] = params['W'][c2][:, :, int(matchInds[k, c2])]

                    # Inverts the loading matrices if similarity is negative.
                    else:
                        W[inds, :, k] = -params['W'][c2][:, :, int(matchInds[k, c2])]

                    # Updates the latent variables if similarity is positive.
                    if max_simZ[k, c2] > 0:
                        Z[inds, :, k] = params['Z'][c2][:, :, int(matchInds[k, c2])]

                    # Inverts the latent variables if similarity is negative.
                    else:
                        Z[inds, :, k] = -params['Z'][c2][:, :, int(matchInds[k, c2])]

                    ## Updating Components and Sample Indices

                    # Updates storecomps[c2] by removing the index of the matching component (int(matchInds[k, c2])) for component k.
                    storecomps[c2] = storecomps[c2][storecomps[c2] != int(matchInds[k, c2])]

                    s += nsamples

                # Averages the robust component data.
                X_rob[k] = [X_rob[k] / np.sum(matchInds[k, :] >= 0)]

                nrobcomp += 1

    success = True

    ## Removal of Non-Robust Components

    if nrobcomp > 0:

        # Identifies columns in Z that are not entirely NaN. This step filters out non-robust components by ensuring only columns with valid data are retained.
        idx_cols = ~np.isnan(np.mean(Z, axis=1)).all(axis=0)

        # idx_cols: A boolean array indicating which columns (components) are robust (not all NaN).

        # Identifies rows in Z that do not have any NaN values in the retained columns. Ensures that only valid rows are kept, filtering out samples with missing data.
        idx_rows = ~np.isnan(np.mean(Z, axis=1)[:, idx_cols]).any(axis=1)

        # idx_rows: A boolean array indicating which rows (samples) are robust (not containing any NaN in the robust components).

        if args.model == 'sparseGFA':

            lmbZ = lmbZ[idx_rows, :, :]
            lmbZ = lmbZ[:, :, idx_cols]

            tauZ = tauZ[idx_rows, :]
            tauZ = tauZ[:, idx_cols]

            if args.reghsZ:

                cZ = cZ[idx_rows, :]
                cZ = cZ[:, idx_cols]

                if cZ.size == 0:
                    print('No samples survived!')
                    success = False

        W = W[idx_rows, :, :]
        W = W[:, :, idx_cols]

        Z = Z[idx_rows, :, :]
        Z = Z[:, :, idx_cols]

        if 'sparseGFA' in args.model:

            lmbW = lmbW[idx_rows, :, :]
            lmbW = lmbW[:, :, idx_cols]

            cW = cW[idx_rows, :, :]
            cW = cW[:, :, idx_cols]

        elif args.model == 'GFA':

            alpha = alpha[idx_rows, :, :]
            alpha = alpha[:, :, idx_cols]

        # Removes empty entries from X_rob.
        X_rob_final = [X_rob[i] for i in range(len(X_rob)) if X_rob[i] != []]
        X_rob = X_rob_final


    else:
        print('No robust components found!')
        success = False

        ## Creating Dictionary of Robust Parameters (save the posterior mean for most parameters)

    rob_params = {'W': np.mean(W, axis=0),
                  'Z': np.mean(Z, axis=0)}  # np.mean(W, axis=0) - Mean of W parameters across samples

    if 'sparseGFA' in args.model:
        rob_params.update({'cW_inf': np.mean(cW, axis=0), 'lmbW': np.mean(lmbW, axis=0)})

    elif args.model == 'GFA':
        rob_params.update({'alpha_inf': np.mean(alpha, axis=0)})

    if args.model == 'sparseGFA':

        if args.reghsZ:
            rob_params.update({'cZ_inf': cZ, 'tauZ_inf': tauZ, 'lmbZ': np.mean(lmbZ, axis=0)})

        else:
            rob_params.update({'tauZ_inf': tauZ, 'lmbZ': np.mean(lmbZ, axis=0)})

    return rob_params, X_rob, success

In [None]:
def get_infparams(samples, hypers, args):

    ## Inferred parameters

    nchs, nsamples = args.num_chains, args.num_samples
    N = samples['Z'].shape[1]
    K = args.K

    ## Initialization of Latent Variables and Parameters for Z

    # List of arrays to store inferred Z parameters for each chain (posterior samples of the Z parameters for each chain)
    Z_inf = [np.zeros((nsamples, N, K)) for _ in range(nchs)]

    if args.model == 'sparseGFA':

        # List of arrays to store inferred local shrinkage parameters (λ_Z) for latent variables for each chain.
        lmbZ_inf = [np.zeros((nsamples, N, K)) for _ in range(nchs)]

        # List of arrays to store inferred global shrinkage parameters (τ_Z) parameters for each chain.
        tauZ_inf = [np.zeros((nsamples, K)) for _ in range(nchs)]

        if args.reghsZ:
            # List of arrays to store inferred slab scale parameters (cZ) for each chain.
            cZ_inf = [np.zeros((nsamples, K)) for _ in range(nchs)]

    ## Initialization of Loading Matrices and Parameters for W

    D = sum(hypers['Dm'])

    print(f"Dm array: {hypers['Dm']}")

    # List of arrays to store inferred loading matrices (W) for each chain.
    W_inf = [np.zeros((nsamples, D, K)) for _ in range(nchs)]

    if 'sparseGFA' in args.model:

        # List of arrays to store inferred local shrinkage parameters (λ_W) for each chain.
        lmbW_inf = [np.zeros((nsamples, D, K)) for _ in range(nchs)]

        # List of arrays to store inferred slab scale parameters (cW) for each chain.
        cW_inf = [np.zeros((nsamples, K)) for _ in range(nchs)]

        # Array to store inferred global shrinkage parameters (τ_W) across all samples and chains.
        tauW_inf = np.zeros((nchs * nsamples, args.num_sources))

    elif args.model == 'GFA':

        # List of arrays to store inferred ARD parameters (α) for each chain.
        alpha_inf = [np.zeros((nsamples, K)) for _ in range(nchs)]

    ## Initialization of Noise Parameters

    # Array to store inferred noise parameters (σ) across all samples and chains.
    sigma_inf = np.zeros((nchs * nsamples, args.num_sources))

    ## Initialization of d_comps

    # List of lists to store data components for each chain and component. For each chain (outer list) and each component (inner list), an empty list is initialized.
    d_comps = [[[] for _ in range(K)] for _ in range(nchs)]

    ## Iteration Over Chains and Sample Index Management

    s = 0

    for c in range(nchs):
        inds = np.arange(s, s + nsamples)

        # Extracting sigma:
        sigma_inf[inds, :] = samples['sigma'][inds, 0, :]

        ## Processing Inferred Parameters for Z

        # Iterates over each component to process the inferred parameters for Z.
        for k in range(K):

            if args.model == 'sparseGFA':

                # Extracts and stores the inferred global shrinkage parameter (τ_Z) for the current chain and component.
                tauZ_inf[c][:, k] = np.array(samples[f'tauZ'])[inds, 0, k]

                # Reshapes the τ_Z array to ensure correct dimensional alignment for subsequent operations.
                tauZ = np.reshape(tauZ_inf[c][:, k], (nsamples, 1))

                if args.reghsZ:

                    # Calculates the slab scale parameter (cZ).
                    cZ_inf[c][:, k] = hypers['slab_scale'] * np.sqrt(samples['cZ'][inds, 0, k])

                    # Reshapes the cZ array for correct dimensional alignment.
                    cZ = np.reshape(cZ_inf[c][:, k], (nsamples, 1))

                    # Squares the local shrinkage parameter (λ_Z) values for the current chain and component.
                    lmbZ_sqr = np.square(np.array(samples['lmbZ'])[inds, :, k])

                    # Calculates the adjusted local shrinkage parameter using the horseshoe prior formula.
                    lmbZ_inf[c][:, :, k] = np.sqrt(lmbZ_sqr * cZ ** 2 / (cZ ** 2 + tauZ ** 2 * lmbZ_sqr))

                else:
                    lmbZ_inf[c][:, :, k] = np.array(samples['lmbZ'])[inds, :, k]

                #  Calculate the final latent variables by combining Z, λ_Z, and τ_Z.
                Z_inf[c][:, :, k] = np.array(samples['Z'])[inds, :, k] * lmbZ_inf[c][:, :, k] * tauZ

            else:
                # Directly extract the latent variables without additional shrinkage adjustments.
                Z_inf[c][:, :, k] = np.array(samples['Z'])[inds, :, k]

        """
        print(f"Z_inf for chain {c}")
        print(Z_inf)
        print("------------------------------------------------------------------------------------------")
        """

        ## Processing Inferred Parameters for W for the Sparse GFA model

        if 'sparseGFA' in args.model:

            # Calculates the slab scale parameter (cW) for the loading matrices.
            cW_inf[c] = hypers['slab_scale'] * np.sqrt(samples['cW'][inds, :, :])

            Dm = hypers['Dm']

            d = 0

            for m in range(args.num_sources):

                # Squares the local shrinkage parameter (λ_W) values for the current data source and component.
                lmbW_sqr = np.square(np.array(samples['lmbW'][inds, d:d + Dm[m], :]))

                # Extracts and stores the inferred global shrinkage parameter (τ_W) for the current data source.
                tauW_inf[inds, m] = samples[f'tauW{m + 1}'][inds]

                # Reshapes the τ_W array for correct dimensional alignment.
                tauW = np.reshape(tauW_inf[inds, m], (nsamples, 1, 1))

                # Extracts and Reshapes the cW array for correct dimensional alignment.
                cW = np.reshape(cW_inf[c][:, m, :], (nsamples, 1, K))

                # Calculates the adjusted local shrinkage parameter using the horseshoe prior formula.
                lmbW_inf[c][:, d:d+Dm[m], :] = np.sqrt(cW ** 2 * lmbW_sqr / (cW ** 2 + tauW ** 2 * lmbW_sqr))


                # Calculates the final loading matrices by combining W, λ_W, and τ_W.
                W_inf[c][:, d:d + Dm[m], :] = np.array(samples['W'][inds, d:d + Dm[m], :]) * lmbW_inf[c][:, d:d + Dm[m], :] * tauW

                # Updates the feature index counter.
                d += Dm[m]

            """
            print(f"cW_inf for chain {c}")
            print(cW_inf)
            print("------------------------------------------------------------------------------------------")

            print(f"lmbW_inf for chain {c}")
            print(lmbW_inf)
            print("------------------------------------------------------------------------------------------")

            print(f"tauW_inf for chain {c}")
            print(tauW_inf)
            print("------------------------------------------------------------------------------------------")
            """

        ## Processing Inferred Parameters for GFA

        elif args.model == 'GFA':

            alpha_inf[c] = samples['alpha'][inds, :, :]

            Dm = hypers['Dm']

            d = 0

            for m in range(args.num_sources):
                alpha = np.reshape(alpha_inf[c][:, m, :], (nsamples, 1, K))

                W_inf[c][:, d:d + Dm[m], :] = np.array(samples['W'][inds, d:d + Dm[m], :]) * (1 / np.sqrt(alpha))

                d += Dm[m]

        """
        print(f"samples['W'] for chain {c}")
        print(samples['W'])
        print("------------------------------------------------------------------------------------------")

        print(f"W_inf for chain {c}")
        print(W_inf)
        print("------------------------------------------------------------------------------------------")
        """

        ## Computation of Components in the Data Space

        # Computes the data components for each component k by taking the mean of the inferred Z and W matrices across samples and then computing their dot product.
        for k in range(K):
            # Computes the mean of the Z values across samples for the current chain c and component k, and reshapes it into a column vector.
            z = np.reshape(np.mean(Z_inf[c][:, :, k], axis=0), (N, 1))

            # Computes the mean of the W values across samples for the current chain c and component k, and reshapes it into a column vector.
            w = np.reshape(np.mean(W_inf[c][:, :, k], axis=0), (D, 1))

            # Computes the data components by taking the dot product of z and the transpose of w.
            # Combines latent variables and loading matrices to generate the components in data space, representing the reconstructed data for the current component.
            d_comps[c][k] = np.dot(z, w.T)

        """
        print(f"d_comps for chain {c}")
        print(d_comps)
        print("------------------------------------------------------------------------------------------")
        """

        s += nsamples

    ## Consolidation of Results

    params = {'W': W_inf, 'Z': Z_inf, 'sigma': sigma_inf}

    if 'sparseGFA' in args.model:
        params.update({'lmbW': lmbW_inf, 'tauW': tauW_inf, 'cW': cW_inf})
    elif args.model == 'GFA':
        params.update({'alpha': alpha_inf})

    if args.model == 'sparseGFA':
        if args.reghsZ:
            params.update({'lmbZ': lmbZ_inf, 'tauZ': tauZ_inf, 'cZ': cZ_inf})
        else:
            params.update({'lmbZ': lmbZ_inf, 'tauZ': tauZ_inf})

    """
    print("Final params and d_comps")
    print(params)
    print("------------------------------------------------------------------------------------------")
    print("Final d_comps")
    print(d_comps)
    print("------------------------------------------------------------------------------------------")
    """

    # Returns the dictionary of inferred parameters and computed components in the data space: a list of lists, with each sublist containing arrays of shape [N, D])
    return params, d_comps