In [1]:
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

from time import process_time

import numpy as np
import json
import matplotlib.pyplot as plt

from sequential_tasks import TemporalOrderExp6aSequence as QRSU

from torch import nn
from torch.nn import functional as F
import torch

import model_utils as mu
from plot_lib import plot_results, set_default, print_colourbar, plot_state

from tqdm import tqdm

In [2]:
set_default(figsize=(20, 10))

In [3]:
# Constants
model_dir = "/Users/mghifary/Work/Code/AI/IF5281/2024/models"
device = "mps" if torch.backends.mps.is_available() else "cpu"
# model_type = "rnn"
model_type = "lstm"
# difficulty = "normal"
difficulty = "moderate"

In [4]:
# Create a data generator
if difficulty == "easy":
    difficulty_level = QRSU.DifficultyLevel.EASY
elif difficulty == "normal":
    difficulty_level = QRSU.DifficultyLevel.NORMAL
elif difficulty == "moderate":
    difficulty_level = QRSU.DifficultyLevel.MODERATE
elif difficulty == "hard":
    difficulty_level = QRSU.DifficultyLevel.HARD
else:
    difficulty_level = QRSU.DifficultyLevel.NIGHTMARE

example_generator = QRSU.get_predefined_generator(
    difficulty_level=difficulty_level,
    batch_size=32,
)

example_batch = example_generator[1]
print(f'The return type is a {type(example_batch)} with length {len(example_batch)}.')
print(f'The first item in the tuple is the batch of sequences with shape {example_batch[0].shape}.')
print(f'The first element in the batch of sequences is:\n {example_batch[0][0, :, :]}')
print(f'The second item in the tuple is the corresponding batch of class labels with shape {example_batch[1].shape}.')
print(f'The first element in the batch of class labels is:\n {example_batch[1][0, :]}')


# Decoding the first sequence
sequence_decoded = example_generator.decode_x(example_batch[0][0, :, :])
print(f'The sequence is: {sequence_decoded}')

# Decoding the class label of the first sequence
class_label_decoded = example_generator.decode_y(example_batch[1][0])
print(f'The class label is: {class_label_decoded}')


The return type is a <class 'tuple'> with length 2.
The first item in the tuple is the batch of sequences with shape (32, 81, 8).
The first element in the batch of sequences is:
 [[0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 1 0 0 0]
 [1 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 0 1 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 

In [5]:
# Setup the training and test data generators
batch_size = 32
train_data_gen = QRSU.get_predefined_generator(difficulty_level, batch_size)
test_data_gen = QRSU.get_predefined_generator(difficulty_level, batch_size)  

In [6]:
# Setup the RNN and training settings
input_size = train_data_gen.n_symbols
# hidden_size = 8 # easy
# hidden_size = 16
hidden_size = 64 # normal
# hidden_size = 256 # moderate
# hidden_size = 512 # hard
output_size = train_data_gen.n_classes    


model = mu.SimpleLSTM(
    input_size=input_size, 
    hidden_size=hidden_size,
    output_size=output_size,
    num_layers=1
)
# Load model
model_name = f"{model_type}_qrsu-{difficulty}"
model_path = os.path.join(model_dir, f"{model_name}.pth")
model.load_state_dict(torch.load(model_path))
model = model.to(device)

In [7]:
model.eval()
with torch.no_grad():
    data = test_data_gen[0][0]
    X = torch.from_numpy(data).float().to(device)
    H_t, C_t = model.get_states_across_time(X)

print("Color range is as follows:")
print_colourbar()

plot_state(X.cpu(), C_t, b=9, decoder=test_data_gen.decode_x)  # 3, 6, 9

Color range is as follows:


In [8]:
b = 4
actual_data = test_data_gen.decode_x(X.cpu()[b, :, :].numpy())
actual_data

'BacccdbaacdddcbaabaXadddbcddcdbaddabaacdbbdaddcccaXbddaaaadcdbbcdaccdacbcacabcaE'

In [9]:
X.cpu()[b, :, :].shape

torch.Size([81, 8])

In [10]:
X.cpu().shape

torch.Size([32, 81, 8])

In [11]:
len(actual_data)

80

In [12]:
X.cpu()

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0., 