In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from dicodile.utils.csc import reconstruct
from dicodile.utils.viz import display_dictionaries

from load_dict import load_dict_and_activations

In [2]:
data_dir_path = Path.home() / "data"
pattern_detection_path = data_dir_path / "pattern_detection_tokam"

In [None]:
experiment = "interchange_nodriftwave"
frame = 1000

input_dir_path = pattern_detection_path / "input" / experiment
input_file_path = input_dir_path / f"frame_{frame}.txt"

D_hat, z_hat = load_dict_and_activations(experiment, frame, verbose=1, mode_or_timestamp="241025_123151")

In [None]:
image_array = np.loadtxt(input_file_path)
plt.imshow(image_array)
plt.axis("off")
plt.show()

In [5]:
counts, values = np.histogram(image_array, bins=100)
mode = values[np.argmax(counts)]

mean_mask = image_array >= image_array.mean()
median_mask = image_array >= np.median(image_array)
mode_mask = image_array >= mode

In [6]:
# mask = mean_mask

# plt.imshow(mask, cmap="gray")
# plt.axis("off")
# plt.show()

# plt.imshow(image_array * mask, cmap="gray")
# plt.axis("off")
# plt.show()

In [None]:
plt.hist(image_array.flatten(), bins=100)
plt.show()

In [8]:
n_atoms = len(D_hat)

In [None]:
# Dictionary shape: (n_atoms, n_channels, *atom_support)

print(D_hat.shape, D_hat.min(), D_hat.max())

In [None]:
# Atoms sum to 1

print(np.sum(D_hat**2, axis=(1, 2, 3)))

In [None]:
normalisation = np.reshape(np.max(D_hat, axis=(1, 2, 3)), (-1, 1, 1, 1))

display_dictionaries(D_hat / normalisation)
plt.show()

In [None]:
print(z_hat.shape, z_hat.min(), z_hat.max())

In [13]:
max_vec = np.max(z_hat, axis=(1,2), keepdims=True)

In [None]:
fig, axes = plt.subplots(n_atoms, 2, figsize=(12, 6 * n_atoms))

for ax, z, max_val in zip(axes, z_hat, max_vec):
    print("L1 Norm", np.linalg.norm(z, ord=1))
    print("L0 norm", np.sum(z >= 0.01 * max_val))

    ax[0].imshow(
        z / z.max(), cmap="gray"
    )
    ax[1].imshow(
        z >= 0.01 * max_val, cmap="gray"
    )
    ax[0].axis("off")
    ax[1].axis("off")
fig.show()

In [15]:
from scipy.signal import convolve2d

In [None]:
fig, axes = plt.subplots(n_atoms, 3, figsize=(18, 6 * n_atoms))

for ax, z, D in zip(axes, z_hat, D_hat.squeeze()):

    ax[0].imshow(D, cmap="gray")
    ax[1].imshow(z >= 0.01 * z.max(), cmap="gray")
    ax[2].imshow(convolve2d(z, D), cmap="gray")
    for i in range(3):
        ax[i].axis("off")

fig.show()

In [None]:
def normalise(array_2d):
    min_val = array_2d.min()
    max_val = array_2d.max()
    max_range = max_val - min_val
    return (array_2d - min_val) / max_range

def to_channel(grayscale_image, channel: str):
    categorical_channel = {
        "red": [1, 0, 0],
        "green": [0, 1, 0],
        "blue": [0, 0, 1],
    }[channel]

    color_image = (
        np.expand_dims(grayscale_image, axis=-1)
        * np.reshape(categorical_channel, (1, 1, 3))
    )
    return color_image

fig, axes = plt.subplots(n_atoms, 2, figsize=(12, 6 * n_atoms))

for ax, z, D in zip(axes, z_hat, D_hat.squeeze()):

    atom_contribution = convolve2d(z, D)

    blue_image = to_channel(normalise(image_array), "blue")
    red_contribution = to_channel(atom_contribution, "red")

    ax[0].imshow(D, cmap="gray")
    ax[1].imshow(blue_image + red_contribution)
    for i in range(2):
        ax[i].axis("off")

fig.show()

In [18]:
X_hat = reconstruct(z_hat, D_hat)
X_hat = np.clip(X_hat, 0, 1)
X_hat = X_hat.transpose([1, 2, 0])

In [None]:
f, (ax1, ax2) = plt.subplots(2, 1, figsize=[6.4, 8])

ax1.imshow(image_array, cmap='gray')
ax1.set_title('Original image')
ax1.axis('off')

ax2.imshow(X_hat, cmap='gray')
ax2.set_title('Recovered image')
ax2.axis('off')
plt.tight_layout()

In [None]:
counts, values = np.histogram(image_array, bins=100)
mode = values[np.argmax(counts)]

plt.hist(image_array.flatten(), bins=100)
plt.hist(X_hat.flatten(), bins=100, alpha=0.8)
#plt.hist((X_hat + np.mean(image_array))[X_hat > 0].flatten(), bins=100)
#plt.axvline(image_array.mean(), color="k")
#plt.axvline(np.median(image_array), color="tab:red")
#plt.axvline(mode, color="tab:pink")
plt.ylim([0, 5000])

plt.show()