In [2]:
import numpy as np
import copy
from scipy import optimize


class TransitionRateModel:
    """
    A model for learning non-interactive and interactive transition rate matrices
    in a multi-layer system for a classification task.
    """

    def __init__(self, total_layers):
        """
        Initializes the model with the specified number of layers.

        Args:
            total_layers (int): Total number of layers in the system.
        """
        self.total_layers = total_layers

    @staticmethod
    def create_layer_rate_matrix(rate_constants):
        """
        Creates a rate matrix for a layer given rate constants.

        Args:
            rate_constants (tuple): A tuple (k1, k2) representing rate constants.

        Returns:
            np.ndarray: The layer rate matrix.
        """

        k1, k2 = rate_constants
        return np.array([[-k2, k1], [k2, -k1]])

    def kronecker_product_with_identity(self, matrix, position):
        """
        Places a matrix in a specific position in a Kronecker product with identity matrices.

        Args:
            matrix (np.ndarray): The matrix to place in the Kronecker product.
            position (int): Position in the Kronecker product (0 = rightmost).

        Returns:
            np.ndarray: Resulting matrix after the Kronecker product.
        """
        result = np.eye(1)
        for i in reversed(range(self.total_layers)):
            if i == position:
                result = np.kron(result, matrix)
            else:
                result = np.kron(result, np.eye(matrix.shape[0]))
        
        return result

    def compute_non_interactive_w(self, layer_matrices):
        """
        Computes the non-interactive transition rate matrix.

        Args:
            layer_matrices (list): List of rate matrices for each layer.

        Returns:
            np.ndarray: The non-interactive transition rate matrix.
        """
        layer_dim = layer_matrices[0].shape[0]
        w_matrix = np.zeros((layer_dim ** self.total_layers, layer_dim ** self.total_layers))

        for i in range(len(layer_matrices)):
            w_matrix += self.kronecker_product_with_identity(layer_matrices[i], i)
        return w_matrix

    def kronecker_with_two_matrices(self, matrix1, pos1, matrix2, pos2):
        """
        Places two matrices at specified positions within a Kronecker product.

        Args:
            matrix1 (np.ndarray): First matrix.
            pos1 (int): Position for the first matrix.
            matrix2 (np.ndarray): Second matrix.
            pos2 (int): Position for the second matrix.

        Returns:
            np.ndarray: Resulting matrix after the Kronecker product.
        """
        result = np.eye(1)
        for i in reversed(range(self.total_layers)):
            if i == pos1:
                result = np.kron(result, matrix1)
            elif i == pos2:
                result = np.kron(result, matrix2)
            else:
                result = np.kron(result, np.eye(matrix1.shape[0]))
        return result

    def compute_interactive_w(self, target_layer=None, target_edge=None, influencing_layer=None, influencing_state=None, interaction_strength=None):
        """
        Computes the interactive transition rate matrix with specified interactions.

        Args:
            target_layer (int): Layer influenced by the interaction.
            target_edge (str): Edge ("AB" or "BA") influenced by the interaction.
            influencing_layer (int): Layer that provides the influence.
            influencing_state (str): State ("A" or "B") of the influencing layer.
            interaction_strength (float): Interaction strength.

        Returns:
            np.ndarray: The interactive transition rate matrix.
        """
        if interaction_strength == None:
          return np.zeros((2 ** self.total_layers, 2 ** self.total_layers))
        else:
          influence_B = np.array([[0, 0], [0, interaction_strength]])
          influence_A = np.array([[interaction_strength, 0], [0, 0]])
          edge_AB = np.array([[-1, 0], [1, 0]])
          edge_BA = np.array([[0, 1], [0, -1]])

          w_matrix = np.zeros((2 ** self.total_layers, 2 ** self.total_layers))

          if influencing_state == "B" and target_edge == "AB":
              w_matrix += self.kronecker_with_two_matrices(edge_AB, target_layer, influence_B, influencing_layer)
          elif influencing_state == "B" and target_edge == "BA":
              w_matrix += self.kronecker_with_two_matrices(edge_BA, target_layer, influence_B, influencing_layer)
          elif influencing_state == "A" and target_edge == "AB":
              w_matrix += self.kronecker_with_two_matrices(edge_AB, target_layer, influence_A, influencing_layer)
          elif influencing_state == "A" and target_edge == "BA":
              w_matrix += self.kronecker_with_two_matrices(edge_BA, target_layer, influence_A, influencing_layer)

        return w_matrix


    def compute_omega_matrix(self, layer_matrices, interaction_data):
        # Initialize lists to store the non-interactive and interactive omega matrices.
        non_int_omega_matrices = []
        int_omega_matrices = []
        

        # Iterate over each layer to compute the corresponding omega matrices.
        for i in range(self.total_layers):
            # Compute the non-interactive omega matrix for the current layer using a Kronecker product.
            non_omega = self.kronecker_product_with_identity(layer_matrices[i], i)
            
            # Initialize the interactive omega matrix as a zero matrix with the same shape as non_omega.
            int_omega = np.zeros_like(non_omega)

            # Accumulate the interactive contributions for the current layer.
            # Each `idata` contains interaction data relevant to this layer.
            for idata in interaction_data[i]:
                # Compute the contribution from the interaction data and add it to the interactive omega matrix.
                # print(int_omega)
                int_omega += self.compute_interactive_w(*idata)

            # Append the computed non-interactive and interactive omega matrices to their respective lists.
            non_int_omega_matrices.append(non_omega)
            int_omega_matrices.append(int_omega)

        # Return both the non-interactive and interactive omega matrices.
        return non_int_omega_matrices, int_omega_matrices

    def get_random_network(self, interval, interaction_terms, factor):
        # Non-interective rate constants
        constants = (interval[1] - interval[0]) * np.random.rand(self.total_layers, 2) + interval[0]
        # Randomly choose interaction layers, edges and interaction strengths.
        all_possible_pairs = [(i, j) for i in range(self.total_layers) for j in range(self.total_layers) if i != j]
        np.random.shuffle(all_possible_pairs)  # shuffle to randomize
        # print(len(all_possible_pairs))
        selected_pairs = all_possible_pairs[:interaction_terms] # select the pairs of layers.
        edge = np.random.choice(["AB", "BA"], size=interaction_terms) # select the influenced edges.
        state_influencing = np.random.choice(["A", "B"], size=interaction_terms) # select the states that are influencing.
        interaction_strength = (interval[1] - interval[0]) * factor * np.random.rand(interaction_terms) + interval[0] * factor
        # randomly select time-scales.
        time_scales = 20 * (interval[1] - interval[0]) * np.random.rand(self.total_layers) + interval[1]
        time_scales = np.sort(time_scales)
        #time_scales = np.random.choice([])
        print(f'Time scales: {time_scales}')
        # Accumulate the interaction data in the format that would be used to calculate matrices.
        interaction_data = []
        for i in range(self.total_layers):
            layer_interactions = []
            for j in range(interaction_terms):
                target_layer, influencing_layer = selected_pairs[j]
                if target_layer == i:
                    layer_interactions.append([target_layer, edge[j], influencing_layer, state_influencing[j], interaction_strength[j] / time_scales[i]]) #Example interaction parameters
            interaction_data.append(layer_interactions)
        print(interaction_data)
        # Calculate the free-layers matrices.
        layer_matrices = [self.create_layer_rate_matrix(constants[i] / time_scales[i]) for i in range(self.total_layers)]
        # Calculate the omega and interaction matrices.
        omega_matrices, interaction_matrices = self.compute_omega_matrix(layer_matrices, interaction_data)

        return omega_matrices, interaction_matrices

    def get_fully_connected_network(self, interval):
          # Non-interective rate constants
          constants = (interval[1] - interval[0]) * np.random.rand(self.total_layers, 2) + interval[0]
          all_possible_pairs = [(i, j) for i in range(self.total_layers) for j in range(self.total_layers) if i != j]
          edges = [["AB", "BA"]] * len(all_possible_pairs)
          state_influencing = [["A", "B"]] * len(all_possible_pairs)
          interaction_strengths = ((interval[1] - interval[0]) * np.random.rand(4 * len(all_possible_pairs)) + interval[0]).reshape(-1, 4)
          
          time_scales = 20 * (interval[1] - interval[0]) * np.random.rand(self.total_layers) + interval[1]
          time_scales = np.sort(time_scales)
          
          interaction_data = []
          for i in range(self.total_layers):
              layer_interactions = []
              for j in range(len(all_possible_pairs)):
                  target_layer, influencing_layer = all_possible_pairs[j]
                  #print(target_layer)
                  if target_layer == i:
                    m=0
                    for state in state_influencing[j]:
                      for edge in edges[j]:
                        layer_interactions.append([target_layer, edge, influencing_layer, state, interaction_strengths[j][m] / time_scales[i]]) #Example interaction parameters
                        m+=1

              interaction_data.append(layer_interactions)
          print(interaction_data)
          layer_matrices = [self.create_layer_rate_matrix(constants[i] / time_scales[i]) for i in range(self.total_layers)]
          # Calculate the omega and interaction matrices.
          omega_matrices, interaction_matrices = self.compute_omega_matrix(layer_matrices, interaction_data)

          return layer_matrices, interaction_data, omega_matrices, interaction_matrices



    @staticmethod
    def eigenvector_for_zero_eigenvalue(matrix):
      # Compute eigenvalues and eigenvectors
      eigenvalues, eigenvectors = np.linalg.eig(matrix)

      # Find the index of eigenvalue 0
      index = np.where(np.isclose(eigenvalues, 0))[0]

      if index.size > 0:
          # Retrieve the eigenvector corresponding to eigenvalue 0
          normalized_zero_eigenvector = eigenvectors[:, index[0]].flatten() / np.sum(eigenvectors[:, index[0]].flatten())
          return normalized_zero_eigenvector
      else:
          print("No eigenvalue 0 found in the matrix.")
          return None

    def get_marginal_probability(self, layer_number, omega_matrices, interaction_matrices):
      # layer numbers are labeled from 0 to n-1
      transition_rate_matrix = np.sum(omega_matrices, axis=0) + np.sum(interaction_matrices, axis=0)
    
      p_stationary = self.eigenvector_for_zero_eigenvalue(transition_rate_matrix)
    
      if (self.total_layers - layer_number) == 1:
        p_marginal_for_state_A = np.sum(p_stationary[: 2**(self.total_layers-1)])
        return p_marginal_for_state_A
      elif (self.total_layers - layer_number) == 2:
        p_marginal_for_state_A = np.sum(p_stationary[: 2**(self.total_layers-2)]) + np.sum(p_stationary[2**(self.total_layers-1): 2**(self.total_layers-1) + 2**(self.total_layers-2)])
        return p_marginal_for_state_A

      return "error"



    def generate_constraints(self, k_data, time_scales):
        # First, we extract k_intra
        k_intra = list(k_data[:2 * self.total_layers].reshape(self.total_layers, 2))

        # Then, we extract k_inter, assuming each layer interacts with 2 other layers
        k_inter = list(k_data[2 * self.total_layers:].reshape(self.total_layers, self.total_layers - 1, 4))

        def get_total_interaction_for_layer_i(k_inter, i):
            k_AB = 0
            k_BA = 0
            k_inter_copy = copy.deepcopy(k_inter)
            k_inter_copy.pop(i)  # Remove the i-th layer interaction
            for j, layer_inter_data in enumerate(k_inter_copy):
                if j < i:
                    k_AB += layer_inter_data[i-1][0] + layer_inter_data[i-1][2]
                    k_BA += layer_inter_data[i-1][1] + layer_inter_data[i-1][3]
                else:
                    k_AB += layer_inter_data[i][0] + layer_inter_data[i][2]
                    k_BA += layer_inter_data[i][1] + layer_inter_data[i][3]
            return k_AB, k_BA

        def get_total_inter_plus_intra_for_layer_i(k_intra, k_inter, i):
            k_AB_inter, k_BA_inter = get_total_interaction_for_layer_i(k_inter, i)
            k_AB_intra, k_BA_intra = k_intra[i]
            k_AB_total = k_AB_inter + k_AB_intra
            k_BA_total = k_BA_inter + k_BA_intra
            return k_AB_total, k_BA_total

        def interaction_constraint(i, k_AB_total, k_BA_total, time_scales):
            return time_scales[i] - (k_AB_total + k_BA_total)

        constraints = []
        for i in range(len(k_intra)):
            # Get the total interaction for layer i
            k_AB_total, k_BA_total = get_total_inter_plus_intra_for_layer_i(k_intra, k_inter, i)

            # Apply the constraint: total interaction should be less than or equal to time_scales for each layer
            constraints.append({
                'type': 'ineq',  # 'ineq' means the constraint should be <=
                'fun': lambda x, i=i, k_AB_total=k_AB_total, k_BA_total=k_BA_total: interaction_constraint(i, k_AB_total, k_BA_total, time_scales)
            })

        # # Add positivity constraint for all elements in k_data
        # for idx in range(len(k_data)):
        #     constraints.append({
        #         'type': 'ineq',  # Ensure k_data[idx] >= 0
        #         'fun': lambda x, idx=idx: x[idx]  # This ensures x[idx] >= 0
        #     })

        return constraints







    def optimize_the_network(self, inputs, time_scales):
        constants =  np.random.rand(self.total_layers, 2) + 0.0001
        k_intra =[constants[i] / time_scales[i] for i in range(self.total_layers)]
        all_possible_pairs = [(i, j) for i in range(self.total_layers) for j in range(self.total_layers) if i != j]
        k_inter =   np.random.rand(4 * len(all_possible_pairs)).reshape(-1, 4)
       
        k_data = np.concatenate((np.array(k_intra).flatten(), k_inter.flatten()))
       
        constraints = self.generate_constraints(k_data, time_scales)
        # Step: Define the loss function to be minimized (wrap the loss function to work with k_data)
        def loss_function_wrapper(k_data):
            # Reshape k_data back into k_intra and k_inter
            k_intra = k_data[:2 * self.total_layers].reshape(self.total_layers, 2)
           # k_inter = k_data[2 * self.total_layers:].reshape(len(all_possible_pairs), 4)
            k_inter = k_data[2 * self.total_layers:].reshape(self.total_layers, 4, self.total_layers - 1)
       
            # Calculate and return the loss
            return self.loss_function(inputs, k_data, time_scales)


        def verbose_callback(xk):
            print(f"Current solution: {xk}")

        result = optimize.minimize(
            fun=loss_function_wrapper,
            x0=k_data,
            constraints=constraints,
            method='SLSQP',
            options={'disp': True},
            callback=verbose_callback,  # Log at each iteration
            tol=1e-1
            )


        # Step: Return the optimized k_intra and k_inter
        optimized_k_data = result.x
        optimized_k_intra = optimized_k_data[:2 * self.total_layers].reshape(self.total_layers, 2)
        optimized_k_inter = optimized_k_data[2 * self.total_layers:].reshape(self.total_layers, 4, self.total_layers - 1)

        return optimized_k_intra, optimized_k_inter, result.fun  # Optimized parameters and final los




    def get_interaction_data(self, k_inter, time_scales):
        state_edge =  [["A", 'AB'], ["A", "BA"], ["B", "AB"], ["B", "BA"]]
        layers = np.arange(self.total_layers)
        interaction_data = []
        # print(k_inter)
        for layer, layer_inter_k in enumerate(k_inter):
          relevent_layers = layers[layers != layer]
        
          layer_interactions = []
          for enum0, k_target_layer in enumerate(layer_inter_k):
            # k_target_layer_k = [0., 0.2,.0.3, 0.8]
            # the layer number of targeted layer is relevent_layers[enum0]
            targeted_layer_num = relevent_layers[enum0]
            for enum1, k_value in enumerate(k_target_layer):
              interaction_strength = k_value
              state = state_edge[enum1][0]
              edge = state_edge[enum1][1]

              influencing_layer = layer
              layer_interactions.append([targeted_layer_num, edge, influencing_layer, state, interaction_strength / time_scales[layer]])

          interaction_data.append(layer_interactions)

        return interaction_data


    def loss_function(self, inputs, k_data,  time_scales):
        # first n inputs are for one class and N-n observations are for the second class in the input data.
        # k_intra = [[k12_layer0, k21_layer0], [k12_layer1, k21_layer1], ..., [k12_layer_last, k21_layer_last]]
        # k_inter = [
                      #     # Interactions of layer 0 with layers 1 and 2
                      #     [[k_A_AB_layer0_1, k_A_BA_layer0_1, k_B_AB_layer0_1, k_B_BA_layer0_1],
                      #     [k_A_AB_layer0_2, k_A_BA_layer0_2, k_B_AB_layer0_2, k_B_BA_layer0_2]],

                      #     # Interactions of layer 1 with layers 0 and 2
                      #     [[k_A_AB_layer1_0, k_A_BA_layer1_0, k_B_AB_layer1_0, k_B_BA_layer1_0],
                      #     [k_A_AB_layer1_2, k_A_BA_layer1_2, k_B_AB_layer1_2, k_B_BA_layer1_2]],

                      #     # Interactions of layer 2 with layers 0 and 1
                      #     [[k_A_AB_layer2_0, k_A_BA_layer2_0, k_B_AB_layer2_0, k_B_BA_layer2_0],
                      #     [k_A_AB_layer2_1, k_A_BA_layer2_1, k_B_AB_layer2_1, k_B_BA_layer2_1]]
                      # ]

        # First, we extract k_intra
        k_intra = np.exp(k_data[:2 * self.total_layers].reshape(self.total_layers, 2))

        # Then, we extract k_inter, assuming each layer interacts with 2 other layers
        k_inter = np.exp(k_data[2 * self.total_layers:].reshape(self.total_layers, self.total_layers - 1, 4))

        interaction_data = self.get_interaction_data(k_inter, time_scales)

        input_plus_k_intra_data = []
        layer_observed_matrices_all_observ = []
        
        for observation, k_input_layer_int in enumerate(inputs):
            input_plus_k_intra = k_intra + np.exp(k_input_layer_int)
            input_plus_k_intra_data.append(input_plus_k_intra)
        

        for j in range(len(inputs)):
          layer_observed_matrices = [self.create_layer_rate_matrix(input_plus_k_intra_data[j][i] / time_scales[i]) for i in range(self.total_layers)]
          layer_observed_matrices_all_observ.append(layer_observed_matrices)

        p_marg_observed_last_layer_data = []
        p_marg_observed_second_last_layer_data = []
        entropy = []
        for observation, k_input_layer_int in enumerate(inputs):
          omega_matrices, interaction_matrices = self.compute_omega_matrix(layer_observed_matrices_all_observ[observation], interaction_data)
          p_marg_last_observed1 = self.get_marginal_probability(self.total_layers-1, omega_matrices, interaction_matrices)
          p_marg_last_observed2 = self.get_marginal_probability(self.total_layers-2, omega_matrices, interaction_matrices)
          p_marg_observed_last_layer_data.append(p_marg_last_observed1)
          p_marg_observed_second_last_layer_data.append(p_marg_last_observed2)
          # ENTROPY TERM
          ent = - p_marg_last_observed1 * np.log2(p_marg_last_observed1) - (1-p_marg_last_observed1) * np.log2((1-p_marg_last_observed1))
          entropy.append(ent)
        # Length of the data
        n = len(p_marg_observed_last_layer_data)
        
        # First loss computation (for last layer)
        loss_last = (np.array(p_marg_observed_last_layer_data) - np.array([1] * (n // 2) + [0] * (n - n // 2))) ** 2
        # # loss_lastB = (np.array(p_marg_observed_last_layer_data) - np.array([0] * (n // 2) + [1] * (n - n // 2))) ** 2
        # loss = np.sum(np.dot(np.array(p_marg_observed_last_layer_data)[: n//2], np.array(p_marg_observed_last_layer_data)[n//2::]))
        loss_dot_product  = np.array(p_marg_observed_last_layer_data[0]) * np.array(p_marg_observed_last_layer_data[1]) +  np.array(1 - p_marg_observed_last_layer_data[0]) * np.array(1 - p_marg_observed_last_layer_data[1])

        loss = loss_dot_product

        # Second loss computation (for second last layer)
        loss_second_last = (np.array(p_marg_observed_second_last_layer_data) - np.array([0] * (n // 2) + [1] * (n - n // 2))) ** 2
        print("p_marg_observed_last_layer_data")

        print(p_marg_observed_last_layer_data)

        # loss = np.sum(loss_last)  #+ 0.0 * np.sum(entropy) #+ (k_inter.flatten() ** 2).sum()
        # + np.sum(loss_second_last) + (k_inter.flatten() ** 2).sum()
        # print(loss)
        return loss



# Example Usage
model = TransitionRateModel(total_layers=2)

# int_data = [[[1,"AB", 0, "B", 0.0]], [[1, "AB", 0, "B", 0.2], [1, "AB", 0, "A", 0.3]]]

input = [[[10, 0], [0,0]], [[0, 10],  [0, 0]]] #, [[0, 4],  [0, 0]] ]

time_scales = [1, 1]

model.optimize_the_network(input, time_scales)

p_marg_observed_last_layer_data
[0.5490288221541963, 0.42039571766457917]
p_marg_observed_last_layer_data
[0.5490288221541972, 0.42039571766471456]
p_marg_observed_last_layer_data
[0.5490288221540648, 0.4203957176645775]
p_marg_observed_last_layer_data
[0.5490288234713759, 0.4203957193194056]
p_marg_observed_last_layer_data
[0.5490288206156015, 0.420395716512947]
p_marg_observed_last_layer_data
[0.5490288209340214, 0.4203957176644181]
p_marg_observed_last_layer_data
[0.5490288237622276, 0.42039571766502815]
p_marg_observed_last_layer_data
[0.5490288221537806, 0.4203957158818227]
p_marg_observed_last_layer_data
[0.5490288221543278, 0.4203957186801429]
p_marg_observed_last_layer_data
[0.5490288221541252, 0.42039571766457917]
p_marg_observed_last_layer_data
[0.5490288221541982, 0.4203957176646271]
p_marg_observed_last_layer_data
[0.5490288221541587, 0.4203957176645803]
p_marg_observed_last_layer_data
[0.5490288221541987, 0.4203957176646829]
p_marg_observed_last_layer_data
[0.5531464656532

TypeError: 'NoneType' object is not subscriptable