# Testing the model

This notebook contains code to test the `PLMTaskModel` from `plft.models.model`. We create a model instance, pass input data through it, and compute the loss. We test if LoRA layers are correctly integrated when specified.

In [1]:
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
from transformers import AutoTokenizer
from plft.configs.registries import PREPROC_REGISTRY
from plft.datamodule import ProteinDataModule
from plft.pipeline import get_task_head, get_full_model, get_trainer

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @main(config_path="configs", config_name="protbert_seqcls.yaml")


In [2]:
# Helper function to patch config with overrides
def patch_config(cfg_path: str, **overrides):
    """
    Load a configuration file and apply overrides.
    """
    cfg = OmegaConf.load(cfg_path)
    for k, v in overrides.items():
        OmegaConf.update(cfg, k, v, merge=False)
    return cfg

In [3]:
def load_dataset(
    project_root: Path,
    cfg: DictConfig,
):
    """
    Load datasets using the provided configuration.
    Args:
        project_root (Path): The root directory of the project.
        cfg (Union[str, Path, DictConfig]): Path to the config file or a DictConfig object.
    """
    tok_cfg = cfg["tokenizer"]
    data_cfg = cfg["data"]

    # Assume data paths are relative to the **project root**
    def resolve_path(p): 
        return str(Path(p)) if Path(p).is_absolute() else str(project_root / p)

    train_file = resolve_path(data_cfg.get("train_file"))
    val_file   = resolve_path(data_cfg.get("val_file"))
    test_file  = resolve_path(data_cfg.get("test_file"))

    print("Train:", train_file)
    print("Valid:", val_file)
    print("Test :", test_file)
    
    print("Loading tokenizer:", tok_cfg["name"])
    tokenizer = AutoTokenizer.from_pretrained(tok_cfg["name"])
    print("Tokenizer uses_fast:", getattr(tokenizer, "is_fast", False))

    # Preprocess function via registry
    preproc_name = data_cfg.get("preprocess", "none")
    preprocess_fn = PREPROC_REGISTRY.get(preproc_name, lambda s: s)
    
    dm = ProteinDataModule(
        train_file=train_file,
        val_file=val_file,
        test_file=test_file,
        tokenizer=tokenizer,
        preprocess_fn=preprocess_fn,
        max_length=data_cfg.get("max_length", 512),
        sequence_column=data_cfg.get("sequence_column", "sequence"),
        label_column=data_cfg.get("label_column", "label"),
        optional_features=data_cfg.get("optional_features", []),
    )
    return dm, tokenizer

In [4]:
# Define project root and config path
project_root = Path("..").resolve()
cfg_path=Path(project_root) / "plft/configs/template.yaml"

In [5]:
def crop_dataset(train_file, val_file, test_file, n=20):
    """
    Crop dataset and save as temporary files for testing
    """
    cropped_files = []
    for file in [train_file, val_file, test_file]:
        with open(file, 'r') as f:
            lines = f.readlines()
        cropped_lines = lines[:n]
        temp_file = Path(file).parent / f"Cropped_{Path(file).name}"
        cropped_files.append(str(temp_file))
        with open(temp_file, 'w') as f:
            f.writelines(cropped_lines)
        print(f"cropped {file} to {temp_file} with {n} lines.")
    return tuple(cropped_files)

def cleanup_cropped_files(cropped_files):
    """
    Remove temporary cropped files
    """
    for file in cropped_files:
        try:
            Path(file).unlink()
            print(f"Removed temporary file: {file}")
        except Exception as e:
            print(f"Error removing file {file}: {e}")

cropped_train, cropped_val, cropped_test = crop_dataset(
    train_file=project_root/"data/training_data/chezod/chezod_token_regression_train.csv",
    val_file=project_root/"data/training_data/chezod/chezod_token_regression_validation.csv",
    test_file=project_root/"data/training_data/chezod/chezod_token_regression_test.csv",
    n=20
)

cropped /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_train.csv to /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_train.csv with 20 lines.
cropped /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_validation.csv to /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_validation.csv with 20 lines.
cropped /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_test.csv to /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_test.csv with 20 lines.


In [6]:
# Example data configuration
example_data = {}
example_data["chezod_token_regression"] = {
        "data.train_file":cropped_train,
        "data.val_file":cropped_val,
        "data.test_file":cropped_test,
        "model.task_type":"TOKEN_REGRESSION",
        "trainer.epochs":2,
        "trainer.batch_size":4,
}

In [9]:
# Patch config with example data
cfg = patch_config(
    cfg_path,
    **example_data["chezod_token_regression"],
)

In [10]:
dm, tokenizer = load_dataset(
    project_root=project_root,
    cfg=cfg,
)

datasets = dm.get_datasets()

# 2. Define the task head
head = get_task_head(cfg, datasets)

# 3. Define the full model
model = get_full_model(head, cfg)

# 4. Define the trainer
trainer = get_trainer(
    model=model,
    tokenizer=tokenizer,
    datasets=datasets,
    cfg=cfg,
)

Train: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_train.csv
Valid: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_validation.csv
Test : /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/Cropped_chezod_token_regression_test.csv
Loading tokenizer: Rostlab/prot_bert
Tokenizer uses_fast: True


Map:   0%|          | 0/19 [00:00<?, ? examples/s]

Applied preprocess_fn to sequences.
Sample preprocessed sequence: M G H H H H H H L E E F T A E Q L S Q Y N G T D E S K P I Y V A I K G R V F D V T T G K S F Y G S G G D Y S M F A G K D A S R A L G K M S K N E E D V S P S L E G L T E K E I N T L N D W E T K F E A K Y P V V G R V V S
Injecting LoRA into target modules: ['query', 'key', 'value']


ValueError: False is not a valid SaveStrategy, please select one of ['no', 'steps', 'epoch', 'best']

In [None]:
print(model)

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): PLMTaskModel(
      (backbone): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30, 1024, padding_idx=0)
          (position_embeddings): Embedding(40000, 1024)
          (token_type_embeddings): Embedding(2, 1024)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-29): 30 x BertLayer(
              (attention): BertAttention(
                (self): BertSdpaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Line

In [None]:
trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys([])
tensor([[[-7.4477e-02],
         [-2.5808e-02],
         [-8.2666e-02],
         [-3.1672e-02],
         [-4.6815e-02],
         [ 3.8480e-03],
         [-7.8226e-02],
         [-9.3170e-02],
         [-5.3363e-02],
         [-6.8815e-02],
         [-7.3863e-02],
         [-7.9088e-02],
         [-9.9615e-02],
         [-8.4405e-02],
         [-7.9381e-02],
         [-5.6838e-02],
         [-7.5134e-02],
         [-5.8090e-02],
         [-9.3162e-02],
         [-5.5908e-02],
         [-7.7884e-02],
         [-6.1547e-02],
         [-6.7229e-02],
         [-5.4547e-02],
         [-6.5771e-02],
         [-8.2336e-02],
         [-8.7147e-02],
         [-9.9004e-02],
         [-8.9468e-02],
         [-4.9970e-02],
         [-6.4501e-02],
         [-4.2046e-02],
         [-5.2171e-02],
         [-5.9748e-02],
         [-6.3150e-02],
         [-8.1511e-02],
         [-1.3104e-02],
         [-3.5575e-02],
         [-1.2794e-01],
         [-6.5150e-02],
         [ 4.9568e-03],
  

Epoch,Training Loss,Validation Loss,Mse,Mae,Rmse
1,No log,126.875618,2162.011719,25.70701,46.497438
2,No log,126.696388,2162.254883,25.705103,46.500053


dict_keys([])
tensor([[[-3.6596e-02],
         [-6.5100e-02],
         [ 3.3747e-02],
         [-3.6727e-03],
         [-5.9837e-02],
         [-5.3349e-02],
         [-4.4049e-02],
         [-3.3782e-02],
         [-1.4346e-02],
         [-5.5837e-02],
         [-8.9123e-02],
         [-3.1987e-02],
         [-3.2710e-02],
         [ 1.2137e-03],
         [-6.9907e-02],
         [-6.6525e-03],
         [ 3.5730e-02],
         [-2.1202e-02],
         [-7.6419e-02],
         [-1.0592e-02],
         [ 7.8045e-03],
         [-5.9313e-02],
         [-5.7107e-02],
         [-3.1643e-02],
         [-8.2394e-02],
         [-1.1714e-02],
         [-3.5564e-02],
         [-1.0087e-02],
         [-5.1371e-02],
         [-1.0459e-01],
         [-3.0278e-02],
         [-6.7102e-02],
         [-6.9973e-02],
         [-3.5698e-02],
         [ 1.2975e-02],
         [-9.5618e-02],
         [-5.2890e-02],
         [-5.7761e-02],
         [-6.0248e-02],
         [-7.0761e-02],
         [-8.5534e-02],
  



dict_keys([])
tensor([[[-3.7390e-02],
         [-8.4189e-03],
         [-5.8928e-02],
         [ 5.5104e-02],
         [-2.8374e-02],
         [-2.8055e-02],
         [ 2.9194e-02],
         [-2.3800e-02],
         [-2.7453e-02],
         [ 2.1985e-02],
         [ 1.8095e-02],
         [-4.2371e-02],
         [ 1.4308e-02],
         [-6.7631e-02],
         [-1.0242e-02],
         [-3.8515e-02],
         [ 2.8486e-02],
         [-4.8275e-02],
         [ 3.2100e-02],
         [-1.1467e-02],
         [ 2.5765e-03],
         [-5.9205e-02],
         [-2.9557e-02],
         [ 2.4960e-03],
         [-1.4580e-02],
         [ 2.9858e-03],
         [-1.4274e-02],
         [-3.9150e-02],
         [-6.8057e-02],
         [-8.6481e-03],
         [ 2.7752e-02],
         [-6.3260e-03],
         [-1.2954e-01],
         [-3.8020e-02],
         [-1.3713e-02],
         [-1.2286e-01],
         [ 1.3976e-02],
         [-9.7412e-02],
         [-8.6562e-02],
         [ 2.0474e-02],
         [-2.5338e-02],
  

TrainOutput(global_step=10, training_loss=132.55665283203126, metrics={'train_runtime': 17.5405, 'train_samples_per_second': 2.166, 'train_steps_per_second': 0.57, 'total_flos': 14886233731296.0, 'train_loss': 132.55665283203126, 'epoch': 2.0})

In [None]:
split = cfg["trainer"].get("eval_split", "validation")
val_metrics = trainer.evaluate(split=split)



dict_keys([])
tensor([[[-0.0212],
         [ 0.0040],
         [-0.0200],
         [-0.0118],
         [-0.0003],
         [ 0.0037],
         [ 0.0023],
         [-0.0007],
         [-0.0085],
         [ 0.0100],
         [-0.0061],
         [ 0.0094],
         [-0.0019],
         [-0.0264],
         [-0.0164],
         [ 0.0043],
         [-0.0581],
         [-0.0241],
         [-0.0003],
         [-0.0103],
         [-0.0318],
         [-0.0026],
         [ 0.0056],
         [-0.0014],
         [ 0.0072],
         [-0.0046],
         [ 0.0238],
         [ 0.0081],
         [ 0.0464],
         [-0.0557],
         [-0.0764],
         [-0.0403],
         [-0.0300],
         [-0.0170],
         [ 0.0228],
         [ 0.0389],
         [ 0.0269],
         [-0.0504],
         [-0.0364],
         [ 0.0178],
         [-0.0186],
         [-0.0193],
         [ 0.0076],
         [ 0.0303],
         [ 0.0227],
         [ 0.0061],
         [ 0.0317],
         [-0.0157],
         [ 0.0053],
      

dict_keys([])
tensor([[[ 3.8419e-03],
         [-3.5224e-03],
         [-2.3220e-02],
         [-2.3885e-02],
         [-2.2975e-02],
         [-2.0492e-02],
         [-2.2415e-02],
         [-2.9199e-02],
         [-4.1211e-02],
         [-1.9891e-02],
         [-1.6423e-02],
         [ 5.2393e-03],
         [ 4.2332e-02],
         [ 5.0398e-02],
         [ 2.9794e-02],
         [ 2.3218e-02],
         [ 7.2910e-02],
         [ 3.0658e-02],
         [ 2.5612e-02],
         [ 5.0449e-02],
         [-2.4795e-02],
         [ 2.5792e-03],
         [ 1.2809e-02],
         [-1.3303e-02],
         [ 2.7678e-02],
         [ 1.1922e-02],
         [-2.5402e-02],
         [ 6.0843e-02],
         [ 2.4558e-02],
         [ 1.4528e-02],
         [-4.3954e-03],
         [ 2.7157e-02],
         [-1.2929e-02],
         [ 1.2184e-02],
         [ 3.7409e-02],
         [-6.9825e-03],
         [ 5.2330e-02],
         [ 8.7611e-03],
         [-3.6795e-03],
         [ 6.0751e-02],
         [ 7.2244e-03],
  

In [None]:
print(f"Eval metrics on {split}:", val_metrics)

Eval metrics on validation: {'eval_loss': 126.6963882446289, 'eval_mse': 2162.2548828125, 'eval_mae': 25.705102920532227, 'eval_rmse': 46.50005250333057, 'eval_runtime': 2.3649, 'eval_samples_per_second': 8.034, 'eval_steps_per_second': 2.114, 'epoch': 2.0}


In [None]:
# cleanup_cropped_files((cropped_train, cropped_val, cropped_test))