In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D


import volumembo
from volumembo.utils import assign_clusters
from volumembo.utils import onehot_to_labels
from _volumembo import fit_median_cpp
from volumembo.legacy import fit_median as fit_median_legacy

In [None]:
def build_transform_heat_kernel_2D(N, M, t):
    ith, jth = np.meshgrid(range(N), range(M), indexing="ij")
    eigenvalues = (
        2.0 * np.cos(2.0 * np.pi * ith / N) + 2.0 * np.cos(2.0 * np.pi * jth / M) - 4.0
    )
    kernel = np.exp(t * eigenvalues)
    return kernel


def diffuse_on_2Dgrid(image, kernel):
    transformed_image = np.fft.fft2(image)
    product = transformed_image * kernel
    # test = np.fft.ifft2(product)
    return np.fft.ifft2(product).real


def diffused_to_onehot(u):
    return volumembo.MBO._diffused_to_onehot(u)

In [None]:
N = 400
M = 300
t = 1
P = 4
kernel = build_transform_heat_kernel_2D(N, M, t)
print("kernel:\t\t", kernel.shape)

image = np.zeros((N, M))
print("image:\t\t", image.shape)
image_one_hot = np.zeros((N, M, P))
print("image_one_hot:\t", image_one_hot.shape)

In [None]:
line = np.array(range(N * M))
print("line: ", line.shape)
np.random.shuffle(line)
line = line.reshape((N, M))
print("line: ", line.shape)
for i in range(N):
    for j in range(M):
        index = int(line[i, j] / (N * M / P))
        image_one_hot[i, j, index] = 1.0

In [None]:
volume = np.array(np.sum(image_one_hot.reshape(N * M, P), axis=0).astype(int).tolist())
print("volumes:", volume, volume.shape, flush=True)
print("", flush=True)

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"kelvin_conjecture/animation_{timestamp}"

if not os.path.exists(filename):
    os.makedirs(filename)

diffused = np.zeros_like(image_one_hot, dtype=np.float64)

for i in range(1000):
    # print('Calling diffusion function.\n', flush=True)
    for p in range(P):
        diffused[:, :, p] = diffuse_on_2Dgrid(image_one_hot[:, :, p], kernel)
    # median = st.median(diffused.flatten())

    # image_one_hot_after_diffusion = diffused_to_onehot(
    #    diffused.reshape((N * M, P)) - np.array([0.333333333, 0.333333333, 0.333333333])
    # ).reshape((N, M, P))
    # print(
    #    "Before:\t",
    #    np.sum(image_one_hot_after_diffusion.reshape(N * M, P), axis=0).astype(int),
    # )
    # print('Calling fit_median_cpp', flush=True)
    median = np.array(fit_median_cpp(diffused.reshape((N * M, P)), volume))

    # print('Producing image.\n', flush=True)
    if i % 10 == 0:
        index_str = str(i).zfill(4)
        plt.imshow(np.argmax(image_one_hot, axis=2), vmin=0, vmax=P - 1)
        plt.savefig(filename + "/" + index_str + ".png")
        plt.close()

        labels = np.argmax(diffused.reshape((N * M, P)) - median, axis=1)
        volumes = np.bincount(labels, minlength=P)
        print(volumes)

    image_one_hot = diffused_to_onehot(diffused.reshape((N * M, P)) - median).reshape(
        (N, M, P)
    )
    # image_one_hot = assign_clusters(diffused.reshape((N*M, P)), median).reshape((N,M,P))
    # print("After:\t", np.sum(image_one_hot.reshape(N * M, P), axis=0).astype(int))

In [None]:
labels = np.argmax(diffused.reshape((N * M, P)) - median, axis=1)
volumes = np.bincount(labels, minlength=P)
print(volumes)
print(median)

In [None]:
# diffused = np.zeros_like(image_one_hot, dtype=np.float64)
# for p in range(P):
#    diffused[:, :, p] = diffuse_on_2Dgrid(image_one_hot[:, :, p], kernel)

# print(diffused.shape)
# print(diffused.reshape((N * M, P)).shape)

In [None]:
print("m =", median)
test_labels = np.argmax(diffused.reshape((N * M, P)) - median, axis=1)
new_sizes = np.bincount(test_labels, minlength=3)
print(new_sizes)

In [None]:
#############################################################################################
fig = plt.figure(figsize=(14, 5))
gs = gridspec.GridSpec(nrows=1, ncols=1)
#############################################################################################
ax0 = fig.add_subplot(gs[0, 0])
ax0.tick_params(
    direction="in", which="both", bottom=True, top=True, left=True, right=True
)
ax0.minorticks_on()

simplex = volumembo.plot.SimplexPlotter(ax=ax0)
simplex.plot_simplex_outline(lw=2)
simplex.add_grid_lines()
simplex.add_ticks(n=10, show_labels=True)
simplex.set_axis_labels()
simplex.plot_points(
    points=diffused.reshape((N * M, P)), labels=test_labels, ec="k", s=45
)
simplex.plot_median(point=median, s=50, color="lime")


#############################################################################################
plt.subplots_adjust(left=0.05, right=0.95, top=1.1, bottom=0.025)
# fig.savefig("./clustering0.png", transparent=False, dpi=300)

In [None]:
print(volumes)

In [None]:
def precompute_directions(M):
    directions = np.empty((M, M))
    for i in range(M):
        direc = np.ones(M) / (M - 1)
        direc[i] = -1.0
        directions[i] = direc
    return directions

In [None]:
precompute_directions(3)

In [None]:
def precompute_other_labels(M: int) -> list[list[int]]:
    other_labels = []
    for i in range(M):
        others = []
        for j in range(M):
            if j != i:
                others.append(j)
        other_labels.append(others)
    return other_labels

In [None]:
precompute_other_labels(4)