In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import src.config as model_config
from src.utils import get_project_root_path
import src.data
import src.metrics

from src.model import (
    T5EncoderModelForTokenClassification,
)

import gc
import copy
import random
import tqdm

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback
)

import peft
from peft import (
    LoraConfig,
)

In [3]:
ROOT = get_project_root_path()
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))

USE_CRF = model_config.use_crf

EXPERT = model_config.selected_expert
MODEL_VERRSION = model_config.model_version

adapter_location = f'/models/moe_v{MODEL_VERRSION}_linear_expert_{EXPERT}'

SEED = model_config.seed
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

print("Base Model:\t", model_config.base_model_name)
print("MPS Availible:\t", torch.backends.mps.is_available())
print("Path:\t\t", ROOT)
print(f"Using device:\t {device}")

Base Model:	 Rostlab/ProstT5
MPS Availible:	 True
Path:		 /Users/finnlueth/Developer/gits/SignalGPT
Using device:	 mps


In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.base_model_name,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
model_config.select_encoding_type.keys()

dict_keys(['ALL', 'NO_SP', 'SP', 'LIPO', 'TAT'])

In [6]:
# df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + model_config.dataset_name))
# df_data['Label'].apply(lambda x: len(x)).describe()

In [7]:
# FASTA_FILENAME = model_config.dataset_name
FASTA_FILENAME = '6_SignalP_6.0_Training_set_testing.fasta'
annotations_name = ['Label'] #+ ['Type'] # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp_type_splits = {}

for sequence_type in model_config.select_encoding_type.keys():
    dataset_signalp = src.data.create_datasets(
        splits=model_config.splits,
        tokenizer=t5_tokenizer,
        data=df_data,
        annotations_name=annotations_name,
        dataset_size=model_config.dataset_size,
        sequence_type=sequence_type
        )
    dataset_signalp_type_splits.update({sequence_type: dataset_signalp})

del df_data

In [8]:
dataset_signalp = dataset_signalp_type_splits[EXPERT]
print(EXPERT)
# print(dataset_signalp_type_splits)
print(dataset_signalp)

ALL
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 6
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
})


In [9]:
# index = 1
# print(len(dataset_signalp['train']['input_ids'][index]), dataset_signalp['train']['input_ids'][index])
# print(len(dataset_signalp['train']['labels'][index]), dataset_signalp['train']['labels'][index])
# print(len(dataset_signalp['train']['attention_mask'][index]), dataset_signalp['train']['attention_mask'][index])

In [10]:
# for x in range(3):
#     print(len(dataset_signalp['valid'][x]['labels']))
#     print(*dataset_signalp['valid'][x]['labels'])

In [11]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=model_config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(model_config.select_decoding_type[EXPERT]),
    custom_dropout_rate=0.1,
    use_crf=USE_CRF,
)

tmp_lin = nn.Linear(
    in_features=t5_base_model.config.hidden_size,
    out_features=t5_base_model.custom_num_labels
)
t5_base_model.custom_classifier.weight = tmp_lin.weight
t5_base_model.custom_classifier.bias = tmp_lin.bias

t5_base_model.crf._constraint_mask = torch.nn.Parameter(t5_base_model.crf.tensor_constraint_mask, requires_grad=False)

if USE_CRF:
    t5_base_model.crf.reset_parameters()
    modules_to_save = ['custom_classifier', 'crf']
else:
    modules_to_save = ['custom_classifier']

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
    modules_to_save=modules_to_save,
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

  return self.fget.__get__(instance, owner)()
Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/ProstT5 and are newly initialized: ['crf._constraint_mask', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 3,946,798 || all params: 1,212,111,150 || trainable%: 0.3256135379993823


In [12]:
crf = src.model.ConditionalRandomField(
                num_tags=7,
                constraints=src.config.allowed_transitions_encoded,
                )

In [13]:
aaa = torch.nn.Parameter(torch.tensor([[1., 0., 1., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 1., 1., 0., 0., 0., 0., 1.],
        [0., 0., 1., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 1., 0., 0.],
        [1., 1., 0., 1., 1., 1., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]]), requires_grad=False)

In [21]:
t5_base_model.crf.original_module._constraint_mask, t5_base_model.crf.original_module.tensor_constraint_mask

(Parameter containing:
 tensor([[1., 0., 1., 0., 0., 0., 0., 0., 1.],
         [0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 1., 0., 0.],
         [1., 1., 0., 1., 1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='mps:0',
        requires_grad=True),
 Parameter containing:
 tensor([[1., 0., 1., 0., 0., 0., 0., 0., 1.],
         [0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 1., 0., 0.],
         [1., 1., 0., 1., 1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='mps:0',
        requires_grad=True))

In [22]:
t5_base_model.crf.modules_to_save.default._constraint_mask, t5_base_model.crf.modules_to_save.default.tensor_constraint_mask

(Parameter containing:
 tensor([[1., 0., 1., 0., 0., 0., 0., 0., 1.],
         [0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 1., 0., 0.],
         [1., 1., 0., 1., 1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='mps:0',
        requires_grad=True),
 tensor([[1., 0., 1., 0., 0., 0., 0., 0., 1.],
         [0., 1., 0., 1., 0., 0., 0., 0., 0.],
         [1., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 1., 1., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 1., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 1., 0., 0.],
         [1., 1., 0., 1., 1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

In [16]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir=ROOT+'/models/checkpoints',
    learning_rate=model_config.lr,
    per_device_train_batch_size=model_config.batch_size,
    per_device_eval_batch_size=model_config.batch_size,
    num_train_epochs=model_config.num_epochs,
    logging_steps=model_config.logging_steps,
    evaluation_strategy="steps",
    eval_steps=model_config.eval_steps,
    # weight_decay=0.01,
    # gradient_accumulation_steps=accum,
    save_strategy="steps",
    save_steps=model_config.save_steps,
    # save_total_limit=5,
    # load_best_model_at_end=True,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels'],
    seed=42,
    # debug="underflow_overflow",
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=src.metrics.compute_metrics,
)

In [19]:
initial_validation=trainer.evaluate()
added_initial_validation = False
print(initial_validation)

loc("cast"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/4e1473ee-9f66-11ee-8daf-cedaeb4cabe2/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":745:0)): error: 'anec.gain_offset_control' op result #0 must be 4D/5D memref of 16-bit float or 8-bit signed integer or 8-bit unsigned integer values, but got 'memref<1x3x1x70xi1>'


  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.4429473578929901, 'eval_accuracy_metric': 0.12560386473429952, 'eval_precision_metric': 0.12560386473429952, 'eval_recall_metric': 0.12560386473429952, 'eval_f1_metric': 0.12560386473429952, 'eval_matthews_correlation': -0.00443142609272187, 'eval_confusion_matrix': array([[ 0, 12,  1,  6, 14, 12,  3],
       [ 0,  0,  2,  2, 13,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [12, 12,  3, 25, 19, 20,  6],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  1,  2,  6, 14,  1,  0],
       [ 0,  0,  0,  0,  0,  0,  0]]), 'eval_runtime': 24.9334, 'eval_samples_per_second': 0.12, 'eval_steps_per_second': 0.04}


In [None]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

trainer.train()

In [None]:
final_validation=trainer.evaluate()
print(final_validation)

In [None]:
if 'training_log' not in locals():
    training_log = pd.DataFrame(trainer.state.log_history)
else:
    training_log = pd.concat([training_log, pd.DataFrame(trainer.state.log_history)], ignore_index=True)
if not added_initial_validation:
    added_initial_validation = True
    training_log = pd.concat([pd.DataFrame([initial_validation]), training_log], ignore_index=True)
display(training_log)

In [None]:
if 'eval_confusion_matrix' in training_log.columns:
    training_log['eval_confusion_matrix'] = training_log['eval_confusion_matrix'].apply(lambda x: x.tolist() if type(x)==np.ndarray else None)
t5_lora_model.save_pretrained(ROOT + adapter_location)
training_log.to_csv(ROOT + adapter_location + '/training_log.csv', index=False)
training_log.to_parquet(ROOT + adapter_location + '/training_log.parquet')

---

In [None]:

# training_log = pd.read_parquet(ROOT + f'/models/moe_v{MODEL_VERRSION}_linear_expert_{EXPERT}/training_log.parquet')
# adapter_location = f'/models/moe_v{MODEL_VERRSION}_expert_{EXPERT}'
training_log = pd.read_parquet(ROOT + adapter_location + '/training_log.parquet')

In [None]:
src.model_new.confusion_matrix_plot(
    np.array(training_log['eval_confusion_matrix'][training_log['eval_confusion_matrix'].notnull()].iloc[-1].tolist()),
    model_config.select_decoding_type[EXPERT]
    )
plt.savefig(ROOT + adapter_location + '/fig_cm.jpg', dpi=400)

src.model_new.loss_plot(training_log)
plt.savefig(ROOT + adapter_location + '/fig_loss.jpg', dpi=400)

---

In [None]:
# [x for x in t5_lora_model.custom_classifier.modules_to_save.default.named_parameters()]

In [None]:
# display(pd.Series([item for row in dataset_signalp['train']['labels'] for item in row]).value_counts())
# display(pd.Series([item for row in dataset_signalp['valid']['labels'] for item in row]).value_counts())
# display(pd.Series([item for row in dataset_signalp['test']['labels'] for item in row]).value_counts())

# src.model_new.make_confusion_matrix(
#     training_log['eval_confusion_matrix'].iloc[-1],
#     model_config.select_decoding_type[EXPERT])

In [None]:
# _ds_index = 3250
# _ds_index = 3250
# _ds_type = 'test'
# USE_CRF = False

# _input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
# _labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
# _attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

# _labels_test_decoded = [model_config.label_decoding[x] for x in _labels_test.tolist()[0][:-1]]

# print('Iput IDs:\t', _input_ids_test)
# print('Labels:\t\t', *_labels_test.tolist()[0])
# print('Labels Decoded:\t', *_labels_test_decoded)
# print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
# print('----')

# preds = src.model_new.predict_model(
#     sequence=_input_ids_test,
#     tokenizer=t5_tokenizer,
#     model=t5_lora_model,
#     labels=_labels_test,
#     attention_mask=_attention_mask_test,
#     device=device,
#     viterbi_decoding=USE_CRF,
#     )

# _result = src.model_new.translate_logits(
#     logits=preds.logits,
#     viterbi_decoding=USE_CRF,
#     decoding=model_config.label_decoding
#     )

# print('Result: \t',* _result)