In [2]:
%reload_ext autoreload
%autoreload 2

from _header_model import *

# device = "cpu"

print("MPS Availible:\t", torch.backends.mps.is_available())
print(f"Using device:\t {device}")

MPS Availible:	 False
Using device:	 cuda:0


---
### DataLoader

<!-- Create HF Dataset

```py
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'cath_id','temperature', 'replica'],
        num_rows: n
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'cath_id','temperature', 'replica'],
        num_rows: n
    })
})
```

```rust
input_ids: Amino Acid Sequence
attention_mask: Padding Mask
cath_id: cath identifier i.e. 1a0rP01
replica: replica numbner in {0, 1, 2, 3, 4}
temperature: temperature of trajectory in {320, 348, 379, 413, 450}
sequence: original sequence
pssm: PSSM as numpy array dim(20, L)
``` -->

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=BASE_MODEL,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

dict_pssm = np.load(os.path.join('../', FILE_PATHS["pssm"], "dict_pssm.npy"), allow_pickle=True).item()


def pssm_to_hf_dataset(dict_pssm: dict, tokenizer: T5Tokenizer) -> Dataset:
    # todo add train/test split
    ds_dict = {"cath_id": [], "temperature": [], "replica": [], "sequence": [], "sequence_processed": [], "pssm": []}

    for key, value in dict_pssm.items():
        parts = key.split("|")
        ds_dict["cath_id"].append(parts[0])
        ds_dict["temperature"].append(int(parts[1]))
        ds_dict["replica"].append(parts[2])
        ds_dict["sequence"].append(parts[3])
        ds_dict["sequence_processed"].append(" ".join(parts[3]))
        ds_dict["pssm"].append(value.tolist())

    tokenized_sequences = tokenizer(
        text=ds_dict["sequence_processed"],
        padding=False,
        truncation=False,
        max_length=512,
    )
    ds = Dataset.from_dict(tokenized_sequences)
    for key, value in ds_dict.items():
        ds = ds.add_column(key, value)
    
    # ds = ds.map(lambda examples: {'pssm': [torch.tensor(pssm) for pssm in examples['pssm']]}, batched=True)

    return DatasetDict({"train": ds, "test": ds})


ds = pssm_to_hf_dataset(dict_pssm=dict_pssm, tokenizer=t5_tokenizer)
ds = ds.remove_columns(["cath_id", "replica", "sequence", 'sequence_processed', "temperature"])
ds = ds.rename_column("pssm", "labels")
# ds = ds.remove_columns("labels")

ds["train"] = ds["train"].select([0, 49])
ds["test"] = ds["test"].select([0, 49])

print(ds)

# i = 0
# print(len(ds["train"]["attention_mask"][i]), ":", *ds["train"]["input_ids"][i])
# print(len(ds["train"]["attention_mask"][i]), ':', *ds["train"]["attention_mask"][i])
# display(pd.DataFrame(ds["train"]["labels"][i]))
# print(type(torch.tensor(ds["train"]["labels"][i])))

---
### Model Loading and LoRA

In [None]:
t5_base_model, loading_info = T5EncoderModelForPssmGeneration.from_pretrained(
    pretrained_model_name_or_path=BASE_MODEL,
    output_loading_info=True,
    # device_map=device,
    # load_in_8bit=False,
    # custom_dropout_rate=0.1,
)

modules_to_save = ["classifier"]

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q", "k", "v", "o"],
    bias="none",
    modules_to_save=loading_info['missing_keys'], # check if saving 'pssm_head' also works
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()
print(loading_info)

---
### Model Training

In [None]:
data_collator = DataCollatorForT5Pssm(
    tokenizer=t5_tokenizer,
    padding=True,
    max_length=512,
)

training_args = TrainingArguments(
    output_dir=FILE_PATHS["models"],
    learning_rate=TRAINING_CONFIG["learning_rate"],
    per_device_train_batch_size=TRAINING_CONFIG["batch_size"],
    per_device_eval_batch_size=TRAINING_CONFIG["batch_size"] * 2,
    num_train_epochs=TRAINING_CONFIG["num_epochs"],
    logging_steps=TRAINING_CONFIG["logging_steps"],
    evaluation_strategy="steps", # use eval_strategy
    eval_steps=TRAINING_CONFIG["eval_steps"],
    save_strategy="steps",
    save_steps=TRAINING_CONFIG["save_steps"],
    remove_unused_columns=True,
    # label_names=["labels"],
    seed=SEED,
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=ds["train"],
    # eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# print(*ds['train']['input_ids'], sep="\n")

# for name, param in t5_base_model.named_parameters():
#     print(name)

# t5_base_model.encoder.block[0].layer[0].SelfAttention.q.weight

In [None]:
gc.collect()

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
if torch.backends.mps.is_available():
    torch.mps.empty_cache()

trainer.train()

In [None]:
# parent_class_name = T5EncoderModelForPssmGeneration.__bases__[0].__name__
# parent_class_name