In [1]:
import numpy as np

In [2]:
# Should first work out a class of MDP
class MDP:
    
    def __init__(self, T, V, A, B, C, D, E, d, eta, omega, alpha, beta, num_policies, num_factors):
        '''
        Inputs:
            T: number of time steps
            V: allowable (deep) policies
            A: state-outcome mapping
            B: transition probabilities
            C: preferred states
            D: priors over initial states
            E: prior over policies
            eta: learning rate
            omega: forgetting rate
            alpha: action precision
            beta: expected free energy precision
            num_policies: number of policies
            num_factors: number of state factors
        '''
        self.T = T
        self.V = V
        
        self.A = A
        self.B = B
        self.C = C
        self.D = D
        self.E = E
        
        self.d = d
        
        self.eta = eta
        self.omega = omega
        self.alpha = alpha
        self.beta = beta
        
        self.num_policies = num_policies
        self.num_factors = num_factors
    
    @classmethod
    def explore_exploit_model(cls):
        
        # Here we specify 3 time points.
        T = 3
        
        # Here we specify prior probabilities about initial states in the generative process (D).
        # --------------------------------------------------------------------------
        # For the 'context' state factor, we specify that the 'left better' context is the true context.
        # For the 'behavior' state factor, we specify that the agent always starts a trial with 'start' state.
        D = [0, 0]
        D[0] = np.array([[1], [0]])
        D[1] = np.array([[1], [0], [0], [0]])
        
        # --------------------------------------------------------------------------
        # We can also specify prior beliefs about initial states (d)
        d = [0,0]
        # We assume that the agent starts out with equal beliefs in the 'context' state
        d[0] = np.array([
            [0.25], # left better
            [0.25], # right better
        ])
        
        # For behavior state, we always assume that the agent start out in the 'start' state.
        d[1] = np.array([
            [1], # start
            [0], # hint
            [0], # choose-left
            [0], # choose-right
        ])
               
        # Here we specify the probabilities of outcomes given each state in the generative process (A)
        # --------------------------------------------------------------------------
        # First we specify the mapping from states to observed hint (modality 1).
        A = [0,0,0]
        Ns = [0, 0]
        Ns[0] = D[0].shape[0]
        Ns[1] = D[1].shape[0]
        
        A0 = np.zeros((4, 3, 2))
        for i in range(Ns[1]):
            A0[i] = np.array([
                [1, 1],  # No hint
                [0, 0],  # Machine-left hint
                [0, 0],  # Machine-right hint
            ])
            
        # Notice that 'hint' behaviror states generates a hint that either the left or right slot machine is better.
        # In this case, the hints are accturate with a probability pHA.
        pHA = 1
        A0[1] = np.array([
            [0, 0],  # No hint
            [pHA, 1-pHA], # Left slot machine is better
            [1-pHA, pHA], # Right slot machine is better
        ])
        
        A[0] = A0
        
        # We then specify the mapping between states and wins/losses.
        # The first two behaviors states ('start' and 'hint') do not generate outcomes
        A1 = np.zeros((4, 3, 2))
        for i in range(2):
            A1[i] = np.array([
                [1, 1],  # Null
                [0, 0],  # Loss
                [0, 0],  # Win
            ])
        
        # Choosing the left machine (behavior state 3) generates wins with probabily p_win
        p_win = 0.8
        A1[2] = np.array([
            [0, 0],  # Null
            [1-p_win, p_win], # Loss
            [p_win, 1-p_win], # Win
        ])
        
        # Choosing the right machine (behavior state 4) generates wins with probability p_win,
        # with reverse mapping to context states from choosing the left machine
        A1[3] = np.array([
            [0, 0],  # Null
            [p_win, 1-p_win], # Loss
            [1-p_win, p_win], # Win
        ])
        
        A[1] = A1
        
        # Finally. we specify the mapping between behavior states and observed behaviors.
        A2 = np.zeros((4, 4, 2))
        for i in range(Ns[1]):
            A2[i] = np.zeros((4, 2))
            A2[i][i] = np.array([1, 1])
        
        A[2] = A2
        
        # Here we specify the probalistic transitions between hidden states under each action (B).
        # -------------------------------------------------------------------------
        # Columns are states at time t, rows are states at time t+1.
        B = [0, 0]
        
        # The agent cannot control the context state, so there is only 1 'action',
        # indicating that contexts remain stable within a trial.
        B0 = np.array([[1,0],[0,1]])
        B[0] = B0
        
        # The agent can control the behavior state, we have 4 possible actions:
        # 1. Move to the start state from any other state.
        # 2. Move to the Hint state from any other state.
        # 4. Move to the Choose Left state from any other state.
        # 5. Move to the Choose Right state from any other state.
        B1 = np.zeros((4,4,4))
        for i in range(Ns[1]):
            B1[i] = np.zeros((4,4))
            B1[i][i] = np.array([1,1,1,1])
        
        B[1] = B1
        
        # We here specify the 'prior preferences' (C), encoded here as log probabilities.
        # ---------------------------------------------------------------------------
        # One matrix per outcome modality. Each row is an observation, each column is a time point.
        No = [A[0].shape[1], A[1].shape[1], A[2].shape[1]] # number of outcomes in each outcome modality
        C = [0, 0, 0]
        
        # We start by setting 0 preference for all outcomes
        C[0] = np.zeros((No[0], T)) # hints
        C[1] = np.zeros((No[1], T)) # wins/losses
        C[2] = np.zeros((No[2], T)) # observed behaviors
        
        # Then we can specify a 'loss aversion' magnitude (la) at time points 2 
        # and 3, and a 'reward seeking' (or 'risk-seeking') magnitude (rs). Here,
        # rs is divided by 2 at the third time point to encode a smaller win ($2
        # instead of $4) if taking the hint before choosing a slot machine.
        la = 1
        rs = 4
        C[1] = np.array([
            [0, 0, 0], # null
            [0, -la, -la], # loss
            [0, rs, rs/2], # win
        ])

        # Here we specify the the policies (V) .
        # ------------------------------------------------------------------------------
        # Here we specify the the policies (V) .
        # Each policy is just a sequence of actions.
        # In our case, rows correspond to time points.
        num_policies = 5 # number of policies
        num_factors = 2 # number of factors
        
        V = [0, 0]
        
        V[0] = np.array([
            [1, 1, 1, 1, 1],
            [1, 1, 1, 1, 1], 
        ]) # context state is not controllable
        
        V[1] = np.array([
            [1, 2, 2, 3, 4],
            [1, 3, 4, 1, 1],
        ])
        
        # Here we specify the habits of the agent (E).
        # -------------------------------------------------------------------------------
        # Here we specify the habits of the agent (E).
        # We will not equip the agent with habits with any starting habits.
        E = [np.ones((5, 1))]
        
        ## Here we specify all other constants
        ## -------------------------------------------------------------------------------
        # Learning rate
        eta = 1
        
        # Forgetting rate
        omega = 1
        
        # Expected precision of expected free energy (G) over policies
        beta = 1
        
        # Alpha: An 'inverse temperature' or 'action precision' parameter that
        # controls how much randomness there is when selecting actions.
        alpha = 32
        
        return cls(T, V, A, B, C, D, E, d, eta, omega, alpha, beta, num_policies, num_factors)
    
    def message_passing_and_policy_selection(self):
        
        # Store initial parameter values of generative model for free energy calculations.
        # -------------------------------------------------------------------------------
        # 'Complexity' of d vector concentration parameters
        d_prior = [0 for i in range(self.num_factors)]
        d_complexity = [0 for i in range(self.num_factors)]
        if hasattr(self, 'd'):
            for i in range(self.num_factors):
                d_prior[i] = self.d[i]
                d_complexity[i] = self._spm_wnorm(d_prior[i])
        
        # We here normalize the matrices, so that they can actually be treated as probabilities
        # -------------------------------------------------------------------------------
        # Normalize A matrix
        self._col_norm(self.A)
        
        # Normalize B matrix
        self._col_norm(self.B)
        
        # Normalize C matrix
        for i in range(len(self.C)):
            self.C[i] = self.C[i] + 1/32;
            for t in range(self.T):
                self.C[i][:, t] = np.log(np.exp(self.C[i][:, t])/np.sum(np.exp(self.C[i][:, t]))+np.exp(-16))
        
        # Normalize D matrix 
        if hasattr(self, 'd'):
            self._col_norm(self.d)
        else:
            self._col_norm(self.D)
            
        # Normalize E vector
        self.E = self.E/np.sum(self.E)
        
        # We here initialize variables.
        # -------------------------------------------------------------------------------
        num_modalities = len(self.A)
        num_states = [0 for i in range(self.num_factors)] # the number of hidden states
        num_controllable_transitions = [0 for i in range(self.num_factors)] # number of hidden controllable hidden states for each factor
        for i in range(self.num_factors):
            if len(self.B[i].shape) == 2:
                num_states[i] = self.B[i].shape[0]
                num_controllable_transitions[i] = 1
            elif len(self.B[i].shape) > 2:
                num_states[i] = self.B[i].shape[1]
                num_controllable_transitions[i] = self.B[i].shape[0]
            else:
                print("The rank of matrix B is not correct")
        
        # Initialize the approximate posterior over states given policies for
        # each factor as a flat distribution over states at each time point.
        state_posterior = []
        for i in range(self.num_factors):
            state_posterior.append(np.ones((self.num_policies, num_states[i], self.T))/num_states[i])
    
        # Initialize the approximate posterior over policies as a flat distribution over policies at each time point
        policy_posteriors = np.ones((self.num_policies, self.T))/self.num_policies
        
        # Initialize posterior over actions
        chosen_action = np.zeros((len(self.B), self.T-1))
        
        # if there is only one policy
        for i in range(self.num_factors):
            if num_controllable_transitions[i] == 1:
                chosen_action[i,:] = np.ones((1, self.T-1))
        setattr(self, 'chosen_action', chosen_action)
        
        # Intialize expected free energy precision (beta)
        posterior_beta = 1
        gamma = []
        gamma.append(1/posterior_beta) # expected free energy precision
        
        # Messgae passing variables
        time_const = 4 # time constant for gradient descent
        num_iterations = 16 # number of message passing iterations
        
        # We here finally come to perform message passing and policy selection.
        # -------------------------------------------------------------------------------
        
        # Here we first initialize all the matrices we are going to use
        true_states = np.zeros((self.num_factors, self.T))
        outcomes = np.zeros((num_modalities, self.T))
        O = [[0 for i in range(num_modalities)] for j in range(self.T)]
        normalized_firing_rates = [np.zeros((self.T, self.T, self.num_policies, num_iterations, num_states[0])),np.zeros((self.T, self.T, self.num_policies, num_iterations, num_states[1]))]
        prediction_error = [np.zeros((self.T, self.T, self.num_policies, num_iterations, num_states[0])),np.zeros((self.T, self.T, self.num_policies, num_iterations, num_states[1]))]
        Ft = np.zeros((self.T, self.num_factors, self.T, num_iterations)) # variational free energy at each time point
        F = np.zeros((self.num_policies, self.T)) # varational free energy
        G = np.zeros((self.num_policies, self.T))
        expected_states = [0 for i in range(self.num_factors)]
        
        for t in range(self.T): # loop over time points
            
            # sample generative process
            # -------------------------------------------------------------------------------
            
            # Here we sample from the prior distribution over states to obtain the state at each time point.
            for factor in range(self.num_factors):
                prob_state = 0
                if t == 0:
                    prob_state = self.D[factor]
                elif t > 0:
                    if factor == 0:
                        j = int(self.chosen_action[factor, t-1])-1
                        prob_state = np.reshape(self.B[factor][:,j], (-1, 1))
                    elif factor > 0:
                        i = int(self.chosen_action[factor, t-1])-1
                        k = int(true_states[factor, t-1])-1
                        prob_state = np.reshape(self.B[factor][i,:,k], (-1, 1))        
                prob_state = np.reshape(np.cumsum(prob_state, axis=0), (-1))
                index_to_use = np.where(prob_state > np.random.random())[0][0]
                true_states[factor, t] = prob_state[index_to_use]
            
            # Here we sample observations
            num_modalities = len(self.A)
            for modality in range(num_modalities):
                i = int(true_states[1, t]) - 1
                k = int(true_states[0, t]) - 1
                index_to_use = np.where(np.cumsum(self.A[modality][i, :, k].reshape((-1,1)), axis=0).reshape(-1)>np.random.random())[0][0]
                outcomes[modality, t] = np.cumsum(self.A[modality][i, :, k].reshape((-1,1)), axis=0)[index_to_use]
            
            # Here we express observations as a structure containing 1 x observations vector for each
            # modality with a 1 in the position corresponding to the observation received on the trial
            for modality in range(num_modalities):
                vec = np.zeros((1, self.A[modality].shape[1]))
                index = int(outcomes[modality, t])-1
                vec[0, index] = 1
                O[modality][t] = vec
            
            # Marginal message passing (minimize F and infer posterior over states)
            # -------------------------------------------------------------------------------
            for policy in range(self.num_policies):
                for Ni in range(num_iterations):
                    for factor in range(self.num_factors):
                        lnAo = np.zeros(state_posterior[factor].shape) # initialise matrix containing the log likelihood of observations
                        for tau in range(self.T):
                            v_depolarization = self._nat_log(state_posterior[factor][policy, :, tau]).reshape(-1,1) # convert approximate posteriors into depolarisation variable v
                            if tau < t+1:
                                for modal in range(num_modalities):
                                    lnA = np.transpose(np.expand_dims(self._nat_log(self.A[modal][:,int(outcomes[modal,tau]-1),:]), axis=1), [1, 2, 0])
                                    for fj in range(self.num_factors):
                                        if fj != factor:
                                            lnAs = self._md_dot(np.squeeze(lnA, axis=0), state_posterior[fj][0, :, tau].reshape(-1, 1), fj)
                                            lnA = lnAs
                                    lnAo[0, :, tau] = lnAo[0, :, tau] + lnA.reshape(-1)
                            
                            # 'forwards' and 'backwards' messages at each tau
                            if tau == 0: # first tau
                                lnD = self._nat_log(self.d[factor]).reshape(-1,1) # forwards message
                                if factor == 0:
                                    lnBs = self._nat_log(np.matmul(self._B_norm(self.B[factor][:,:].T),state_posterior[factor][policy,:,tau+1])).reshape(-1,1) # backwards message
                                else:
                                    lnBs = self._nat_log(np.matmul(self._B_norm(self.B[factor][self.V[factor][tau,policy]-1,:,:].T), state_posterior[factor][policy,:,tau+1])).reshape(-1,1)
                            
                            elif tau == self.T-1: # last tau
                                if factor == 0:
                                    lnD = self._nat_log(np.matmul(self.B[0][:,:], state_posterior[factor][policy,:,tau-1])).reshape(-1,1) # forwards message
                                else:
                                    lnD = self._nat_log(np.matmul(self.B[factor][self.V[factor][tau-1,policy]-1,:,:], state_posterior[factor][policy,:,tau-1])).reshape(-1,1)
                                lnBs = np.zeros(self.d[factor].shape)
                            
                            else: # T-1 > tau > 0
                                if factor == 0:
                                    lnD = self._nat_log(np.matmul(self.B[0][:,:], state_posterior[factor][policy,:,tau-1])).reshape(-1,1) # forwards message
                                    lnBs = self._nat_log(np.matmul(self._B_norm(self.B[factor][:,:].T),state_posterior[factor][policy,:,tau+1])).reshape(-1,1) # backwards message
                                else:
                                    lnD = self._nat_log(np.matmul(self.B[factor][self.V[factor][tau-1,policy]-1,:,:], state_posterior[factor][policy,:,tau-1])).reshape(-1,1)
                                    lnBs = self._nat_log(np.matmul(self._B_norm(self.B[factor][self.V[factor][tau,policy]-1,:,:].T), state_posterior[factor][policy,:,tau+1])).reshape(-1,1)
                                    
                            # we then combine both the messages and do a gradient descent on the posterior
                            v_depolarization = v_depolarization + (0.5*lnD + 0.5*lnBs + lnAo[0,:,tau].reshape(-1,1) - v_depolarization)/time_const
                            #print(v_depolarization)
                            # variational free energy at each time point
                            a = state_posterior[factor][policy,:,tau]
                            b = (0.5*lnD+0.5*lnBs-lnAo[0,:,tau].reshape(-1,1)).reshape(-1)-self._nat_log(state_posterior[factor][policy,:,tau])
                            Ft[t, factor, tau, Ni] = np.dot(a,b)
                            # update posterior by running v through a softmax
                            state_posterior[factor][policy,:,tau] = ((np.exp(v_depolarization))/np.sum(np.exp(v_depolarization), axis=0)[0]).reshape(-1)
                            # store state_positerior from each epoch of gradient descent for each tau
                            normalized_firing_rates[factor][tau,t,policy,Ni,:] = state_posterior[factor][policy,:,tau]
                            # store v from each epoch of gradient descent for each tau
                            prediction_error[factor][tau,t,policy,Ni,:] = v_depolarization.reshape(-1)
                
                # variational free energy for each policy
                F_intermidiate = np.sum(Ft, axis=1) # sum over state factors
                F_intermidiate = np.squeeze(np.sum(F_intermidiate, 1)).T # sum over tau then squeeze into 16x3 matrix
                # store the value of the message pass at the last iteration into the variational free energy
                F_intermidiate_flatten = F_intermidiate.flatten()
                F[policy, t] = F_intermidiate_flatten[np.flatnonzero(F_intermidiate)[-1]]
            
            # Expected free energy (G) under each policy
            # -------------------------------------------------------------------------------
            
            # Initialize intermediate expected free energy variable for each policy
            G_intermediate = np.zeros((self.num_policies,1))
            # Policy horizon for 'counterfactual rollout' for deep policies
            horizon = self.T
            
            # Do the loop through policies 
            for policy in range(self.num_policies):
                
                # Bayesian superise about 'd'
                if hasattr(self, 'd'):
                    for factor in range(self.num_factors):
                        G_intermediate[policy] = G_intermediate[policy] - np.dot(d_complexity[factor].reshape(-1), state_posterior[factor][policy,:,0].reshape(-1))
                
                # We then come to calculate the expected free energy from time t to the policy horizon.
                for timestep in range(t, horizon):  
                    # store the expected states from each policy and time
                    for factor in range(self.num_factors):
                        expected_states[factor] = state_posterior[factor][policy,:,timestep].reshape(-1,1)
                        
                    # calculate Bayesian surprise then add it the expected free energy
                    G_intermediate[policy] = G_intermediate[policy] + self._G_epistemic_value(self.A[:], expected_states[:])
                    
                    for modality in range(num_modalities):
                        predictive_observations_posterior = self._cell_md_dot(self.A[modality], expected_states)
                        G_intermediate[policy] = G_intermediate[policy] + (predictive_observations_posterior.T).dot(self.C[modality][:,t])
            
            G[:,t] = G_intermediate.squeeze()
            print(G)
                            
                            
    # Normalize vector columns
    def _col_norm(self, input_matrices):
        num_factors = len(input_matrices)
        # output_matrices = [0 for i in range(num_factors)]
        for i in range(num_factors):
            num = input_matrices[i].shape
            if num[1] == 1:
                z = np.sum(input_matrices[i], axis=0)
                input_matrices[i] = input_matrices[i]/z
            else:
                for j in range(num[0]):
                    z = np.sum(input_matrices[i][j], axis=0)
                    input_shape = input_matrices[i][j].shape
                    if (input_matrices[i][j] - np.zeros(input_shape) == np.zeros(input_shape)).all():
                        continue
                    input_matrices[i][j] = input_matrices[i][j]/z
            # output_matrices[i] = input_matrices[i]
        return
    
    # This function substracts the inverse of each column entry
    # from the inverse of the sum of the columns and then divide by 2.
    def _spm_wnorm(self, input_matrix):
        input_matrix = input_matrix + np.exp(-16)
        input_matrix = (1/np.sum(input_matrix, axis=0) - 1/input_matrix)/2
        return input_matrix
    
    # Natural log that replaces zero values with very small values for numerical reasons.
    def _nat_log(self, x):
        return np.log(x+np.exp(-16))
    
    # Dot product along dimension f
    def _md_dot(self, A, s, f):
        if f == 0:
            matrix_to_return = np.matmul(A.T, s)
        elif f == 1:
            matrix_to_return = np.matmul(A, s)
        return matrix_to_return
    
    def _cell_md_dot(self, X, x):
        DIM = np.array([i for i in range(len(x))]).reshape(1,len(x)) + np.ndim(X) - len(x)
        for d in range(len(x)):
            s = [1 for i in range(np.ndim(X))]
            s[DIM[0][d]] = np.prod(list(x[d].shape))
            s = s[2:] + s[0:2] # because numpy and matlab has different arrangement for the shape of matrices,we have consider this fact when reshaping numpy arrays.
            x[d] = x[d].reshape(*s)
            X = X*(x[d].reshape(*s))
            if DIM[0][d]%2 != 0:
                X = np.expand_dims(np.sum(X, axis=DIM[0][d]+1), axis=DIM[0][d]+1)
            else:
                X = np.expand_dims(np.sum(X, axis=0), axis=0)
            #X = np.expand_dims(np.sum(X, axis=DIM[0][d]+1), axis=DIM[0][d]+1)
        X = X.squeeze()
        return X
        
    
    # Normalize the elements of B transpose as required by MMP
    def _B_norm(self, B):
        bb = B
        z = np.sum(bb, axis=0)
        bb = bb/z
        np.nan_to_num(bb, copy=False, nan=0)
        b = bb
        return b
    
    # Calculate the Bayesian surprise in expected free energy
    def _G_epistemic_value(self, A, s):
        qx = self._spm_cross(s)
        qx_flatten = qx.flatten()
        G = 0
        qo = 0
        for i in np.where(qx_flatten>np.exp(-16))[0]:
            po = 1
            for g in range(len(A)):
                A_list = [A[g][i][:,j] for i in range(A[g].shape[0]) for j in range(A[g].shape[-1])]
                po = np.kron(po,A_list[i])
            po = po[:]
            qo = qo + qx_flatten[i]*po
            G = G + qx_flatten[i]*po.T.dot(self._nat_log(po))
        G = G - qo.T.dot(self._nat_log(qo))
        return G
        
    # The outer product of matrices.
    # This method is buggy when dealing with 1-d arrays.
    def _spm_cross(self, *args):
        
        # dealing with the single case
        def helper_single(*args):
            if len(args) == 1:
                matrices = args[0]
                extension_dims = [1 for i in range(len(args[0]))]
                if isinstance(matrices, np.ndarray):
                    Y = matrices
                elif len(matrices) <= 1:
                    Y = matrices[0]
                else:
                    Z = matrices.pop(0)
                    Z_shape = list(Z.shape)
                    if sum(extension_dims) == fixed_length:
                        matrices = matrices[::-1]
                    W = helper_single(*[matrices])
                    W_shape = list(W.shape)
                    W = W.reshape(*W_shape,*list(np.ones(np.ndim(Z),dtype=int)))
                    Y = np.kron(W,Z)
                return Y.squeeze()
        
        # dealing with multiple inputs
        def helper_multiple(*args):
            Y_list = []
            for i in range(len(args)):
                matrices = args[i]
                extension_dims = [1 for i in range(len(matrices))]
                if isinstance(matrices, np.ndarray):
                    Y = matrices
                elif len(matrices) <= 1:
                    Y = matrices[0]
                else:
                    Z = matrices.pop(len(matrices)-1)
                    Z_shape = list(Z.shape)
                    if sum(extension_dims) == fixed_length_list[i]:
                        matrices = matrices[::-1]
                    W = helper_single(*[matrices])
                    Z = Z.reshape(*Z.shape,*list(np.ones(np.ndim(W),dtype=int)))
                    W_shape = list(W.shape)
                    Y = np.kron(W,Z)
                Y_list.append(Y)
            V = self._spm_cross(Y_list)
            return V

        if len(args) == 1:
            fixed_length = len(args[0])
            return helper_single(*args)
        else:
            fixed_length_list = [len(args[i]) for i in range(len(args))]
            return helper_multiple(*args)

In [3]:
MDP = MDP.explore_exploit_model()

In [4]:
MDP.message_passing_and_policy_selection()

[[-9.72660102  0.          0.        ]
 [-8.84199649  0.          0.        ]
 [-8.84199649  0.          0.        ]
 [-9.53413628  0.          0.        ]
 [-9.53413628  0.          0.        ]]
[[ -9.72660102 -12.00513211   0.        ]
 [ -8.84199649 -10.31465811   0.        ]
 [ -8.84199649 -10.31465811   0.        ]
 [ -9.53413628 -12.00513206   0.        ]
 [ -9.53413628 -12.00513206   0.        ]]
[[ -9.72660102 -12.00513211  -3.65102613]
 [ -8.84199649 -10.31465811  -3.6510261 ]
 [ -8.84199649 -10.31465811  -3.6510261 ]
 [ -9.53413628 -12.00513206  -3.65102613]
 [ -9.53413628 -12.00513206  -3.65102613]]




In [5]:
a = np.array([[1,2],[3,4]])

In [None]:
# def spm_cross(*args):
    
#     def helper_single(*args):
#         if len(args) == 1:
#             matrices = args[0]
#             extension_dims = [1 for i in range(len(args[0]))]
#             if isinstance(matrices, np.ndarray):
#                 Y = matrices
#             elif len(matrices) <= 1:
#                 Y = matrices[0]
#             else:
#                 Z = matrices.pop(0)
#                 Z_shape = list(Z.shape)
#                 if sum(extension_dims) == fixed_length:
#                     matrices = matrices[::-1]
#                 W = helper_single(*[matrices])
#                 W_shape = list(W.shape)
#                 W = W.reshape(*W_shape,*list(np.ones(np.ndim(Z),dtype=int)))
#                 Y = np.kron(W,Z)
                
#             return Y.squeeze()

#     def helper_multiple(*args):
#         Y_list = []
#         for i in range(len(args)):
#             matrices = args[i]
#             extension_dims = [1 for i in range(len(matrices))]
#             if isinstance(matrices, np.ndarray):
#                 Y = matrices
#             elif len(matrices) <= 1:
#                 Y = matrices[0]
#             else:
#                 Z = matrices.pop(len(matrices)-1)
#                 Z_shape = list(Z.shape)
#                 if sum(extension_dims) == fixed_length_list[i]:
#                     matrices = matrices[::-1]
#                 W = helper_single(*[matrices])
#                 Z = Z.reshape(*Z.shape,*list(np.ones(np.ndim(W),dtype=int)))
#                 W_shape = list(W.shape)
#                 Y = np.kron(W,Z)
            
#             Y_list.append(Y)
 
#         V = spm_cross(Y_list)
#         return V
              
#     if len(args) == 1:
#         fixed_length = len(args[0])
#         return helper_single(*args)
#     else:
#         fixed_length_list = [len(args[i]) for i in range(len(args))]
#         return helper_multiple(*args)