# Load and Prompt a Checkpoint

This notebook demonstrates how to reconstruct a model from a file checkpoint. 
To load the vocabulary and tokenizer used to train this model correctly, this notebook assumes that the corresponding maze datasets are loaded into MongoDB correctly, as outlined in [TokenDatasets.ipynb](TokenDatasets.ipynb).

A single checkpoint holds data about
* Hyper-parameters used for training (excluding the token vocabulary, which is stored together with the corresponding token dataset).
* All model parameters
* Optimizer state
* Number of gradient steps at which the model was solved

This notebook only outlines how a model can be reconstructed from a checkpoint file.
To use any of the other workflows included in this code base, the checkpoint files must be be first imported into MongoDB.
In this notebook we focus on the checkpoint resulting from the run `maze-sweep-rep-nondet-small-0` and assume that the file `maze-sweep-rep-nondet-small-0.ckpt` is present at the project's root directory.
This checkpoint file can be downloaded [here](https://dl.fbaipublicfiles.com/searchformer/ckptDB/maze-sweep-rep-nondet-small-0.ckpt). 
The file [`checkpoint_index.csv`](../doc.checkpoint_index.csv) lists all released checkpoints and their corresponding download link.

First different modules are imported.

In [1]:
import sys; sys.path.append("..")

import logging
import torch
from searchformer.train import Checkpoint
from searchformer.transformer import EncoderDecoderConfig, sample_probability
from searchformer.trace import DictTokenizer, TokenizedDataset


logging.basicConfig(
    level=logging.DEBUG,
    format="%(levelname)s - %(asctime)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

Loading the checkpoint file `../maze-sweep-rep-nondet-small-0.ckpt` and printing the training configuration.

In [2]:
ckpt = Checkpoint.from_file("../maze-sweep-rep-nondet-small-0.ckpt")
ckpt.config

{'_id': 'maze-sweep-rep-nondet-small-0',
 'data': {'train_name': 'maze.10-by-10-nondeterministic.simple',
  'test_name': 'maze.10-by-10-nondeterministic.simple',
  'batch_size': 8,
  'plan_only': False,
  'num_train_sequences': 50000,
  'num_test_sequences': 100000,
  'load_batch_size': 10000,
  'num_workers': 2},
 'encoder': 'enc-s',
 'decoder': 'dec-s',
 'optimizer': {'lr': 0.00025,
  'lr_schedule': 'cosine',
  'train_steps': 400000,
  'warmup': 2000,
  'beta_0': 0.9,
  'beta_1': 0.99,
  'cycle_length': 1.0,
  'cosine_theta': 1.0,
  'lr_min_ratio': 0.1},
 'log_interval': 1000,
 'eval_interval': 40000}

First loading the tokenized dataset and constructing a `DictTokenizer` object. 
This object is used to map word token sequences to integer lists.
Subsequently, an `EncoderDecoderConfig` object is constructed which holds all network architecture model parameters.
From this object the actual encoder-decoder Transformer is constructed and the model parameters (state dictionary) are loaded in.
The example below runs inference on CPU for the smallest model and shortest sequences to reduce compute requirements.

In [3]:
# Load vocabulary from tokenized dataset. This is needed to load the training token vocabulary and a test prompt.
tok_dataset = TokenizedDataset(ckpt.config_obj.data.train_name)
# Load tokenizer mapping tokens to indices.
tokenizer = DictTokenizer(tok_dataset.vocabulary)
# Construct model config object.
enc_dec_config = EncoderDecoderConfig.from_name(
    enc_name=ckpt.config_obj.encoder,
    dec_name=ckpt.config_obj.decoder,
    vocab_size=tokenizer.vocab_size,
)
# Construct model from config.
model = enc_dec_config.construct_model()
# Loading trained weights into model.
model.load_state_dict(ckpt.model_only_state_dict)

INFO - 2024-04-26 13:56:08 - root - Connecting to mongodb://localhost:27017/mongo
INFO - 2024-04-26 13:56:08 - root - Vocabulary size: 118
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_heads=3, dim=192.
DEBUG - 2024-04-26 13:56:08 - root - Creating block: n_hea

<All keys matched successfully>

The following code segment loads the first test prompt and prints it.

In [4]:
test_trace_id_list = tok_dataset.test_ids
test_trace_id_list.sort()
test_trace = next(iter(tok_dataset.test_it(test_trace_id_list[:1])))[0]

prompt_str = " ".join(test_trace.prompt)
prompt_str = prompt_str.replace("start", "\n\tstart")
prompt_str = prompt_str.replace("wall", "\n\twall ")
prompt_str = prompt_str.replace("goal", "\n\tgoal ")
print("Prompt: " + prompt_str)

DEBUG - 2024-04-26 13:56:08 - root - Loading all ids from Collection(Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True, sockettimeoutms=1800000, connecttimeoutms=1800000), 'tokenSeqDB'), 'maze.10-by-10-nondeterministic.simple.meta.test') ...
DEBUG - 2024-04-26 13:56:08 - root - Finished loading.
DEBUG - 2024-04-26 13:56:08 - root - Iterating over 1 ids.


Prompt: 
	start 3 6 
	goal  4 2 
	wall  0 0 
	wall  3 0 
	wall  4 0 
	wall  2 1 
	wall  4 1 
	wall  5 1 
	wall  9 1 
	wall  0 2 
	wall  1 2 
	wall  2 2 
	wall  6 2 
	wall  7 2 
	wall  5 3 
	wall  6 3 
	wall  7 3 
	wall  8 3 
	wall  9 3 
	wall  1 4 
	wall  2 4 
	wall  3 4 
	wall  9 4 
	wall  3 5 
	wall  4 5 
	wall  6 5 
	wall  5 6 
	wall  6 6 
	wall  9 6 
	wall  0 7 
	wall  2 7 
	wall  4 7 
	wall  6 8 
	wall  9 8 
	wall  2 9 
	wall  3 9 
	wall  4 9 
	wall  6 9 
	wall  8 9


The following code segment maps the prompt to an integer tensor and then generates a response sequence. 
This response sequence (a integer tensor) is then decoded into a token sequence and printed.

In [5]:
prompt_tokens = tokenizer.encode(test_trace.prompt)
prompt_tokens_tensor = torch.Tensor(prompt_tokens).long()
response = model.rollout(
    prompt=prompt_tokens_tensor,
    bos_idx=tokenizer.bos,
    eos_idx=tokenizer.eos,
    max_rollout_len=2000,
    sample_fn=sample_probability,
)
response_token_list = tokenizer.decode(response[0].tolist())
print("Response:" + " ".join(response_token_list).replace("bos ", "\n\tbos").replace("eos", "\n\teos").replace("create", "\n\tcreate").replace("close", "\n\tclose ").replace("plan ", "\n\tplan   "))

DEBUG - 2024-04-26 13:56:13 - root - Rollout 200 steps, 0 seq. complete.
INFO - 2024-04-26 13:56:15 - root - Rollout length: 283


Response:
	bos
	create 3 6 c0 c5 
	close  3 6 c0 c5 
	create 3 7 c1 c6 
	create 4 6 c1 c4 
	create 2 6 c1 c6 
	close  4 6 c1 c4 
	close  3 7 c1 c6 
	create 3 8 c2 c7 
	close  2 6 c1 c6 
	create 2 5 c2 c5 
	create 1 6 c2 c7 
	close  2 5 c2 c5 
	create 1 5 c3 c6 
	close  1 5 c3 c6 
	create 0 5 c4 c7 
	close  1 6 c2 c7 
	create 1 7 c3 c8 
	create 0 6 c3 c8 
	close  3 8 c2 c7 
	create 2 8 c3 c8 
	create 4 8 c3 c6 
	close  4 8 c3 c6 
	create 5 8 c4 c7 
	close  2 8 c3 c8 
	create 1 8 c4 c9 
	close  1 7 c3 c8 
	close  0 5 c4 c7 
	create 0 4 c5 c6 
	close  5 8 c4 c7 
	create 5 9 c5 c8 
	create 5 7 c5 c6 
	close  0 4 c5 c6 
	create 0 3 c6 c5 
	close  0 6 c3 c8 
	close  0 3 c6 c5 
	create 1 3 c7 c4 
	close  5 7 c5 c6 
	create 6 7 c6 c7 
	close  1 3 c7 c4 
	create 2 3 c8 c3 
	close  2 3 c8 c3 
	create 3 3 c9 c2 
	close  3 3 c9 c2 
	create 4 3 c10 c1 
	create 3 2 c10 c1 
	close  4 3 c10 c1 
	create 4 2 c11 c0 
	create 4 4 c11 c2 
	close  2 2 c9 c0 
	plan   3 6 
	plan   2 6 
	plan   2 5 
	plan   1 