In [1]:
import os
os.chdir("..")

import torch
import numpy as np

from coin.utils import build_tensor, generate_partitions
from coin.utils import generate_triads

In [2]:
# n = 4
# U, V, W = generate_triads(n,n,n)
# T = U.view(-1, n**2, 1, 1) * V.view(-1, 1, n**2, 1) * W.view(-1, 1, 1, n**2)

In [3]:
import matplotlib.pyplot as plt

def plot_tensor(T, n):
    fig, axs = plt.subplots(n,n, figsize=(4,4))
    for (ax, m) in zip(axs.flat, T):
        ax.matshow(m)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [4]:
def tensor_product(U, V, W):
    r = U.size(0)
    n1, n2, n3 = U.size(1), V.size(1), W.size(1)
    return U.view(r, n1, 1, 1) * V.view(r, 1, n2, 1) * W.view(r, 1, 1, n3)

def triads_decompose(T):
    nonzero = T.nonzero()
    r = len(nonzero)
    
    U = torch.zeros((r, T.size(0)), dtype=torch.int8)
    V = torch.zeros((r, T.size(1)), dtype=torch.int8)
    W = torch.zeros((r, T.size(2)), dtype=torch.int8)
    for (p, (i,j,k)) in enumerate(nonzero):
        U[p, i], V[p, j], W[p, k] = 1, 1, 1

    return (U, V, W)

In [251]:
def get_unique_hashed_idx(hashed):
    """Filter unique hashed orbits by removing duplicates"""
    idx1 = torch.arange(hashed.size(0), dtype=torch.int64, device=hashed.device)
    hashed_sorted, idx2 = torch.sort(hashed)
    mask = torch.concat((torch.tensor([True], device=hashed.device), hashed_sorted[1:] - hashed_sorted[:-1] > 0))
    return idx1[idx2[mask]] 

## Starting Point

In [826]:
n = 3
sym = 3
Mn = build_tensor(n,n,n)

tranpose_c = torch.arange(n**2).view(n,n).T.reshape(-1)
Mn = Mn[tranpose_c]

partition = generate_partitions(n, sym=sym)[1]
print(f"{partition = }")
vec = torch.zeros((len(partition), n*n), dtype=torch.int8)
for (j, p) in enumerate(partition):
    for i in p:
        vec[j, n*(i-1) + (i-1)] = 1
ten_sym = tensor_product(vec, vec, vec).sum(dim=0)

rest = (Mn - ten_sym) % 2
print(f"#nonzero(Mn):   {len(Mn.nonzero())}")
print(f"#nonzero(rest): {len(rest.nonzero())}")

# plot_tensor(rest, n)

partition = [[1], [2, 3]]
#nonzero(Mn):   27
#nonzero(rest): 30


In [827]:
UVW = triads_decompose(rest)
U, V, W = UVW
UVW = torch.stack((U,V,W))
# (tensor_product(*UVW).sum(dim=0) == rest).all()

UVW_p00 = UVW[[0,1,2]]
UVW_p10 = UVW[[1,2,0]]
UVW_p20 = UVW[[2,0,1]]
permute_z2 = torch.arange(n**2).flip(0)
if sym == 6:
    UVW_p01 = UVW_p00[:, :, permute_z2]
    UVW_p11 = UVW_p10[:, :, permute_z2]
    UVW_p21 = UVW_p20[:, :, permute_z2]
    UVW_mapped = torch.stack((UVW_p00, UVW_p10, UVW_p20, UVW_p01, UVW_p11, UVW_p21))
else:
    UVW_mapped = torch.stack((UVW_p00, UVW_p10, UVW_p20))

print(f"{UVW_mapped.shape = }")

hash_lim = int(1e9)
hash_mask = torch.randint(0, hash_lim, (1, 1, 3 * n**2), dtype=torch.int64)
UVW_mapped_flat = UVW_mapped.permute(0,2,1,3).reshape(sym, -1, 3 * n**2)
UVW_hashed = (UVW_mapped_flat * hash_mask).sum(dim=2)
# n_orbits = UVW_hashed.sum(dim=(0,1,3)).unique().size(0)
UVW_hashed_sort = torch.sort(UVW_hashed, dim=0).values
n_orbits = UVW_hashed_sort.unique(dim=1).size(1)
print(f"{n_orbits     = }")
print(f"{n_orbits*sym = }")

unique_idx = get_unique_hashed_idx(UVW_hashed_sort[0])
start_point = UVW_mapped.permute(2,0,1,3)[unique_idx] # (orbits, sym, uvw, n²)  
start_point = start_point[:, :, 0] # (orbits, sym, n²) 

UVW_mapped.shape = torch.Size([3, 3, 30, 9])
n_orbits     = 10
n_orbits*sym = 30


## Additional data

In [828]:
# (orbits, sym, n²) 
# print(f"{start_point.shape = }")

# pow2 = 1 << torch.arange(n**2, dtype=torch.int8).flip(0)
# pow2 = pow2.reshape(1, 1, -1)

# start_point_bin = (start_point.long() * pow2).sum(dim=2)

In [829]:
def get_eq_idx(scheme):
    r, sym, n2 = scheme.shape
    eq_idx = (
        scheme.view(r, 1, 1, sym, n2) == scheme.view(1, r, sym, 1, n2)
    ).all(dim=-1).nonzero() # (orb1, orb2, sym2, sym1) ? for some reason
    eq_idx = eq_idx[eq_idx[:, 0] < eq_idx[:, 1]] # filter eq in same orbit and purmutation
    eq_idx = eq_idx[:, [0,3,1,2]] # (orb1, sym1, orb2, sym2)
    return eq_idx

def check_scheme(scheme):
    U, V, W = torch.concat((scheme[:, [0,1,2]], scheme[:, [1,2,0]], scheme[:, [2,0,1]]), dim=0).permute(1,0,2) # (uvw, tensors, n²)
    return (tensor_product(U, V, W).sum(dim=0)%2 == rest).all().item()

from random import randint
def flip(scheme):
    eq_idx = get_eq_idx(scheme)
    if get_eq_idx(scheme).size(0):
        o1, s1, o2, s2 = eq_idx[randint(0, eq_idx.size(0)-1)]
        scheme[o1, (s1+1)%3] ^= scheme[o2, (s2+1)%3]
        scheme[o2, (s2+2)%3] ^= scheme[o1, (s1+2)%3]
        return True
    else:
        return False

In [830]:
# r = start_point.size(0)
# eq_idx = (
#     start_point.view(r, 1, 1, sym, n**2) == start_point.view(1, r, sym, 1, n**2)
# ).all(dim=-1).nonzero() # (orb1, orb2, sym2, sym1) ? for some reason
# eq_idx = eq_idx[eq_idx[:, 0] < eq_idx[:, 1]] # filter eq in same orbit and purmutation
# eq_idx = eq_idx[:, [0,3,1,2]] # (orb1, sym1, orb2, sym2)
# eq_idx = get_eq_idx(start_point)
# eq_idx

In [831]:
from tqdm import tqdm

In [865]:
scheme = start_point.clone()
for j in range(1, 10000+1):
    flippable = flip(scheme)
    if (scheme.sum(dim=2)==0).any():
        scheme = scheme[~(scheme.sum(dim=2)==0).any(dim=1)]
    if not flippable:
        break

rank = scheme.size(0) * sym + len(partition)

print(f"step {j}, # orbits = {scheme.size(0)}")
print(f"{check_scheme(scheme) = }")
print(f"{rank = }")

step 10000, # orbits = 7
check_scheme(scheme) = True
rank = 23


In [641]:
# sym scheme -> Mn

In [504]:
# (orbits, sym, n²)
# if sym==3: (orbits, uvw, n²)

# restore rest from UVW in sym==3
# U, V, W = torch.concat((scheme[:, [0,1,2]], scheme[:, [1,2,0]], scheme[:, [2,0,1]]), dim=0).permute(1,0,2) # (uvw, tensors, n²)
# (tensor_product(U, V, W).sum(dim=0) == rest).all().item()

## Flip

In [349]:
# орбит будет много, flip влияет только на пару 
# =>
# имеет смысл хранить множители и где они встречаются
# а после flip только для измененных проводить поиск

In [350]:
start_point.shape

torch.Size([4, 3, 3, 4])

In [357]:
T1 = start_point[0, :, 0, :].clone()
T2 = start_point[1, :, 0, :].clone()

k1 = 0
k2 = 0
# +1 and +2 can be changed
T1[(k1+1)%3] ^= T2[(k2+1)%3]
T2[(k2+2)%3] ^= T1[(k1+2)%3]

In [335]:
"""
a⊗b1⊗c1 + b1⊗c1⊗a + c1⊗a⊗b1 + a'⊗b1'⊗c1' + b1'⊗c1'⊗a' + c1'⊗a'⊗b1'
a⊗.. + b1⊗.. + c1⊗.. + a'⊗.. + b1'⊗.. + c1'⊗..
flip
a⊗b2⊗c2 + b2⊗c2⊗a + c2⊗a⊗b2 + a'⊗b2'⊗c2' + b2'⊗c2'⊗a' + c2'⊗a'⊗b2'
a⊗.. + b2⊗.. + c2⊗.. + a'⊗.. + b2'⊗.. + c2'⊗..
=
a⊗(b1^b2)⊗c1 + (b1^b2)⊗c1⊗a + c1⊗a⊗(b1^b2) + a'⊗(b1^b2)'⊗c1' + (b1^b2)'⊗c1'⊗a' + c1'⊗a'⊗(b1^b2)'
a⊗b2⊗(c2^c1) + b2⊗(c2^c1)⊗a + (c2^c1)⊗a⊗b2 + a'⊗b2'⊗(c2^c1)' + b2'⊗(c2^c1)'⊗a' + (c2^c1)'⊗a'⊗b2'
~
a⊗.. + (b1^b2)⊗.. + c1⊗.. + a'⊗.. + (b1^b2)'⊗.. + c1'⊗..
a⊗.. + b2⊗.. + (c2^c1)⊗.. + a'⊗.. + b2'⊗.. + (c2^c1)'⊗..
""";

In [352]:
"""
a1⊗.. + x⊗.. + c1⊗.. + ?...'
flip
a2⊗.. + b2⊗.. + x⊗.. + ?...'
=
x⊗.. + c1⊗.. + a1⊗.. + ?...'
flip
x⊗.. + a2⊗.. + b2⊗.. + ?...'
=
x⊗.. + (c1^a2)⊗.. + a1⊗.. + ?...'
x⊗.. + a2⊗.. + (b2^a1)⊗.. + ?...'
""";

In [None]:
"""
a⊗.. + b1⊗.. + c1⊗.. + a'⊗.. + b1'⊗.. + c1'⊗..
flip
a⊗.. + b2⊗.. + c2⊗.. + a'⊗.. + b2'⊗.. + c2'⊗..
=
a⊗.. + (b1^b2)⊗.. + c1⊗.. + a'⊗.. + (b1^b2)'⊗.. + c1'⊗..
a⊗.. + b2⊗.. + (c2^c1)⊗.. + a'⊗.. + b2'⊗.. + (c2^c1)'⊗..
""";

In [192]:
UVW.shape

torch.Size([3, 12, 4])

In [193]:
UVW[:, :2, :].shape

torch.Size([3, 2, 4])

In [191]:
# torch.sort(UVW_hashed, dim=0).values.T

In [178]:
# U = torch.tensor([[1,0,0,1],[1,0,0,0],[0,0,1,1],[0,1,0,-1],[0,0,0,1],[1,1,0,0],[-1,0,1,0]])
# V = torch.tensor([[1,0,0,1],[0,1,0,-1],[1,0,0,0],[0,0,1,1],[-1,0,1,0],[0,0,0,1],[1,1,0,0]])
# W = torch.tensor([[1,0,0,1],[0,0,1,1],[0,1,0,-1],[1,0,0,0],[1,1,0,0],[-1,0,1,0],[0,0,0,1]])
# tensor_product(U,V,W).sum(dim=0)