We have a matrix split into $N^2$ blocks.
We then create $M$ matrix products of linear combinations of individual blocks.
We then evaluate each of the final blocks as a linear combinations
of these $M$ matrix products.

In [118]:
import numpy as np
from scipy.optimize import minimize, Bounds


class Model:
    def __init__(self, blocks_per_side=2, multiplications_count=7):
        self.n = blocks_per_side
        self.m = multiplications_count

    def expected_cs(self):
        cs = np.zeros((self.n**2, self.n**4))
        for i in np.arange(self.n):
            for j in np.arange(self.n):
                for k in np.arange(self.n):
                    # c[i, k] += a[i, j] * b[j, k]
                    cs[i * self.n + j, i * self.n**3 + j * self.n**2 + j * self.n + k] = 1
        return cs

    def diff(self, ls, rs, qs):
        ms = np.zeros((self.m, self.n**4))
        for i in np.arange(self.m):
            ms[i] = np.outer(ls[i], rs[i]).flatten()
        #print(ms)

        ds = np.zeros((self.n**2, self.n**4))
        for i in np.arange(self.n**2):
            ds[i] = qs[i].dot(ms)
        #print(ds)

        cs = self.expected_cs()
        res = np.linalg.norm(cs - ds)
        return res

    def norm(self, x):
        return np.linalg.norm(x, ord=1)

    def target(self, x):
        ls, rs, qs = self.unflatten_x(x)
        return self.diff(ls, rs, qs) + 0.02 * self.norm(x)

    def unflatten_x(self, x):
        ls_raw = x[:self.m * (self.n ** 2)]
        rs_raw = x[self.m * (self.n ** 2):self.m * (self.n ** 2) + self.m * (self.n ** 2)]
        qs_raw = x[self.m * (self.n ** 2) * 2:]
        ls = np.reshape(ls_raw, (self.m, self.n ** 2))
        rs = np.reshape(rs_raw, (self.m, self.n ** 2))
        qs = np.reshape(qs_raw, (self.n ** 2, self.m))
        return ls, rs, qs

    def rand_x0(self):
        return np.random.rand(self.m*(self.n**2)*2 + self.n**2 * self.m)

model = Model(2, 3)

ls = np.array([[1, 0, 0, 0], [0, 1, 0, 1], [1, 2, 0, 0]])
rs = np.array([[1, 0, 0, 0], [0, 1, 0, 1], [1, 2, 0, 0]])
qs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]])

model.diff(ls, rs, qs)

print(np.reshape(np.array([1,2,3,4]), (2,2)))
print(model.rand_x0())
print(model.target(model.rand_x0()))

[[1 2]
 [3 4]]
[0.91498552 0.59913503 0.90286203 0.12203732 0.59619017 0.66789541
 0.79670486 0.27500439 0.36144453 0.32521796 0.53848903 0.81570697
 0.53178401 0.62097164 0.97942108 0.87594823 0.42096903 0.0130436
 0.01978726 0.96856409 0.22852638 0.40576923 0.03188876 0.75602626
 0.36947713 0.85356683 0.01944917 0.32151361 0.35719585 0.15546641
 0.63761261 0.19931915 0.48440623 0.66803621 0.24747403 0.93476971]
3.775291049965919


In [119]:
model = Model(2, 8)

result = minimize(
    model.target,
    model.rand_x0(),
    bounds=Bounds(-1, 1),
    options={"disp": True, "gtol": 1e-8, "maxfun": 1e10,},
)
print(result)
ls, rs, qs = model.unflatten_x(result['x'])
print(np.around(ls, 2))
print(np.around(rs, 2))
print(np.around(qs, 2))

      fun: 0.41804011366849264
 hess_inv: <96x96 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 0.23382278, -0.07826043, -0.27157915, -0.02825604, -0.07726841,
       -0.02812641, -0.15111885, -0.06013128, -0.019833  ,  0.03909088,
        0.18135071, -0.07778555,  0.15405894,  0.10679092,  0.25392767,
        0.16369014,  0.02000003,  0.01999976,  0.02000018,  0.02000014,
       -0.01407942, -0.02250291,  0.01191295,  0.02055701,  0.06032639,
        0.18063706, -0.05611638,  0.02072166, -0.03570338, -0.30441186,
        0.03241485, -0.03728237, -0.06476601,  0.21207471,  0.04670129,
        0.05509143,  0.06325954,  0.06341789, -0.0096786 ,  0.02038772,
        0.07607305,  0.08529452, -0.19348138,  0.15899715,  0.11092112,
        0.11136718, -0.06613695, -0.01920178, -0.0200001 , -0.02000003,
       -0.01999998,  0.02000005, -0.01806968,  0.00228639, -0.02293325,
       -0.02554544, -0.02098349,  0.03324456, -0.01346462, -0.01345551,
       -0.07345048, -0.00263751, -0.