In [1]:

from src.domain.model import Model
from src.domain.datamodels import DatasetConfig, ModelConfig
import torch
from src.application.handlers import (
    ModelConfigHandler,
    DatasetConfigHandler,
    TrainingConfigHandler,
    WandbConfigHandler,
    LoggingConfigHandler,
    TrainModelHandler,
)
import os
import sys

import torch
from src.application.training.phon_metrics import calculate_phoneme_wise_accuracy

from src.domain.datamodels.dataset_config import DatasetConfig
from src.domain.datamodels.encodings import BridgeEncoding
from src.domain.datamodels.model_config import ModelConfig
from src.domain.datamodels.training_config import TrainingConfig
from src.domain.dataset.bridge_dataset import BridgeDataset
from src.domain.model.model import Model

In [3]:

# Load the model checkpoint
model_config: ModelConfig = ModelConfigHandler(config_filepath='app/config/model_config.yaml').get_config()
dataset_config: DatasetConfig = DatasetConfigHandler(config_filepath='app/config/dataset_config.yaml').get_config()
training_config: TrainingConfig = TrainingConfigHandler(config_filepath='app/config/training_config.yaml').get_config()
bridge_dataset = BridgeDataset(dataset_config=dataset_config, device=training_config.device)
model=Model(
    model_config=model_config,
    dataset_config=dataset_config,
    device=training_config.device,
)
checkpoint = torch.load(training_config.checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])

# Load the test dataset
dataset_config.dataset_filepath = 'data/tests/fry_1980.pkl'
dataset_config.max_orth_seq_len = None
dataset_config.max_phon_seq_len = None
test_dataset = BridgeDataset(dataset_config=dataset_config, device=training_config.device)

  checkpoint = torch.load(training_config.checkpoint_path)


In [5]:
output = model.generate(test_dataset.data, pathway='p2p')

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


In [8]:
# Prepare the true and predicted phoneme sequences
phon_pred = output.phon_vecs
phon_true = test_dataset.data.phonological.targets
phon_pred = torch.stack([torch.stack(i) for i in phon_pred])


In [9]:
phon_pred.shape, phon_true.shape

(torch.Size([299, 10, 33]), torch.Size([299, 10, 33]))

In [10]:
# Mask the padding tokens
phon_valid_mask = phon_true != 2
masked_phon_true = phon_true[phon_valid_mask]
masked_phon_pred = phon_pred[phon_valid_mask]
phoneme_wise_mask = phon_pred == phon_true

In [11]:


from src.application.training.phon_metrics import calculate_cosine_distance, calculate_euclidean_distance, calculate_phon_word_accuracy


phoneme_wise_accuracy = calculate_phoneme_wise_accuracy(
        phon_true, masked_phon_true, phoneme_wise_mask
    )
word_wise_accuracy = calculate_phon_word_accuracy(
    phon_true[:64], phoneme_wise_mask[:64]
)
cosine_accuracy = calculate_cosine_distance(phon_true, phon_pred)
euclidean_distance = calculate_euclidean_distance(phon_true, phon_pred)
phoneme_wise_accuracy, word_wise_accuracy, cosine_accuracy, euclidean_distance

(tensor(0.9992), tensor(1.), tensor(0.9999), tensor(0.0008))

In [4]:
logits = model.forward('p2p', phon_enc_input=test_dataset.data.phonological.enc_input_ids,
                               phon_enc_pad_mask=test_dataset.data.phonological.enc_pad_mask,
                               phon_dec_input=test_dataset.data.phonological.dec_input_ids,
                               phon_dec_pad_mask=test_dataset.data.phonological.dec_pad_mask)

In [None]:
phon_pred = torch.argmax(logits["phon"], dim=1)
phon_pred.shape