# Evaluation of the model

In [1]:
import os
import sys
from pathlib import Path

import lightning as pl
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from datasets import Dataset

# Add parent directory to path to import shortcutfm modules
sys.path.append('..')
from shortcutfm.analysis.denoising import denoise_with_velocity_tracking
from shortcutfm.batch import collate
from shortcutfm.config import TrainingConfig
from shortcutfm.text_datasets import TextDataset
from shortcutfm.train.pl.trainer_factory import (
    create_criterion,
    load_unit_from_checkpoint,
)

In [2]:
import os
import sys
from pathlib import Path

import lightning as pl
import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from datasets import Dataset

# Add parent directory to path to import shortcutfm modules
sys.path.append('..')
from shortcutfm.analysis.denoising import denoise_with_velocity_tracking
from shortcutfm.batch import collate
from shortcutfm.config import TrainingConfig
from shortcutfm.text_datasets import TextDataset
from shortcutfm.train.pl.trainer_factory import (
    create_criterion,
    load_unit_from_checkpoint,
)

In [3]:
# Set the checkpoint directory
checkpoint_dir = Path("../checkpoints/run_30me1xfs")
checkpoint_path = checkpoint_dir / "last-v1.ckpt"
training_config_path = checkpoint_dir / "training_config.yaml"

checkpoint_dir = Path("../checkpoints/run_q2qzjeso")
checkpoint_path = checkpoint_dir / "epoch=190-step=27100.ckpt"
training_config_path = checkpoint_dir / "training_config.yaml"

# checkpoint_dir = Path("../checkpoints/run_baseline")
# checkpoint_path = checkpoint_dir / "last.ckpt"
# training_config_path = checkpoint_dir / "training_config.yaml"

# Load training configuration
with open(training_config_path) as f:
    yaml_cfg = OmegaConf.load(f)

training_config = TrainingConfig(**OmegaConf.to_container(yaml_cfg, resolve=True)) # type: ignore
print(f"Loaded training config from {training_config_path}")

def create_test_dataloader(split: str = "test", batch_size: int = 128) -> DataLoader:
    """Create test dataloader from config."""
    test_data_path = Path().cwd().parent / "datasets" / "tokenized" / "bert-base-uncased" / "QQP-Official"
    test_data_path = test_data_path / split
                
    test_ds = Dataset.load_from_disk(test_data_path)
    test_text_ds = TextDataset(test_ds)

    return DataLoader(
        test_text_ds,
        batch_size=batch_size,
        collate_fn=collate,
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
    )

Loaded training config from ..\checkpoints\run_q2qzjeso\training_config.yaml


In [4]:
from shortcutfm.train.pl.train_unit import TrainModule

# Set random seed for reproducibility
pl.seed_everything(training_config.seed)

Seed set to 44


44

In [5]:
from shortcutfm.criteria import CompositeCriterion
from shortcutfm.train.pl.train_unit import TrainModule
import torch
from pathlib import Path
from typing import Union

def load_unit_from_checkpoint(
    criterion: Union[CompositeCriterion],
    checkpoint_path: Union[Path, str],
    training_config: TrainingConfig,
    denoising_step_size: int = None,
    prediction_shortcut_size: int = None,
) -> TrainModule:
    """Load and configure training unit from checkpoint with key remapping.

    :param criterion: Criterion instance to use for training
    :type criterion: CompositeCriterion | FlowNllCriterion
    :param checkpoint_path: Path to the checkpoint file
    :type checkpoint_path: Path | str
    :param training_config: Training configuration containing optimizer settings
    :type training_config: TrainingConfig
    :param denoising_step_size: Number of denoising steps (optional)
    :type denoising_step_size: int | None
    :param prediction_shortcut_size: Size of prediction shortcut (optional)
    :type prediction_shortcut_size: int | None
    :return: Configured training unit loaded from checkpoint
    :rtype: TrainModule
    """
    denoising_step_size = denoising_step_size or training_config.denoising_step_size
    prediction_shortcut_size = prediction_shortcut_size or training_config.prediction_shortcut_size

    # Load the checkpoint
    checkpoint = torch.load(str(checkpoint_path), map_location=torch.device('cpu'), weights_only=False)

    # Get the state dictionary from the checkpoint
    state_dict = checkpoint.get('state_dict', checkpoint)

    new_state_dict = {}
    for key, value in state_dict.items():
        # Remap keys: replace 'nll' with 'consistency_criterion' or 'embedding_criterion'
        if 'criterion.nll' in key or 'criterion.flow_matching_criterion' in key:
            # Determine which criterion to map to based on context or configuration
            # For simplicity, we assume mapping to 'consistency_criterion' for some keys
            # and 'embedding_criterion' for others. Adjust this logic as needed.
            new_key = key.replace('criterion.nll', 'criterion.embedding_criterion')
            new_key = key.replace('criterion.flow_matching_criterion', 'criterion.embedding_criterion')
            # Alternatively, for embedding_criterion, you might need a condition
            # new_key = key.replace('criterion.nll', 'criterion.embedding_criterion')
            new_state_dict[new_key] = value
        else:
            new_state_dict[key] = value

    # Initialize the TrainModule
    unit = TrainModule(
        criterion=criterion,
        optimizer_config=training_config.optimizer.scheduler,
        prediction_shortcut_size=prediction_shortcut_size,
        denoising_step_size=denoising_step_size,
    )

    # Load the remapped state dictionary into the model
    try:
        unit.load_state_dict(new_state_dict, strict=False)
    except RuntimeError as e:
        print(f"Error loading state dict: {e}")
        raise

    return unit

In [None]:
# Create criterion and load model from checkpoint
from itertools import islice

denoising_step_size = 128
shortcut_size = 128

criterion = create_criterion(training_config)
unit: TrainModule = load_unit_from_checkpoint(
        criterion,
        checkpoint_path,
        training_config,
        denoising_step_size,
        shortcut_size,
    )
print(f"Loaded model from {checkpoint_path}")

# Set the model to evaluation mode
unit.eval()
print("UNIT lodaed succesfully")


word emebedding reuires grad: False
lm head requires grad: True
Loaded model from ..\checkpoints\run_q2qzjeso\epoch=190-step=27100.ckpt
UNIT lodaed succesfully


In [7]:
limit_test_batches = 16

# Load the dataset
split = "valid"
test_dataloader = create_test_dataloader(split, batch_size=8)
if limit_test_batches is not None:
    test_dataloader = islice(test_dataloader, limit_test_batches)


In [8]:
tokenizer = unit.criterion.flow_matching_criterion.tokenizer

In [9]:
inputs = []
predictions = []

In [10]:
denoising_step_size = 2048
shortcut_size = 2048

total_batches = len(test_dataloader) if not isinstance(test_dataloader, islice) else limit_test_batches
for batch_idx, test_batch in enumerate(tqdm(test_dataloader, desc="Evaluating", total=total_batches)):
    test_batch = test_batch.to(unit.device)
    predicted_ids: Tensor = unit.criterion.denoise(
        test_batch, 
        shortcut_size, 
        step_size=denoising_step_size,
        probe_every_step=False,
        return_logits=False,
        use_ground_truth_embeddings=True,
    ) # type: ignore
    
    inputs.append(test_batch.seqs.detach().cpu())
    predictions.append(predicted_ids.detach().cpu())

Evaluating:   0%|          | 0/16 [00:00<?, ?it/s]

In [11]:
all_inputs = torch.cat(inputs, dim=0)
all_predictions = torch.cat(predictions, dim=0)

In [12]:
all_inputs.shape

torch.Size([128, 128])

In [13]:
import evaluate

from scripts.evaluate_seq import process_prediction, process_sequence


input_texts = [tokenizer.decode(seq, skip_special_tokens=False) for seq in all_inputs]
sources, references = zip(*[process_sequence(text, tokenizer) for text in input_texts], strict=False)
references = list(references)  # Convert tuple to list


# Process predictions (take last step predictions)
pred_texts = [tokenizer.decode(seq, skip_special_tokens=False) for seq in all_predictions]
hypotheses = [process_prediction(text, tokenizer) for text in pred_texts]

# Calculate metrics
metrics = {}

# BLEU score
bleu = evaluate.load("bleu")
bleu_score = bleu.compute(predictions=hypotheses, references=[[ref] for ref in references])
metrics["bleu"] = bleu_score
print(f"BLEU score: {bleu_score}")

BLEU score: {'bleu': 0.9937063061227939, 'precisions': [0.9964002879769619, 0.9944488501189532, 0.9929390997352162, 0.991044776119403], 'brevity_penalty': 1.0, 'length_ratio': 1.0021645021645023, 'translation_length': 1389, 'reference_length': 1386}


In [14]:
for src, ref, hyp in islice(zip(sources, references, hypotheses, strict=False), 10):
    print("*" * 30)
    print(f"Source: {src}")
    print(f"Reference: {ref}")
    print(f"Hypothesis: {hyp}")
    print("*" * 30)
    print("\n")

******************************
Source: why is tokyo so big?
Reference: why has tokyo grown to be a such large city?
Hypothesis: why has tokyo grown to be a such large city?
******************************


******************************
Source: why does he want to have sex with me not her?
Reference: why did he chose me to have sex with?
Hypothesis: why did he chose me to have sex with?
******************************


******************************
Source: what could be the effect of gst bill on indian economy?
Reference: how can the gst bill, passed by the rajyasabha yesterday, boost the indian economy?
Hypothesis: how can the gst bill, passed by the rajyasabha yesterday, boost the indian economy?
******************************


******************************
Source: how will the ban on 500 and 1000 rupee notes bring out the black money of the big shots who have lots of it in the swiss bank in a different currency?
Reference: how is demonetizing the rs 500 and 1000 currencies affect