# Testing the data module

This notebook contains tests for the `ProteinDataModule` class in the `plft.datamodule` module. It demonstrates how to load datasets using a configuration file and initialize the data module with appropriate preprocessing functions and tokenizers.


In [1]:
# Import necessary libraries
from pathlib import Path
from typing import Optional
from omegaconf import OmegaConf, DictConfig
from transformers import AutoTokenizer
from plft.datamodule import ProteinDataModule
from plft.configs.registries import PREPROC_REGISTRY

# Functions for loading the datasets

In [2]:
def load_dataset(
    project_root: Path,
    cfg: DictConfig,
):
    """
    Load datasets using the provided configuration.
    Args:
        project_root (Path): The root directory of the project.
        cfg (Union[str, Path, DictConfig]): Path to the config file or a DictConfig object.
    """
    tok_cfg = cfg["tokenizer"]
    data_cfg = cfg["data"]

    # Assume data paths are relative to the **project root**
    def resolve_path(p): 
        return str(Path(p)) if Path(p).is_absolute() else str(project_root / p)

    train_file = resolve_path(data_cfg.get("train_file"))
    val_file   = resolve_path(data_cfg.get("val_file"))
    test_file  = resolve_path(data_cfg.get("test_file"))

    print("Train:", train_file)
    print("Valid:", val_file)
    print("Test :", test_file)
    
    print("Loading tokenizer:", tok_cfg["name"])
    tokenizer = AutoTokenizer.from_pretrained(tok_cfg["name"])
    print("Tokenizer uses_fast:", getattr(tokenizer, "is_fast", False))

    # Preprocess function via registry
    preproc_name = data_cfg.get("preprocess", "none")
    preprocess_fn = PREPROC_REGISTRY.get(preproc_name, lambda s: s)
    print(preproc_name, "preprocess function:", preprocess_fn)
    dm = ProteinDataModule(
        train_file=train_file,
        val_file=val_file,
        test_file=test_file,
        tokenizer=tokenizer,
        preprocess_fn=preprocess_fn,
        max_length=data_cfg.get("max_length", 512),
        sequence_column=data_cfg.get("sequence_column", "sequence"),
        label_column=data_cfg.get("label_column", "label"),
        optional_features=data_cfg.get("optional_features", []),
    )

    datasets = dm.get_datasets()
    return datasets

In [3]:
def patch_config(cfg_path: str, **overrides):
    """
    Load a YAML config and override fields using dot notation.
    """
    cfg = OmegaConf.load(cfg_path)
    for k, v in overrides.items():
        OmegaConf.update(cfg, k, v, merge=False)
    return cfg

# File paths for the example datasets

In [4]:
# Define project root and config path
project_root = Path("..").resolve()
cfg_path=Path(project_root) / "plft/configs/template.yaml"

In [5]:
# Define data files
example_data = {}
example_data["gb1_seq_regression"] = {
        "data.train_file":"data/training_data/gb1/gb1_seq_regression_train.csv",
        "data.val_file":"data/training_data/gb1/gb1_seq_regression_validation.csv",
        "data.test_file":"data/training_data/gb1/gb1_seq_regression_test.csv",
}
example_data["scl_seq_classification"] = {
        "data.train_file":"data/training_data/scl/scl_seq_classification_train.csv",
        "data.val_file":"data/training_data/scl/scl_seq_classification_validation.csv",
        "data.test_file":"data/training_data/scl/scl_seq_classification_test.csv",
}
example_data["chezod_token_regression"] = {
        "data.train_file":"data/training_data/chezod/chezod_token_regression_train.csv",
        "data.val_file":"data/training_data/chezod/chezod_token_regression_validation.csv",
        "data.test_file":"data/training_data/chezod/chezod_token_regression_test.csv",
}
example_data["ss_token_classification"] = {
        "data.train_file":"data/training_data/ss/ss_token_classification_train.csv",
        "data.val_file":"data/training_data/ss/ss_token_classification_validation.csv",
        "data.test_file":"data/training_data/ss/ss_token_classification_test.csv",
}

# Examples of loading datasets with the data module

## Seq regression dataset

In [6]:

cfg = patch_config(
    cfg_path,
    **example_data["gb1_seq_regression"]
)
datasets = load_dataset(
    project_root=project_root,
    cfg=cfg,
)

Train: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/gb1/gb1_seq_regression_train.csv
Valid: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/gb1/gb1_seq_regression_validation.csv
Test : /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/gb1/gb1_seq_regression_test.csv
Loading tokenizer: Rostlab/prot_bert
Tokenizer uses_fast: True
protbert preprocess function: <function ProtBert_preprocess at 0x30e1bde40>


Map:   0%|          | 0/2691 [00:00<?, ? examples/s]

Applied preprocess_fn to sequences.
Sample preprocessed sequence: M Q Y K L I L N G K T L K G E T T T E A V D A A T A E K V F K Q Y A N D N G V D G E W T Y D D A T K T F T V T E L E V L F Q G P L D P N S M A T Y E V L C E V A R K L G T D D R E V V L F L L N V F I P Q P T L A Q L I G A L R A L K E E G R L T F P L L A E C L F R A G R R D L L R D L L H L D P R F L E R H L A G T M S Y F S P Y Q L T V L H V D G E L C A R D I R S L I F L S K D T I G S R S T P Q T F L H W V Y C M E N L D L L G P T D V D A L M S M L R S L S R V D L Q R Q V Q T L M G L H L S G P S H S Q H Y R H T P L E H H H H H H
Applied preprocess_fn to sequences.
Sample preprocessed sequence: M Q Y K L I L N G K T L K G E T T T E A V D A A T A E K V F K Q Y A N D N G I D E E W T Y D D A T K T F T T T E L E V L F Q G P L D P N S M A T Y E V L C E V A R K L G T D D R E V V L F L L N V F I P Q P T L A Q L I G A L R A L K E E G R L T F P L L A E C L F R A G R R D L L R D L L H L D P R F L E R H L A G T M S Y F S P Y Q L T V L H 

Map:   0%|          | 0/299 [00:00<?, ? examples/s]

Applied preprocess_fn to sequences.
Sample preprocessed sequence: M Q Y K L I L N G K T L K G E T T T E A V D A A T A E K V F K Q Y A N D N G V E G E W T Y D D A T K T F T V T E L E V L F Q G P L D P N S M A T Y E V L C E V A R K L G T D D R E V V L F L L N V F I P Q P T L A Q L I G A L R A L K E E G R L T F P L L A E C L F R A G R R D L L R D L L H L D P R F L E R H L A G T M S Y F S P Y Q L T V L H V D G E L C A R D I R S L I F L S K D T I G S R S T P Q T F L H W V Y C M E N L D L L G P T D V D A L M S M L R S L S R V D L Q R Q V Q T L M G L H L S G P S H S Q H Y R H T P L E H H H H H H


Map:   0%|          | 0/5743 [00:00<?, ? examples/s]

Applied preprocess_fn to sequences.
Sample preprocessed sequence: M Q Y K L I L N G K T L K G E T T T E A V D A A T A E K V F K Q Y A N D N G A A A E W T Y D D A T K T F T A T E L E V L F Q G P L D P N S M A T Y E V L C E V A R K L G T D D R E V V L F L L N V F I P Q P T L A Q L I G A L R A L K E E G R L T F P L L A E C L F R A G R R D L L R D L L H L D P R F L E R H L A G T M S Y F S P Y Q L T V L H V D G E L C A R D I R S L I F L S K D T I G S R S T P Q T F L H W V Y C M E N L D L L G P T D V D A L M S M L R S L S R V D L Q R Q V Q T L M G L H L S G P S H S Q H Y R H T P L E H H H H H H
Applied preprocess_fn to sequences.
Sample preprocessed sequence: M Q Y K L I L N G K T L K G E T T T E A V D A A T A E K V F K Q Y A N D N G E Y V E W T Y D D A T K T F T N T E L E V L F Q G P L D P N S M A T Y E V L C E V A R K L G T D D R E V V L F L L N V F I P Q P T L A Q L I G A L R A L K E E G R L T F P L L A E C L F R A G R R D L L R D L L H L D P R F L E R H L A G T M S Y F S P Y Q L T V L H 

In [7]:
datasets

DatasetDict({
    train: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2691
    })
    validation: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 299
    })
    test: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 5743
    })
})

In [8]:
for key in ["sequence","label","input_ids","attention_mask"]:
    print(f"{key}:",datasets["train"][0][key])

sequence: MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH
label: 1.0
input_ids: [2, 21, 18, 20, 12, 5, 11, 5, 17, 7, 12, 15, 5, 12, 7, 9, 15, 15, 15, 9, 6, 8, 14, 6, 6, 15, 6, 9, 12, 8, 19, 12, 18, 20, 6, 17, 14, 17, 7, 8, 14, 7, 9, 24, 15, 20, 14, 14, 6, 15, 12, 15, 19, 15, 8, 15, 9, 5, 9, 8, 5, 19, 18, 7, 16, 5, 14, 16, 17, 10, 21, 6, 15, 20, 9, 8, 5, 23, 9, 8, 6, 13, 12, 5, 7, 15, 14, 14, 13, 9, 8, 8, 5, 19, 5, 5, 17, 8, 19, 11, 16, 18, 16, 15, 5, 6, 18, 5, 11, 7, 6, 5, 13, 6, 5, 12, 9, 9, 7, 13, 5, 15, 19, 16, 5, 5, 6, 9, 23, 5, 19, 13, 6, 7, 13, 13, 14, 5, 5, 13, 14, 5, 5, 22, 5, 14, 16, 13, 19, 5, 9, 13, 22, 5, 6, 7, 15, 21, 10, 20, 19, 10, 16, 20, 18, 5, 15, 8, 5, 22, 8, 14, 7, 9, 5, 23, 6, 13, 14, 11, 13, 10, 5, 11, 19, 5, 10, 12, 14, 15, 11, 7, 10, 13, 10, 15, 16

## Seq classification dataset

In [52]:

cfg = patch_config(
    cfg_path,
    **example_data["scl_seq_classification"]
)
datasets = load_dataset(
    project_root=project_root,
    cfg=cfg,
)

Train: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/scl/scl_seq_classification_train.csv
Valid: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/scl/scl_seq_classification_validation.csv
Test : /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/scl/scl_seq_classification_test.csv
Loading tokenizer: Rostlab/prot_bert
Tokenizer uses_fast: True


Map:   0%|          | 0/1678 [00:00<?, ? examples/s]

In [53]:
datasets

DatasetDict({
    train: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9503
    })
    validation: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1678
    })
    test: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2768
    })
})

In [57]:
for key in ["sequence","label","input_ids","attention_mask"]:
    print(f"{key}:",datasets["train"][0][key])

sequence: MEVLEEPAPGPGGADAAERRGLRRLLLSGFQEELRALLVLAGPAFLAQLMMFLISFISSVFCGHLGKLELDAVTLAIAVINVTGISVGHGLSSACDTLISQTYGSQNLKHVGVILQRGTLILLLCCFPCWALFINTEQILLLFRQDPDVSRLTQTYVMVFIPALPAAFLYTLQVKYLLNQGIVLPQVITGIAANLVNALANYLFLHQLHLGVMGSALANTISQFALAIFLFLYILWRKLHHATWGGWSWECLQDWASFLQLAIPSMLMLCIEWWAYEVGSFLSGILGMVELGAQSITYELAIIVYMIPAGFSVAANVRVGNALGAGNIDQAKKSSAISLIVTELFAVTFCVLLLGCKDLVGYIFTTDWDIVALVAQVVPIYAVSHLFEALACTCGGVLRGTGNQKVGAIVNAIGYYVIGLPIGISLMFVAKLGVIGLWSGIIICSVCQTSCFLVFIARLNWKLACQQAQVHANLKVNVALNSAVSQEPAHPVGPESHGEIMMTDLEKKDEIQLDQQMNQQQALPVHPKDSNKLSGKQLALRRGLLFLGVVLVLVGGILVRVYIRTE
label: 0
input_ids: [2, 21, 9, 8, 5, 9, 9, 16, 6, 16, 7, 16, 7, 7, 6, 14, 6, 6, 9, 13, 13, 7, 5, 13, 13, 5, 5, 5, 10, 7, 19, 18, 9, 9, 5, 13, 6, 5, 5, 8, 5, 6, 7, 16, 6, 19, 5, 6, 18, 5, 21, 21, 19, 5, 11, 10, 19, 11, 10, 10, 8, 19, 23, 7, 22, 5, 7, 12, 5, 9, 5, 14, 6, 8, 15, 5, 6, 11, 6, 8, 11, 17, 8, 15, 7, 11, 10, 8, 7, 22, 7, 5, 10, 10, 6, 23, 14, 15, 5, 11, 10, 18, 15, 20, 7, 10, 18, 17, 5, 12, 22, 8, 7, 8, 11, 5, 

## Token regression dataset

In [69]:

cfg = patch_config(
    cfg_path,
    **example_data["chezod_token_regression"]
)
datasets = load_dataset(
    project_root=project_root,
    cfg=cfg,
)

Train: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_train.csv
Valid: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_validation.csv
Test : /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/chezod/chezod_token_regression_test.csv
Loading tokenizer: Rostlab/prot_bert
Tokenizer uses_fast: True


Map:   0%|          | 0/176 [00:00<?, ? examples/s]

In [70]:
datasets

DatasetDict({
    train: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 821
    })
    validation: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 176
    })
    test: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 177
    })
})

In [71]:
for key in ["sequence","label","input_ids","attention_mask"]:
    print(f"{key}:",datasets["train"][0][key])

sequence: RTNQAGLELIGNAEGCRRDPYMCPAGVWTDGIGNTHGVTPGVRKTDQQIAADWEKNILIAERCINQHFRGKDMPDNAFSAMTSAAFNMGCNSLRTYYSKARGMRVETSIHKWAQKGEWVNMCNHLPDFVNSNGVPLRGLKIRREKERQLCLTGLVNEHHHHHH
label: [999, 11.949, 14.319, 13.721, 13.014, 13.399, 13.538, 14.388, 14.856, 13.593, 12.446, 12.227, 14.036, 14.011, 13.477, 13.963, 14.782, 15.049, 13.729, 13.556, 13.904, 15.265, 14.584, 13.669, 10.621, 10.82, 10.562, 9.88, 7.584, 3.22, 3.351, 4.162, 3.65, 4.064, 4.426, 5.057, 4.362, 1.636, 2.463, 2.236, 3.227, 3.643, 6.034, 10.755, 13.402, 14.742, 13.645, 13.884, 14.788, 14.961, 13.882, 13.417, 14.173, 15.023, 13.969, 13.908, 14.375, 14.955, 14.867, 14.266, 14.236, 14.468, 14.625, 14.812, 14.093, 13.88, 13.775, 14.228, 13.578, 14.462, 14.485, 14.922, 13.134, 11.731, 12.771, 13.29, 14.0, 10.49, 10.234, 9.71, 11.881, 12.987, 13.881, 14.644, 14.621, 14.578, 14.89, 14.451, 14.736, 14.761, 15.084, 14.821, 14.566, 14.682, 15.186, 14.709, 14.646, 14.646, 14.844, 14.212, 11.829, 12.763, 12.456, 14.748, 12.448, 13.049, 1

## Token classification dataset

In [61]:
cfg = patch_config(
    cfg_path,
    **example_data["ss_token_classification"]
)
datasets = load_dataset(
    project_root=project_root,
    cfg=cfg,
)

Train: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/ss/ss_token_classification_train.csv
Valid: /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/ss/ss_token_classification_validation.csv
Test : /Users/leerangyang/Documents/Workspace/Projects/plft/data/training_data/ss/ss_token_classification_test.csv
Loading tokenizer: Rostlab/prot_bert
Tokenizer uses_fast: True


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/9712 [00:00<?, ? examples/s]

Map:   0%|          | 0/1080 [00:00<?, ? examples/s]

Map:   0%|          | 0/364 [00:00<?, ? examples/s]

In [62]:
datasets

DatasetDict({
    train: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9712
    })
    validation: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1080
    })
    test: Dataset({
        features: ['sequence', 'label', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 364
    })
})

In [63]:
for key in ["sequence","label","input_ids","attention_mask"]:
    print(f"{key}:",datasets["train"][0][key])

sequence: VTKPTIAAVGGYAMNNGTGTTLYTKAADTRRSTGSTTKIMTAKVVLAQSNLNLDAKVTIQKAYSDYVVANNASQAHLIVGDKVTVRQLLYGLMLPSGCDAAYALADKYGSGSTRAARVKSFIGKMNTAATNLGLHNTHFDSFDGIGNGANYSTPRDLTKIASSAMKNSTFRTVVKTKAYTAKTVTKTGSIRTMDTWKNTNGLLSSYSGAIGVKTGAGPEAKYCLVFAATRGGKTVIGTVLASTSIPARESDATKIMNYGFAL
label: [999, 999, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1