In [11]:
import numpy as np
import sympy as sp

In [12]:
p = 4
q = 4
n = 2

dims_A = np.random.randint(low = 1, high = n+1, size = p)
dims_B = np.random.randint(low = 1, high = n+1, size = q)

A = np.random.randint(-10, 10, size = dims_A)
B = np.random.randint(-10, 10, size = dims_B)

In [13]:
T_comb = np.array([3, 0, 1, 2])
A_T = np.random.choice(T_comb, p, replace=False)
B_T = np.random.choice(T_comb, q, replace=False)
A_transposed = A.transpose(A_T)
B_transposed = B.transpose(B_T)

In [14]:
A.shape

(2, 1, 2, 2)

In [15]:
A_transposed.shape

(2, 2, 2, 1)

In [16]:
lambda_value = 2
mu_value = 2

def mat_mult(A: np.ndarray, B: np.ndarray, lambda_value:int, mu_value:int) -> np.ndarray:
    A_p = len(A.shape)
    B_q = len(B.shape)

    k = A_p - lambda_value - mu_value
    v = B_q - lambda_value - mu_value

    l = A.shape[:k]
    s = A.shape[k:k+lambda_value]
    c = A.shape[k+lambda_value:k+lambda_value+mu_value]
    m = B.shape[mu_value + lambda_value: mu_value + lambda_value + v]

    D = np.zeros((l+m))
    return D




In [30]:
def mat_mult(A: np.ndarray, B: np.ndarray, lambda_value: int, mu_value: int) -> np.ndarray:
    
    A_p = len(A.shape)
    B_q = len(B.shape)

    k = A_p - lambda_value - mu_value
    v = B_q - lambda_value - mu_value

    l = A.shape[:k]
    s_A = A.shape[k:k+lambda_value]
    c_A = A.shape[k+lambda_value:k+lambda_value+mu_value]
    s_B = B.shape[:lambda_value]
    c_B = B.shape[lambda_value:lambda_value+mu_value]

    m = B.shape[lambda_value+mu_value:]
    
    if k < 0 or v < 0 or lambda_value < 0 or mu_value < 0:
        raise ValueError
    if k + lambda_value + mu_value != A_p:
        raise ValueError
    if v + lambda_value + mu_value != B_q:
        raise ValueError
    if s_A != s_B or c_A != c_B:
        raise ValueError

    D_shape = l + m
    D = np.zeros(D_shape if D_shape else (1,))

    l_indices = [range(dim) for dim in l]
    m_indices = [range(dim) for dim in m]
    s_indices = [range(dim) for dim in s_A]
    c_indices = [range(dim) for dim in c_A]

    for l_idx in np.ndindex(l):
        for m_idx in np.ndindex(m):
            sum_result = 0.0
            for s_idx in np.ndindex(s_A):
                for c_idx in np.ndindex(c_A):
                    sum_result += A[l_idx + s_idx + c_idx] * B[s_idx + c_idx + m_idx]
            D[l_idx + m_idx] = sum_result
    
    return np.squeeze(D)

In [31]:
D = mat_mult(A, B, lambda_value=0, mu_value=1)
print(A.shape)
print(B.shape)
print(D.shape)

(2, 1, 2, 2)
(2, 1, 1, 1)
(2, 2)


In [None]:

def create_eye(lambda_value: int, mu_value: int, n: int) -> np.ndarray:

    if lambda_value < 0 or mu_value < 0 or n < 0:
        raise ValueError("lambda_value, mu_value, and n must be non-negative")

    total_dims = 2 * (mu_value + lambda_value)
    shape = (n,) * total_dims
    
    E = np.zeros(shape, dtype=float)
    
    if total_dims == 0:
        return np.array(1.0)
    
    c_indices = (n,) * mu_value 
    m_indices = (n,) * mu_value 
    s_indices = (n,) * lambda_value 
    
    if mu_value == 0:
        for s_idx in np.ndindex(s_indices):  
            E[s_idx + s_idx] = 1.0  
    else:
        for c_idx in np.ndindex(c_indices):  
            for m_idx in np.ndindex(m_indices):  
                if c_idx == m_idx:  
                    for s_idx in np.ndindex(s_indices):
                        full_idx = c_idx + s_idx + m_idx
                        E[full_idx] = 1.0
    
    return E

In [43]:
E_l_m = create_eye(lambda_value=0, mu_value=1, n=2)

In [44]:
mat_mult(D, E_l_m, lambda_value=0, mu_value=1)

array([[ 55.,   3.],
       [-38., -16.]])

In [45]:
D

array([[ 55.,   3.],
       [-38., -16.]])