In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import src.config as model_config
from src.utils import get_project_root_path
import src.data
import src.metrics
import src.plot

from src.model import (
    ProstT5EncoderModelForSequenceClassification,
    ProstT5EncoderModelTokenClassificationCRF,
    ProstT5EncoderModelTokenClassificationLinear,
    ProtT5EncoderModelForSequenceClassification,
    ProtT5EncoderModelTokenClassificationCRF,
    ProtT5EncoderModelTokenClassificationLinear,
)

import gc
import random
import tqdm

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    set_seed
)

import peft
from peft import (
    LoraConfig,
)

In [3]:
ROOT = get_project_root_path()
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))

EXPERT = model_config.selected_expert
MODEL_VERRSION = model_config.model_version

adapter_location = f'/models/moe_v{MODEL_VERRSION}_CRF_expert_{EXPERT}'

SEED = model_config.seed
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
set_seed(SEED)

print("Base Model:\t", model_config.base_model_name)
print("MPS Availible:\t", torch.backends.mps.is_available())
print("Path:\t\t", ROOT)
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS Availible:	 False
Path:		 /home/ec2-user/src/SignalGPT
Using device:	 cuda:0


In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
    pretrained_model_name_or_path=model_config.base_model_name,
    do_lower_case=False,
    use_fast=True,
    legacy=False,
)

In [5]:
FASTA_FILENAME = model_config.dataset_name
# FASTA_FILENAME = '6_SignalP_6.0_Training_set_testing.fasta'
annotations_name = ['Label'] #+ ['Type'] # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp_type_splits = {}

for sequence_type in model_config.select_encoding_type.keys():
    dataset_signalp = src.data.create_datasets(
        splits=model_config.splits,
        tokenizer=t5_tokenizer,
        data=df_data,
        annotations_name=annotations_name,
        dataset_size=model_config.dataset_size,
        sequence_type=sequence_type
        )
    dataset_signalp_type_splits.update({sequence_type: dataset_signalp})

del df_data

dataset_signalp = dataset_signalp_type_splits[EXPERT]
print(EXPERT)
# print(dataset_signalp_type_splits)
print(dataset_signalp)

ALL
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 13288
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 7002
    })
})


In [6]:
t5_base_model = ProstT5EncoderModelTokenClassificationCRF.from_pretrained(
    pretrained_model_name_or_path=model_config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(model_config.select_decoding_type[EXPERT]),
    custom_dropout_rate=0.1,
)

tmp_lin = nn.Linear(
    in_features=t5_base_model.config.hidden_size,
    out_features=t5_base_model.custom_num_labels
)
t5_base_model.custom_classifier.weight = tmp_lin.weight
t5_base_model.custom_classifier.bias = tmp_lin.bias

t5_base_model.crf._constraint_mask = torch.nn.Parameter(t5_base_model.crf.tensor_constraint_mask, requires_grad=False)

t5_base_model.crf.reset_parameters()
modules_to_save = ['custom_classifier', 'crf']

modules_to_save = ['custom_classifier']

lora_config = LoraConfig(
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=['q', 'k', 'v', 'o'],
    bias="none",
    modules_to_save=modules_to_save,
)

t5_lora_model = peft.get_peft_model(t5_base_model, lora_config)
t5_lora_model.print_trainable_parameters()

  return self.fget.__get__(instance, owner)()
Some weights of ProstT5EncoderModelTokenClassificationCRF were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['crf._constraint_mask', 'crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 3,946,510 || all params: 1,212,088,478 || trainable%: 0.32559586792805073


In [7]:
data_collator = DataCollatorForTokenClassification(tokenizer=t5_tokenizer)

training_args = TrainingArguments(
    output_dir=ROOT+'/models/checkpoints',
    learning_rate=model_config.lr,
    per_device_train_batch_size=model_config.batch_size,
    per_device_eval_batch_size=model_config.batch_size*2,
    num_train_epochs=model_config.num_epochs,
    logging_steps=model_config.logging_steps,
    evaluation_strategy="steps",
    eval_steps=model_config.eval_steps,
    # weight_decay=0.01,
    # gradient_accumulation_steps=accum,
    save_strategy="steps",
    save_steps=model_config.save_steps,
    # save_total_limit=5,
    # load_best_model_at_end=True,
    # fp16=True,
    # deepspeed=deepspeed_config,
    remove_unused_columns=False,
    label_names=['labels'],
    seed=42,
    # debug="underflow_overflow",
)

trainer = Trainer(
    model=t5_lora_model,
    args=training_args,
    train_dataset=dataset_signalp['train'],
    eval_dataset=dataset_signalp['valid'],
    data_collator=data_collator,
    compute_metrics=src.metrics.compute_metrics,
)

In [8]:
initial_validation=trainer.evaluate()
added_initial_validation = False
print(initial_validation)

encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])


encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])
encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])
encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])
encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])
encoder_outputs...
truncate
crf
log
done
tensor.shape 1 torch.Size([32, 71])
tensor.shape 2 torch.Size([32, 71])
tensor.shape 1 torch.Size([32, 69, 7])
tensor.shape 2 torch.Size([32, 69, 7])
encoder_outputs...
truncate
crf
log
done
tens

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [23,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
gc.collect()
torch.cuda.empty_cache()
# torch.mps.empty_cache()

trainer.train()

In [None]:
final_validation=trainer.evaluate()
print(final_validation)

In [None]:
if 'training_log' not in locals():
    training_log = pd.DataFrame(trainer.state.log_history)
else:
    training_log = pd.concat([training_log, pd.DataFrame(trainer.state.log_history)], ignore_index=True)
if not added_initial_validation:
    added_initial_validation = True
    training_log = pd.concat([pd.DataFrame([initial_validation]), training_log], ignore_index=True)
display(training_log)

In [None]:
if 'eval_confusion_matrix' in training_log.columns:
    training_log['eval_confusion_matrix'] = training_log['eval_confusion_matrix'].apply(lambda x: x.tolist() if type(x)==np.ndarray else None)
t5_lora_model.save_pretrained(ROOT + adapter_location)
training_log.to_csv(ROOT + adapter_location + '/training_log.csv', index=False)
training_log.to_parquet(ROOT + adapter_location + '/training_log.parquet')

In [None]:

# training_log = pd.read_parquet(ROOT + f'/models/moe_v{MODEL_VERRSION}_linear_expert_{EXPERT}/training_log.parquet')
# adapter_location = f'/models/moe_v{MODEL_VERRSION}_expert_{EXPERT}'
training_log = pd.read_parquet(ROOT + adapter_location + '/training_log.parquet')

In [None]:
src.plot.confusion_matrix_plot(
    np.array(training_log['eval_confusion_matrix'][training_log['eval_confusion_matrix'].notnull()].iloc[-1].tolist()),
    model_config.select_decoding_type[EXPERT]
    )
plt.savefig(ROOT + adapter_location + '/fig_cm.jpg', dpi=400)

src.plot.loss_plot(training_log)
plt.savefig(ROOT + adapter_location + '/fig_loss.jpg', dpi=400)