In [83]:
from math import ceil

import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions.categorical import Categorical

device = "cpu"

In [84]:
def vectors_to_demo(uu, vv, ww, device):
    mul_tensor = torch.zeros((4, 4, 4), device=device)
    for i in torch.arange(uu.shape[0]):
        mul_tensor += torch.einsum("p,qr->pqr", uu[i], torch.outer(vv[i], ww[i]))
    # convert to steps/actions
    steps_wide = torch.cat((uu, vv, ww), dim=1)
    steps_wide += 1
    return mul_tensor, steps_wide


def steps_wide_to_uvw(steps_wide, n=4):
    uu, vv, ww = torch.split(steps_wide, n, dim=1)
    return uu, vv, ww


def get_strassen(device: str):
    uu_strassen = torch.tensor(
        [
            [1, 0, 0, 1],
            [0, 0, 1, 1],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
            [1, 1, 0, 0],
            [-1, 0, 1, 0],
            [0, 1, 0, -1],
        ],
        device=device,
    )
    vv_strassen = torch.tensor(
        [
            [1, 0, 0, 1],
            [1, 0, 0, 0],
            [0, 1, 0, -1],
            [-1, 0, 1, 0],
            [0, 0, 0, 1],
            [1, 1, 0, 0],
            [0, 0, 1, 1],
        ],
        device=device,
    )
    ww_strassen = torch.tensor(
        [
            [1, 0, 0, 1],
            [0, 0, 1, -1],
            [0, 1, 0, 1],
            [1, 0, 1, 0],
            [-1, 1, 0, 0],
            [0, 0, 0, 1],
            [1, 0, 0, 0],
        ],
        device=device,
    )
    return vectors_to_demo(uu_strassen, vv_strassen, ww_strassen, device)


# strassen_tensor, strassen_steps = vectors_to_demo(uu_strassen, vv_strassen, ww_strassen)


In [85]:
uu_strassen = torch.tensor(
    [
        [1, 0, 0, 1],
        [0, 0, 1, 1],
        [1, 0, 0, 0],
        [0, 0, 0, 1],
        [1, 1, 0, 0],
        [-1, 0, 1, 0],
        [0, 1, 0, -1],
    ],
    device=device,
)
vv_strassen = torch.tensor(
    [
        [1, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 0, -1],
        [-1, 0, 1, 0],
        [0, 0, 0, 1],
        [1, 1, 0, 0],
        [0, 0, 1, 1],
    ],
    device=device,
)
ww_strassen = torch.tensor(
    [
        [1, 0, 0, 1],
        [0, 0, 1, -1],
        [0, 1, 0, 1],
        [1, 0, 1, 0],
        [-1, 1, 0, 0],
        [0, 0, 0, 1],
        [1, 0, 0, 0],
    ],
    device=device,
)

In [90]:
strassen_tensor, strassen_steps = vectors_to_demo(uu_strassen, vv_strassen, ww_strassen, device)

In [91]:
strassen_steps

tensor([[2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2],
        [1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 0],
        [2, 1, 1, 1, 1, 2, 1, 0, 1, 2, 1, 2],
        [1, 1, 1, 2, 0, 1, 2, 1, 2, 1, 2, 1],
        [2, 2, 1, 1, 1, 1, 1, 2, 0, 2, 1, 1],
        [0, 1, 2, 1, 2, 2, 1, 1, 1, 1, 1, 2],
        [1, 2, 1, 0, 1, 1, 2, 2, 2, 1, 1, 1]])

In [92]:
strassen_tensor

tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.]],

        [[0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]])

In [35]:
i_bits = 103
bitstring = format(i_bits, "b").zfill(7)
print(bitstring)
used_indexes = [i for i in range(7) if bitstring[i]=='1']
avail_indexes = [i for i in range(7) if bitstring[i]=='0']
n_used = len(used_indexes)
n_avail = len(avail_indexes)
target_tensor = strassen_tensor    


1100111


In [123]:
n_demos = 0
state_tensor = []
target_action = []
reward = []
scalar = []
bit_info = []
for i_bits in range(2**7):
    bitstring = format(i_bits, "b").zfill(7)
    used_indexes = [i for i in range(7) if bitstring[i]=='1']
    avail_indexes = [i for i in range(7) if bitstring[i]=='0']
    n_used = len(used_indexes)
    n_avail = len(avail_indexes)
    target_tensor = strassen_tensor.clone()    
    for j in used_indexes:
        target_tensor -= (
            uu_strassen[j].view(-1, 1, 1) * vv_strassen[j].view(1, -1, 1) * ww_strassen[j].view(1, 1, -1)
        )

    for k in avail_indexes:
        state_tensor.append(target_tensor)
        target_action.append(torch.cat((uu_strassen[k], vv_strassen[k], ww_strassen[k])))
        reward.append( - n_avail)
            scalar.append(torch.tensor(0))
        bit_info.append(bitstring)
        n_demos += 1
    

In [124]:
print(n_demos)

448


In [118]:
print(len(set(bit_info)))

127


In [119]:
bit_info[-8:]

['1111001',
 '1111010',
 '1111010',
 '1111011',
 '1111100',
 '1111100',
 '1111101',
 '1111110']

In [120]:
print(reward[-10:])

[-3, -2, -2, -2, -2, -1, -2, -2, -1, -1]


In [122]:
target_action[-10:]

[tensor([ 0,  1,  0, -1,  0,  0,  1,  1,  1,  0,  0,  0]),
 tensor([ 1,  1,  0,  0,  0,  0,  0,  1, -1,  1,  0,  0]),
 tensor([-1,  0,  1,  0,  1,  1,  0,  0,  0,  0,  0,  1]),
 tensor([ 1,  1,  0,  0,  0,  0,  0,  1, -1,  1,  0,  0]),
 tensor([ 0,  1,  0, -1,  0,  0,  1,  1,  1,  0,  0,  0]),
 tensor([ 1,  1,  0,  0,  0,  0,  0,  1, -1,  1,  0,  0]),
 tensor([-1,  0,  1,  0,  1,  1,  0,  0,  0,  0,  0,  1]),
 tensor([ 0,  1,  0, -1,  0,  0,  1,  1,  1,  0,  0,  0]),
 tensor([-1,  0,  1,  0,  1,  1,  0,  0,  0,  0,  0,  1]),
 tensor([ 0,  1,  0, -1,  0,  0,  1,  1,  1,  0,  0,  0])]

In [112]:
torch.all(torch.eq(state_tensor[-6], state_tensor[-7]))

tensor(True)

In [43]:
uu = uu_strassen[0]
vv = vv_strassen[0]
ww = ww_strassen[0]
tensor_update = (
    uu.view(-1, 1, 1) * vv.view(1, -1, 1) * ww.view(1, 1, -1)
)
print(tensor_update)

tensor([[[1, 0, 0, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]],

        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[1, 0, 0, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]]])


In [None]:
uu = uu_strassen[0]
vv = vv_strassen[0]
ww = ww_strassen[0]
tensor_update = (
    uu.view(-1, 1, 1) * vv.view(1, -1, 1) * ww.view(1, 1, -1)
)
print(tensor_update)

In [55]:
mul_tensor = torch.zeros((4, 4, 4), device=device)
mul_tensor += torch.einsum("p,qr->pqr", uu, torch.outer(vv, ww))    
mul_tensor = mul_tensor.long()
# print(mul_tensor.long())


tensor([[[1, 0, 0, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]],

        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],

        [[1, 0, 0, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]]])


In [57]:
torch.all(torch.eq(tensor_update, mul_tensor))

tensor(True)