In [1]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from dicodile.utils.csc import reconstruct
from dicodile.utils.viz import display_dictionaries

from load_dict import load_dict_and_activations

In [2]:
data_dir_path = Path.home() / "data"
pattern_detection_path = data_dir_path / "pattern_detection_tokam"

In [None]:
experiment = "interchange_nodriftwave"
frame = 1000

input_dir_path = pattern_detection_path / "input" / experiment
input_file_path = input_dir_path / f"frame_{frame}.txt"

D_hat, z_hat = load_dict_and_activations(experiment, frame, verbose=1)

In [None]:
image_array = np.loadtxt(input_file_path)
plt.imshow(image_array)
plt.axis("off")
plt.show()

In [None]:
plt.hist(image_array.flatten(), bins=100)
plt.show()

In [6]:
n_atoms = len(D_hat)

In [None]:
# Dictionary shape: (n_atoms, n_channels, *atom_support)

print(D_hat.shape, D_hat.min(), D_hat.max())

In [None]:
# Atoms sum to 1

print(np.sum(D_hat**2, axis=(1, 2, 3)))

In [None]:
normalisation = np.reshape(np.max(D_hat, axis=(1, 2, 3)), (-1, 1, 1, 1))

display_dictionaries(D_hat / normalisation)
plt.show()

In [None]:
print(z_hat.shape, z_hat.min(), z_hat.max())

In [None]:
fig, axes = plt.subplots(n_atoms, 1, figsize=(6, 6 * n_atoms))

for ax, z in zip(axes, z_hat):
    ax.imshow(
        z / z.max(), cmap="gray"
    )
    ax.axis("off")
fig.show()

In [12]:
X_hat = reconstruct(z_hat, D_hat)
X_hat = np.clip(X_hat, 0, 1)
X_hat = X_hat.transpose([1, 2, 0])

In [None]:
f, (ax1, ax2) = plt.subplots(2, 1, figsize=[6.4, 8])

ax1.imshow(image_array, cmap='gray')
ax1.set_title('Original image')
ax1.axis('off')

ax2.imshow(X_hat + image_array.mean(), cmap='gray')
ax2.set_title('Recovered image')
ax2.axis('off')
plt.tight_layout()

In [None]:
counts, values = np.histogram(image_array, bins=100)
mode = values[np.argmax(counts)]

plt.hist(image_array.flatten(), bins=100)
plt.hist(X_hat[X_hat > 0].flatten(), bins=100)
plt.hist((X_hat + np.mean(image_array))[X_hat > 0].flatten(), bins=100)
plt.axvline(image_array.mean(), color="k")
plt.axvline(mode, color="k", linestyle=":")

plt.show()