## Evaluating AbLMs on test tests

In [None]:
from datasets import load_dataset
from transformers import (
    EsmTokenizer,
    EsmForMaskedLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)

In [None]:
# Load and prepare your dataset
dataset = load_dataset(
    'csv',
    data_files={'test': './data/test/test_dataset.csv'}
)

In [None]:
# Load the tokenizer and model
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

model_path = './01all_esm_models/deepspeed/esm/all_checkpoints_4good/m_150M_full_batch_128_2025-02-11/checkpoint-500000'
model = EsmForMaskedLM.from_pretrained(model_path)

In [None]:
# Define parameters for tokenization
MAX_LEN = 320  # e.g. train_config["max_length"]
SEPARATOR = "<cls><cls>"  # e.g. train_config["separator_token"]

In [None]:
# Tokenize the dataset
def preprocess_function(example):
    # Combine the heavy and light chain sequences
    sequence = example['sequence_aa_heavy'] + SEPARATOR + example['sequence_aa_light']
    # Tokenize with the same settings used during training
    tokenized = tokenizer(
        sequence,
        padding='max_length',
        truncation=True,
        max_length=MAX_LEN,
        add_special_tokens=True,
    )
    # Add special tokens mask if required (as done in training)
    tokenized['special_tokens_mask'] = tokenizer.get_special_tokens_mask(tokenized['input_ids'], already_has_special_tokens=True)
    return tokenized

tokenized_datasets = dataset.map(preprocess_function, batched=False)
eval_dataset = tokenized_datasets['test']

In [None]:
# Prepare the data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15  # Use the same probability as in training
)

In [None]:
# Setup evaluation arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_eval_batch_size=32,
    logging_dir='./logs',
    do_eval=True,
    evaluation_strategy="no",
    report_to="none",  # Explicitly disable W&B logging
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

In [None]:
# Run evaluation
eval_results = trainer.evaluate()

# Print the eval_loss (average cross‑entropy per token)
loss = eval_results['eval_loss']
print(f"Cross‑Entropy Loss (eval_loss): {loss:.4f}")