## Setting up the environment

In [30]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [31]:
import os
os.chdir("/content/drive/MyDrive/MentorSpring2025/flip-graph")

import torch
import numpy as np
from tqdm import tqdm

import importlib
import flipgraph
importlib.reload(flipgraph)

from flipgraph import int2bin, bin2int, reconstruct
from flipgraph import generate_triads_binary, generate_triads, build_tensor, check_uvw
from flipgraph import flip, flippable, reduce, reducible

from matplotlib import pyplot as plt

In [32]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Creating a probe task

In [4]:
n1 = 3
n2 = 3
n3 = 3

dims = (n1 * n3, n1 * n2, n2 * n3)
uvw_dtype = torch.int16 # should be ≥ max(dims)
T = build_tensor(n1, n2, n3).to(device)

In [11]:
#creating a dataset

N = int(1e5)

n1, n2, n3 = 3, 3, 3
dims = (n1 * n3, n1 * n2, n2 * n3)

UVW = generate_triads_binary(n1, n2, n3, dims, uvw_dtype)[None]
r = UVW.size(2)
ar_N = torch.arange(N, device=device)
print(f"{r = }")

UVW = UVW.to(device).repeat(N, 1, 1)

r = 27


In [8]:
I, J, K = np.indices((n1, n2, n3)).reshape(3, -1)
print(I, J, K)
torch.eye(n1 * n3, dtype=torch.int8)
generate_triads_binary(3, 3, 3, dims, torch.int16).shape

[0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2] [0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2 0 0 0 1 1 1 2 2 2] [0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2 0 1 2]


torch.Size([3, 27])

In [12]:
steps = 3000
N_reducible = np.zeros(steps)
for k in range(0, steps):
    flip(UVW)
    N_reducible[k] = reducible(UVW).sum()
    #if N_reducible[k]:
        #print(f"reducible at {k = }, number of reducible schemes: {int(N_reducible[k])}")

In [None]:
plt.plot(np.arange(steps), np.sqrt(N_reducible))
k, b = np.polyfit(np.arange(steps)[25:], np.sqrt(N_reducible)[25:], 1)
#plt.plot(k*np.arange(steps) + b)
#plt.yscale("log")
plt.xscale("log")
plt.xlabel("steps")
plt.ylabel("$N_{red}^{1/2} $")
plt.savefig("30K.png")

In [None]:
mask = reducible(UVW)
UVW_26 = reduce(UVW[mask])
UVW.shape, UVW_26.shape


In [None]:
mask = reducible(UVW_26)
mask.sum()

In [None]:
UVW_25 = reduce(UVW_26[mask])
UVW_25.shape[0]

UVW_25 = UVW_25.repeat(int(1e5/UVW_25.shape[0]), 1, 1)
UVW_25.shape

In [None]:
steps = 1000
N_reducible = np.zeros(steps)
for k in range(0, steps):
    UVW_25 = flip(UVW_25)
    N_reducible[k] = reducible(UVW_25).sum()
    if N_reducible[k]:
        print(f"reducible at {k = }, number of reducible schemes: {int(N_reducible[k])}")

In [None]:
mask = reducible(UVW_25)
mask.sum()

UVW_24 = reduce(UVW_25[mask])
print(UVW_24.shape[0])

UVW_24 = UVW_24.repeat(int(1e5/UVW_24.shape[0]), 1, 1)
UVW_24.shape

steps = 1000
N_reducible = np.zeros(steps)
for k in range(0, steps):
    UVW_24 = flip(UVW_24)
    N_reducible[k] = reducible(UVW_24).sum()
    if N_reducible[k]:
        print(f"reducible at {k = }, number of reducible schemes: {int(N_reducible[k])}")

In [None]:
#finding a 23-operational 3 by 3 MatMul scheme
UVW_min = UVW
N_triads = 27
while True:
    steps = 1000
    N_reducible = np.zeros(steps)
    for k in range(0, steps):
        UVW_min = flip(UVW_min)
        N_reducible[k] = reducible(UVW_min).sum()
        mask = reducible(UVW_min)
    print(f"number of triads:{N_triads}, number of reducible schemes: {int(N_reducible[-1])}")
    N_triads -= 1
    try:
        UVW_min = reduce(UVW_min[mask])
        UVW_min = UVW_min.repeat(int(1e5/UVW_min.shape[0]), 1, 1)
        print(UVW_min.shape)
    except:
        print("No reducible schemes found")
        break

## BFS

In [None]:
def find_nearest_reducible(UVW):
  steps = 0
  # init variables

  N, _, r = UVW.shape
  device = UVW.device
  ar_N = torch.arange(N, device=device)

  # sort to find pairs with equal vectors in O(r ln(r))
  val, idx = torch.sort(UVW, dim=2)
  mask = val[..., 1:] == val[..., :-1]

  flat_pos = ((1+torch.rand((N,3,r-1), dtype=torch.float, device=device)) * mask
    ).view(N, -1)

  # inplace flip in ℤ₂
  UVW[ar_N, c1, j1] ^= UVW[ar_N, c1, j2]
  UVW[ar_N, c2, j2] ^= UVW[ar_N, c2, j1]


## Checking number of reducibles

In [None]:
N = int(1e5)
UVW_222 = generate_triads_binary(2, 2, 2, (4, 4, 4), torch.int8).to(device).repeat(N, 1, 1)
for i in tqdm(range(5000)):
    flip(UVW_222)

reducible(UVW_222).sum()

In [None]:
UVW_333 = generate_triads_binary(3, 3, 3, (9, 9, 9), torch.int16).to(device).repeat(N, 1, 1)
for i in tqdm(range(5000)):
    flip(UVW_333)

reducible(UVW_333).sum()

In [None]:
UVW_444 = generate_triads_binary(4, 4, 4, (16, 16, 16), torch.int16).to(device).repeat(N, 1, 1)
for i in tqdm(range(5000)):
    flip(UVW_444)

reducible(UVW_444).sum()

In [None]:
reducible(UVW_444).sum()