# Model evaluation

Now that we have a model trained that seems to have learned well, we can examine it to see how well it really learned.

In [1]:
import torch
from ataarangi.train import TransformerModel
from ataarangi.data import TextTokenizer, WorldStateTokenizer, load_data

In [2]:
# Initialize tokenizers
world_state_tokenizer = WorldStateTokenizer('../data/worldstate_tokens.txt')
text_tokenizer = TextTokenizer('../data/tokens.txt')

In [3]:
world_state_tokenizer.token_map

{'<SOS>': 1,
 '<SELECTED>': 2,
 '<NOT_SELECTED>': 3,
 'color_red': 4,
 'color_blue': 5,
 'color_green': 6,
 'color_yellow': 7,
 'color_black': 8,
 'color_white': 9,
 'color_brown': 10,
 'color_pink': 11,
 'height_1': 12,
 'height_2': 13,
 'height_3': 14,
 'height_4': 15,
 'height_5': 16,
 'height_6': 17,
 'height_7': 18,
 'height_8': 19,
 'height_9': 20,
 'height_10': 21,
 '<CLS>': 22}

In [4]:
text_tokenizer.token_map

{'rākau': 23,
 'te': 24,
 'ngā': 25,
 'me': 26,
 'mā': 27,
 'kākāriki': 28,
 'kōwhai': 29,
 'kikorangi': 30,
 'parauri': 31,
 'pango': 32,
 'whero': 33,
 'māwhero': 34,
 'iti': 35,
 'nui': 36,
 'hāunga': 37,
 'katoa': 38,
 'rawa': 39,
 'taha': 40,
 'kei': 41,
 'mauī': 42,
 'matau': 43,
 'ki': 44,
 'tawhiti': 45,
 'e': 46,
 'rua': 47,
 'waenganui': 48,
 'i': 49,
 'toru': 50,
 'tuarua': 51,
 'mai': 52,
 '[END]': 53}

In [5]:
def load_model(path, params):
    model = TransformerModel(**params)
    model.load_state_dict(torch.load(path))
    return model

best_model_params = {
    'vocab_size': 54,
    'embed_size': 256,
    'nhead': 4,
    'num_encoder_layers': 3,
    'num_decoder_layers': 3,
    'dim_feedforward': 1024,
    'max_seq_length': 500,
    'dropout': 0.17962795808108917
}

best_model = load_model(
    '../models/lr=0.0003-num_layers=3-embed_size=256-nhead=4-dim_ff=1024-dropout=0.1796.pth',
    best_model_params
)

In [6]:
train_data, dev_data = load_data('../data/train_set.csv', '../data/dev_set.csv', text_tokenizer, world_state_tokenizer)

In [7]:
tokens = world_state_tokenizer.tokenize(train_data['rākau'][0])
tokens_tensor = torch.tensor(tokens, dtype=torch.long)  # Ensure the tensor is of type long
generated_sequence = best_model.generate(tokens_tensor)
print(generated_sequence)

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 53])


In [8]:
' '.join([text_tokenizer.id_map[tok] for tok in generated_sequence.cpu().tolist()[2:]])

KeyError: 0