In [1]:
%reload_ext autoreload
%autoreload 2

import gc
import os

import yaml

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from typing import List, Optional, Tuple, Union

import datasets
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from torch.nn import CrossEntropyLoss
from transformers import (
    DataCollatorForTokenClassification,
    PretrainedConfig,
    PreTrainedModel,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_outputs import TokenClassifierOutput

from plms import (
    PLMConfig,
    ProteinLanguageModelPredictor,
    auto_model,
    auto_tokenizer,
)
from plms.models.plm_token_classification import PLMConfigForTokenClassification, PLMForTokenClassification

In [27]:
dataset_strings = [
    ("ACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccc"),
    ("ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccaaaaaaaabbbbbbbbccccc"),
    ("ACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccc"),
    ("ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccaaaaaaaabbbbbbbbccccc"),
    ("ACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccc"),
    ("ACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccc"),
    ("ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccaaaaaaaabbbbbbbbccccc"),
    ("ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccaaaaaaaabbbbbbbbccccc"),
    ("ACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccc"),
    ("ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWYOOO", "aaaaaaaabbbbbbbbccccccaaaaaaaabbbbbbbbccccc"),
]

config_yaml = """
metadata:
  identifier: test123
model:
  encoder_name_or_path: Rostlab/ProstT5
#   encoder_name_or_path: Rostlab/prot_t5_xl_uniref50
  num_labels: 3
  classifier_dropout: 0.1
  hidden_size: 1024
training_args:
  output_dir: ./tmp/models/checkpoints
  run_name: test
#   report_to: None
  learning_rate: 0.0001
  per_device_train_batch_size: 6
  per_device_eval_batch_size: 6
  num_train_epochs: 100
  logging_steps: 1
  logging_strategy: steps
  evaluation_strategy: steps
  eval_steps: 1
  eval_strategy: steps
  eval_on_start: true
  batch_eval_metrics: false
  save_strategy: steps
  save_steps: 300
  save_total_limit: 5
  remove_unused_columns: true
  label_names: ['labels']
  seed: 42
  lr_scheduler_type: cosine
  warmup_steps: 0
lora:
  inference_mode: false
  r: 8
  lora_alpha: 16
  lora_dropout: 0.05
  use_rslora: false
  use_dora: false
  target_modules: ['q', 'v']
  bias: none
data_collator:
  padding: true
  pad_to_multiple_of: 8
extender:
  name: Rostlab/ProstT5
  use_extender: false
"""

LABEL_ENCODING = {
    "a": 0,
    "b": 1,
    "c": 2,
}

config = yaml.safe_load(config_yaml)

tokenizer = auto_tokenizer(config["model"]["encoder_name_or_path"])

dataset_dict = {
    "sequence": [x[0] for x in dataset_strings],
    "labels": [[LABEL_ENCODING[y] for y in x[1]] for x in dataset_strings],
}
dataset = datasets.Dataset.from_dict(dataset_dict)

tokens = tokenizer.encode(dataset_dict["sequence"])

dataset = dataset.add_column("input_ids", tokens["input_ids"])
dataset = dataset.add_column("attention_mask", tokens["attention_mask"])

In [3]:
model_config = PLMConfigForTokenClassification(**config["model"])
model = PLMForTokenClassification(model_config)
lora_config = LoraConfig(**config["lora"], modules_to_save=model.get_modules_to_save())
model = get_peft_model(model, lora_config)

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer.get_tokenizer(), **config["data_collator"])

training_args = TrainingArguments(**config["training_args"])

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
)

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()
trainer.evaluate()

gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

In [97]:
model.eval()

test_dataset_dict = {
    "sequence": ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG", "ACDEFGHIKLMNPQRSTVWYOOO"],
    "labels": [
        [0] * len("MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGGxx"),
        [LABEL_ENCODING[y] for y in "aaaaaaaaabbbbbbbbcccccccc"],
    ],
}

test_dataset = datasets.Dataset.from_dict(test_dataset_dict)

test_tokens = tokenizer.encode(test_dataset_dict["sequence"])

test_dataset = test_dataset.add_column("input_ids", test_tokens["input_ids"])
test_dataset = test_dataset.add_column("attention_mask", test_tokens["attention_mask"])

In [None]:
for k, v in test_dataset[1].items():
    print(k, len(v), v)


In [None]:
index = 1
with torch.no_grad():
    outputs = model(
        input_ids=torch.tensor(test_dataset["input_ids"][index]).to(model.device).unsqueeze(0),
        attention_mask=torch.tensor(test_dataset["attention_mask"][index]).to(model.device).unsqueeze(0),
        labels=torch.tensor(test_dataset["labels"][index]).to(model.device).unsqueeze(0),
    )

LABEL_DECODING = {v: k for k, v in LABEL_ENCODING.items()}

predictions = outputs.logits.argmax(dim=-1).tolist()
print(*[LABEL_DECODING[pred] for pred in predictions[0]], sep="")

In [None]:
LABEL_DECODING