In [10]:
%load_ext autoreload
%autoreload 2

import gc
import copy

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    T5EncoderModelForSequenceClassification,
    create_datasets,
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

import random

ROOT = src.utils.get_project_root_path()
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))

USE_CRF = True
EXPERT = 'ALL'

SEED = 42
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
print("Path:\t\t", ROOT)
print(f"Using device:\t {device}")

# torch.set_printoptions(threshold=10_000)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 False
Path:		 /home/ec2-user/developer/prottrans-t5-signalpeptide-prediction
Using device:	 cuda:0


In [2]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [3]:
t5_base_model_gate = T5EncoderModelForSequenceClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.type_encoding),
    custom_dropout_rate=0.1,
    use_crf=USE_CRF
    )

Some weights of T5EncoderModelForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier_out.bias', 'custom_classifier_in.bias', 'custom_classifier_in.weight', 'custom_classifier_out.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
t5_base_model_expert = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
adapter_location = '/models/moe_v1_'

In [6]:
gate_adapter_location = adapter_location+'gate'
t5_base_model_gate.load_adapter(ROOT+gate_adapter_location)

In [11]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp_type_splits = {}

dataset_signalp_type_splits.update(
    {'ALL': src.model_new.create_datasets(
        splits=src.config.splits,
        tokenizer=t5_tokenizer,
        data=df_data,
        annotations_name=annotations_name,
        dataset_size=src.config.dataset_size,
        encoder=src.config.label_encoding,
    )})

for sequence_type in src.config.type_encoding.keys():
    dataset_signalp = src.model_new.create_datasets(
        splits=src.config.splits,
        tokenizer=t5_tokenizer,
        data=df_data,
        annotations_name=annotations_name,
        dataset_size=src.config.dataset_size,
        encoder=src.config.select_encoding_type[sequence_type],
        sequence_type=sequence_type
        )
    dataset_signalp_type_splits.update({sequence_type: dataset_signalp})

del df_data

dataset_signalp = dataset_signalp_type_splits[EXPERT]
display(dataset_signalp)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 12462
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4149
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4147
    })
})

In [12]:
_ds_index = 0
_ds_type = 'test'

_input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
_labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels']]).to(device)
_attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

_labels_test_decoded = [src.config.type_decoding[x] for x in _labels_test.tolist()]

TypeError: unhashable type: 'list'

In [None]:
src.model_new.moe_inference(
    sequence=_input_ids_test,
    tokenizer=t5_tokenizer,
    model_gate=t5_base_model_gate,
    model_expert=t5_base_model_gate,
    labels=_labels_test,
    attention_mask=_attention_mask_test,
    device=device,
    result_type='NO_SP'
)

In [16]:
expert_adapter_location = adapter_location+'expert_'+EXPERT
t5_base_model_expert.load_adapter(ROOT+expert_adapter_location)

Loading adapter weights from /home/ec2-user/developer/prottrans-t5-signalpeptide-prediction/models/moe_v1_expert_ALL led to unexpected keys not found in the model:  ['crf.start_transitions', 'crf.end_transitions', 'crf.transitions']. 


In [26]:
# _ds_index = 3250
_ds_index = 3250
_ds_type = 'test'
USE_CRF = True

_input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
_labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
_attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[0][:-1]]

print('Iput IDs:\t', _input_ids_test)
print('Labels:\t\t', *_labels_test.tolist()[0])
print('Labels Decoded:\t', *_labels_test_decoded)
print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
print('----')

preds = src.model_new.predict_model(
    sequence=_input_ids_test,
    tokenizer=t5_tokenizer,
    model=t5_base_model_expert,
    labels=_labels_test,
    attention_mask=_attention_mask_test,
    device=device,
    viterbi_decoding=USE_CRF,
    )

_result = src.model_new.translate_logits(
    logits=preds.logits,
    viterbi_decoding=USE_CRF,
    decoding=src.config.label_decoding
    )

print('Result: \t',* _result)

Iput IDs:	 M D F L H R N G V L I I Q H L Q K D Y R A Y Y T F L N F M S N V G D P R N I F F I Y F P L C F Q F N Q T V G T K M I W V A V I G D W L N L I
Labels:		 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 0 0 0 0 0 0 0 0 0 0 0 2 2 2 2 2 2 2 2 2 2 2 2 2 2 -100
Labels Decoded:	 O O O O O O O O O O O O O O O O O O O O O O O O M M M M M M M M M M M M M M M M M M M M M I I I I I I I I I I I M M M M M M M M M M M M M M
Attention Mask:	 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
----


AttributeError: 'T5EncoderModelForTokenClassification' object has no attribute 'crf'