# Tutorial: Analisis Inferensi LSTM & Visualisasi Memory

Notebook ini menjelaskan cara melakukan inferensi menggunakan model LSTM yang telah dilatih dan memvisualisasikan *internal state* (memori) model untuk memahami bagaimana ia membuat keputusan.

### Langkah 1: Persiapan Lingkungan

Mengoimpor pustaka yang diperlukan dan menyetel konfigurasi dasar untuk plotting.

In [None]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from time import process_time
from IPython import display

import jax
import jax.numpy as jnp
from flax import nnx
import optax
import numpy as np

from sequential_tasks import TemporalOrderExp6aSequence as QRSU

import model_utils as mu
from plot_lib import plot_results, set_default, print_colourbar, plot_state

from tqdm import tqdm

set_default(figsize=(20, 10))

### Langkah 2: Parameter Eksperimen

Menentukan jenis model dan tingkat kesulitan dataset yang akan dianalisis.

In [None]:
# Constants
model_type = "lstm"
difficulty = "moderate"
SEED = 42

### Langkah 3: Pembuatan Dataset & Decoding

Membuat generator data QRSU dan mencoba mendecode contoh urutan simbol agar dapat dibaca oleh manusia.

In [None]:
# Create a data generator
if difficulty == "easy":
    difficulty_level = QRSU.DifficultyLevel.EASY
elif difficulty == "normal":
    difficulty_level = QRSU.DifficultyLevel.NORMAL
elif difficulty == "moderate":
    difficulty_level = QRSU.DifficultyLevel.MODERATE
elif difficulty == "hard":
    difficulty_level = QRSU.DifficultyLevel.HARD
else:
    difficulty_level = QRSU.DifficultyLevel.NIGHTMARE

example_generator = QRSU.get_predefined_generator(
    difficulty_level=difficulty_level,
    batch_size=32,
)

example_batch = example_generator[1]
print(f'The return type is a {type(example_batch)} with length {len(example_batch)}.')
print(f'The first item in the tuple is the batch of sequences with shape {example_batch[0].shape}.')
print(f'The first element in the batch of sequences is:\n {example_batch[0][0, :, :]}')
print(f'The second item in the tuple is the corresponding batch of class labels with shape {example_batch[1].shape}.')
print(f'The first element in the batch of class labels is:\n {example_batch[1][0, :]}')


# Decoding the first sequence
sequence_decoded = example_generator.decode_x(example_batch[0][0, :, :])
print(f'The sequence is: {sequence_decoded}')

# Decoding the class label of the first sequence
class_label_decoded = example_generator.decode_y(example_batch[1][0])
print(f'The class label is: {class_label_decoded}')

### Langkah 4: Inisialisasi Model

Membangun struktur LSTM. Dalam tutorial ini, kita memulai dengan bobot acak (random), namun idealnya kita memuat model yang sudah terlatih.

In [None]:
# Setup the RNN and training settings
input_size = example_generator.n_symbols
hidden_size = 64
output_size = example_generator.n_classes    

rngs = nnx.Rngs(SEED)
model = mu.SimpleLSTM(
    input_size=input_size, 
    hidden_size=hidden_size,
    output_size=output_size,
    num_layers=1,
    rngs=rngs
)

# For visual exploration, we ideally need a trained model. 
# For now, we initialize it randomly. 
# If we had a trained model path, we would load it here.
print("Model initialized with random weights.")

### Langkah 5: Visualisasi State Terhadap Waktu

Bagian paling krusial: kita menjalankan satu batch data melalui model, mengambil *Cell State* ($C_t$) pada setiap langkah waktu, dan memvisualisasikannya untuk melihat bagian mana dari urutan yang dianggap penting oleh LSTM.

In [None]:
# Get states across time for evaluation/visualization
data, _ = example_batch
X = jnp.array(data)
H_t, C_t = model.get_states_across_time(X)

print("Color range is as follows:")
print_colourbar()

# plot_state expects (seq_len, batch, hidden_size) for state
# H_t and C_t are (batch * seq_len, hidden_size) in current model_utils implementation
# Let's fix the shape to (seq_len, batch, hidden_size)
seq_len = X.shape[1]
batch_size = X.shape[0]
C_t_reshaped = C_t.reshape(seq_len, batch_size, -1)

plot_state(X, C_t_reshaped, b=0, decoder=example_generator.decode_x)