In [11]:
from src.domain.model import Model
from src.domain.datamodels import DatasetConfig, ModelConfig
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from src.domain.dataset.bridge_dataset import BridgeDataset

# First, let's set up our model and generate some outputs as shown previously
dataset_path = Path("input_data.pkl")
checkpoint_path = Path("model_artifacts/2025_01_26_run_005/model_epoch_19.pth")

dataset_config = DatasetConfig(
    dataset_filepath=dataset_path,
    dimension_phon_repr=31,
    orthographic_vocabulary_size=31,
    phonological_vocabulary_size=34,
    max_orth_seq_len=16,
    max_phon_seq_len=14,
)

model_config = ModelConfig(
    num_phon_enc_layers=8,
    num_orth_enc_layers=1,
    num_mixing_enc_layers=2,
    num_phon_dec_layers=4,
    num_orth_dec_layers=1,
    d_model=128,
    nhead=16,
    d_embedding=1,
    seed=1337,
)

dataset = BridgeDataset(dataset_config)
model = Model(model_config, dataset_config)
chkpt = torch.load(checkpoint_path)
model.load_state_dict(chkpt["model_state_dict"])
model.eval()

# Generate outputs for a few example words
words = ["special", "cat", "hat"]
encodings = dataset.encode(words)
output = model.generate(encodings, pathway="p2p")

  chkpt = torch.load(checkpoint_path)
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


In [15]:
output

GenerationOutput(global_encoding=tensor([[[-1.4095e+00, -1.1995e+00, -1.8271e+00,  1.9730e+00,  1.6186e-01,
           1.0571e+00,  3.7346e-01,  7.4502e-02,  2.0194e+00, -1.5930e+00,
           3.9793e-01, -7.2384e-01, -2.7475e-01,  1.4110e-01, -7.2520e-01,
           1.3257e+00,  3.6148e-01, -9.4106e-01, -1.0011e+00,  1.4901e+00,
          -1.5038e+00, -4.2724e-01, -1.2157e+00, -5.1754e-01, -4.4586e-01,
          -3.6102e-01,  7.2623e-01, -2.5621e-01,  9.1356e-01,  1.2516e+00,
           4.7797e-01, -1.3797e+00,  5.3543e-01,  2.1893e+00,  6.7478e-01,
           2.6462e-04,  1.0731e+00, -2.1673e+00, -1.6657e+00,  2.9619e-01,
          -1.8022e+00,  5.6132e-01,  9.2281e-01, -3.3354e-01,  7.3111e-01,
           1.1474e+00, -3.8997e-01, -1.3576e+00, -6.5947e-01,  1.6759e+00,
          -8.1341e-01, -8.5633e-01, -1.5119e-01, -1.1080e+00,  7.2549e-01,
           1.6314e+00, -2.2890e-01, -3.0734e-01,  1.2726e+00,  2.5322e-01,
           6.6323e-01,  2.0086e-01, -4.2688e-01,  1.0025e+00, -1.57