In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gc
import copy

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    T5EncoderModelForSequenceClassification,
    create_datasets,
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

import random

from tqdm import tqdm
tqdm.pandas()


In [3]:
ROOT = src.utils.get_project_root_path()
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))


EXPERT = 'NO_SP'
MODEL_VERRSION = src.config.model_version
MODEL = 'linear'
adapter_location = f'/models/moe_v{MODEL_VERRSION}_'

USE_CRF = MODEL == 'crf'
SEED = 42
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
print("Path:\t\t", ROOT)
print(f"Using device:\t {device}")
print('Using CRF\t\t', USE_CRF)

# torch.set_printoptions(threshold=10_000)

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 False
Path:		 /home/ec2-user/developer/prottrans-t5-signalpeptide-prediction
Using device:	 cuda:0
Using CRF		 False


In [4]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [5]:
# FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# # FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
# annotations_name = ['Label'] + ['Type'] # Choose Type or Label

# df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

# dataset_signalp_type_splits = {}

# for sequence_type in src.config.select_encoding_type.keys():
#     dataset_signalp = src.model_new.create_datasets(
#         splits=src.config.splits,
#         tokenizer=t5_tokenizer,
#         data=df_data,
#         annotations_name=annotations_name,
#         dataset_size=src.config.dataset_size,
#         sequence_type=sequence_type
#         )
#     dataset_signalp_type_splits.update({sequence_type: dataset_signalp})

# del df_data

# dataset_signalp = dataset_signalp_type_splits[EXPERT]
# display(dataset_signalp_type_splits)

In [6]:
# t5_base_model_gate = T5EncoderModelForSequenceClassification.from_pretrained(
#     pretrained_model_name_or_path=src.config.base_model_name,
#     device_map='auto',
#     load_in_8bit=False,
#     custom_num_labels=len(src.config.type_encoding),
#     custom_dropout_rate=0.1,
#     )

In [7]:
# len(src.config.label_encoding)

In [8]:
src.config.select_encoding_type[EXPERT]

{'I': 0, 'M': 1, 'O': 2}

In [9]:
t5_base_model_expert = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.select_encoding_type[EXPERT]),
    custom_dropout_rate=0.1,
    use_crf=USE_CRF
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# gate_adapter_location = adapter_location+'gate'
# t5_base_model_gate.load_adapter(ROOT+gate_adapter_location, adapter_name=f"gate")
# t5_base_model_gate.set_adapter("gate")

In [11]:
expert_adapter_location = adapter_location+f'{MODEL}_expert_{EXPERT}'
adapter_name = f"adapter_{EXPERT}"
print(expert_adapter_location, adapter_name)
t5_base_model_expert.load_adapter(ROOT+expert_adapter_location, adapter_name=adapter_name)
t5_base_model_expert.set_adapter(adapter_name)

/models/moe_v1_linear_expert_NO_SP adapter_NO_SP


---
todo

X gate\
X lin all\
X lin expert\
X crf all \
O crf expert

In [12]:
# FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))
# df_data = df_data[df_data['Partition_No'] == 4].reset_index(drop=True)
# df_data['Sequence_Raw'] = df_data['Sequence'].apply(lambda x: x.replace(' ', ''))
# # df_data['Mask'] = [x[1:] for x in dataset_signalp_type_splits['ALL']['test']['attention_mask']]
# df_data['input_ids'] = dataset_signalp_type_splits['ALL']['test']['input_ids']
# df_data['ds_attention_mask'] = dataset_signalp_type_splits['ALL']['test']['attention_mask']
# df_data['ds_labels'] = dataset_signalp_type_splits['ALL']['test']['labels']
# df_data['ds_type'] = dataset_signalp_type_splits['ALL']['test']['type']

In [13]:
df_data = pd.read_parquet('./results/df_data.parquet.gzip')

In [14]:
# print(dataset_signalp_type_splits['ALL']['test'])
display(df_data.head(), df_data.shape)

Unnamed: 0,Uniprot_AC,Kingdom,Type,Partition_No,Sequence,Label,Sequence_Raw,input_ids,ds_attention_mask,ds_labels,ds_type,predicted_type,predicted_label_linear_ALL,predicted_label_linear_experts,predicted_label_crf_ALL,predicted_label_crf_experts,predicted_label_linear_experts_imperfect_viterbi,predicted_label_crf_experts_imperfect,predicted_label_linear_experts_imperfect
0,P55317,EUKARYA,NO_SP,4,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MLGTVKMEGHETSDWNSYYADTQEAYSSVPVSNMNSGLGSMNSMNT...,"[19, 4, 5, 11, 6, 14, 19, 9, 5, 20, 9, 11, 7, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,
1,P35583,EUKARYA,NO_SP,4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MLGAVKMEGHEPSDWSSYYAEPEGYSSVSNMNAGLGMNGMNTYMSM...,"[19, 4, 5, 3, 6, 14, 19, 9, 5, 20, 9, 13, 7, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,
2,Q8UVD9,EUKARYA,NO_SP,4,M E I S T P D F G F G T E D S S A Q Q S A N R ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MEISTPDFGFGTEDSSAQQSANRAIPQPVPAPAFPLKETASDTGGT...,"[19, 9, 12, 7, 11, 13, 10, 15, 5, 15, 5, 11, 9...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,
3,Q99PF5,EUKARYA,NO_SP,4,M S D Y S T G G P P P G P P P P A G G G G G A ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MSDYSTGGPPPGPPPPAGGGGGAAGAGGGPPPGPPGAGDRGGGGPG...,"[19, 7, 10, 18, 7, 11, 5, 5, 13, 13, 13, 5, 13...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,
4,Q9URU9,EUKARYA,NO_SP,4,M N F R P E Q Q Y I L E K P G I L L S F E Q L ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MNFRPEQQYILEKPGILLSFEQLRINFKHILRHLEHESHVINSTLT...,"[19, 17, 15, 8, 13, 9, 16, 16, 18, 12, 4, 9, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,


(4147, 19)

In [15]:
# df_data['predicted_type'] = df_data.progress_apply(lambda x:
#     src.model_new.translate_logits(
#         src.model_new.predict_model(
#             sequence=x['Sequence'],
#             tokenizer=t5_tokenizer,
#             model=t5_base_model_gate,
#             attention_mask=torch.Tensor([x['ds_attention_mask']]).to(device),
#             device=device
#             )['logits'].unsqueeze(0),
#         src.config.type_decoding,
#         viterbi_decoding=False
#         )[0],  axis=1
#     )
# display(df_data.head())
# df_data.to_parquet('./results/df_data.parquet.gzip', compression='gzip')

In [16]:
# # test_df = df_data.tail(700).copy(deep=True)

# df_data[f'predicted_label_crf_{EXPERT}'] = df_data.progress_apply(lambda x:
#     ''.join(src.model_new.translate_logits(
#         src.model_new.predict_model(
#             sequence=x['Sequence'],
#             tokenizer=t5_tokenizer,
#             model=t5_base_model_expert,
#             attention_mask=torch.Tensor([x['ds_attention_mask']]).to(device),
#             device=device
#             )['logits'],
#         src.config.select_decoding_type[EXPERT],
#         viterbi_decoding=USE_CRF
#         ))[:np.count_nonzero(x['ds_attention_mask'])-1],  axis=1
#     )
# display(df_data.head())

In [17]:
# test_df = df_data.tail(20).copy(deep=True)
column_name = 'predicted_label_linear_experts_imperfect'

df_data[column_name] = df_data.progress_apply(lambda x:
    ''.join(src.model_new.translate_logits(
        src.model_new.predict_model(
            sequence=x['Sequence'],
            tokenizer=t5_tokenizer,
            model=t5_base_model_expert,
            attention_mask=torch.Tensor([x['ds_attention_mask']]).to(device),
            device=device
            )['logits'],
        src.config.select_decoding_type[EXPERT],
        viterbi_decoding=USE_CRF
        ))[:np.count_nonzero(x['ds_attention_mask'])-1] if x['predicted_type'] == EXPERT else x[column_name],  axis=1
    )

  attention_mask=torch.Tensor([x['ds_attention_mask']]).to(device),
100%|██████████| 4147/4147 [03:20<00:00, 20.69it/s] 


In [18]:
display(df_data.head(), df_data.shape)

Unnamed: 0,Uniprot_AC,Kingdom,Type,Partition_No,Sequence,Label,Sequence_Raw,input_ids,ds_attention_mask,ds_labels,ds_type,predicted_type,predicted_label_linear_ALL,predicted_label_linear_experts,predicted_label_crf_ALL,predicted_label_crf_experts,predicted_label_linear_experts_imperfect_viterbi,predicted_label_crf_experts_imperfect,predicted_label_linear_experts_imperfect
0,P55317,EUKARYA,NO_SP,4,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MLGTVKMEGHETSDWNSYYADTQEAYSSVPVSNMNSGLGSMNSMNT...,"[19, 4, 5, 11, 6, 14, 19, 9, 5, 20, 9, 11, 7, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
1,P35583,EUKARYA,NO_SP,4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MLGAVKMEGHEPSDWSSYYAEPEGYSSVSNMNAGLGMNGMNTYMSM...,"[19, 4, 5, 3, 6, 14, 19, 9, 5, 20, 9, 13, 7, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
2,Q8UVD9,EUKARYA,NO_SP,4,M E I S T P D F G F G T E D S S A Q Q S A N R ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MEISTPDFGFGTEDSSAQQSANRAIPQPVPAPAFPLKETASDTGGT...,"[19, 9, 12, 7, 11, 13, 10, 15, 5, 15, 5, 11, 9...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
3,Q99PF5,EUKARYA,NO_SP,4,M S D Y S T G G P P P G P P P P A G G G G G A ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MSDYSTGGPPPGPPPPAGGGGGAAGAGGGPPPGPPGAGDRGGGGPG...,"[19, 7, 10, 18, 7, 11, 5, 5, 13, 13, 13, 5, 13...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...
4,Q9URU9,EUKARYA,NO_SP,4,M N F R P E Q Q Y I L E K P G I L L S F E Q L ...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,MNFRPEQQYILEKPGILLSFEQLRINFKHILRHLEHESHVINSTLT...,"[19, 17, 15, 8, 13, 9, 16, 16, 18, 12, 4, 9, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0,NO_SP,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...,IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII...


(4147, 19)

In [19]:
inspect = df_data[df_data['predicted_type'] == EXPERT]
inspect[column_name].value_counts().shape, inspect[column_name].value_counts()

((267,),
 predicted_label_linear_experts_imperfect
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII    2802
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIM       5
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII          4
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMOIIII       2
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII                     2
                                                                           ... 
 IIIIIMMMMMMMMMMIMMIIIIIIIIIIIIIIIIIIIMMIMMMMMMMMMMMMIMIIIIIIIIIIIIIIII       1
 IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMIIIII       1
 IIIIIIIIIIIMIMMMMMMMMIIMIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII       1
 IIIIIIIIIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMIIIIIIIIIIOIIIIIIIIIIIIIIIIII       1
 IIIIIIIOIIIIIIIOIIIIIIIIIOIIIIOIIIIOIIOOOOIMMIOIMMMMMMMMMMMMMMMMIMMMII       1
 Name: count, Length: 267, dtype: int64)

In [20]:
sum(df_data[column_name].isna())

0

In [21]:
# test_df[test_df['Sequence_Raw'].str.len() != 70]

In [22]:
# loc = 3458
# print(test_df['predicted_label_linear_ALL'].loc[loc], test_df['Label'].loc[loc])
# print(test_df['predicted_label_linear_ALL'].loc[loc].__len__(), test_df['Label'].loc[loc].__len__())

In [23]:
# df_data[df_data['predicted_type'] != df_data['Type']]
# t5_tokenizer.decode(t5_tokenizer.encode('M T E T L P P V T E S A V A L Q A E V T Q R E L F E F V L N D P L L A S S L Y I N I A L A G L S I L L F V F M T R G L D D P R A K L I A V S'))
# dataset_signalp_type_splits['ALL']['test']['attention_mask']
# decoded_shit = [(x, ''.join([t5_tokenizer.decode(z) for z in y])) for x, y in zip(df_data['Sequence'], dataset_signalp_type_splits['ALL']['test']['input_ids'])]
# print(*decoded_shit, sep='\n')
# len(decoded_shit)
# len(df_data['Sequence'].at[0]), len(df_data['Mask'].at[0])# print(*[(x, y) for x, y in zip(df_data['Sequence'], df_data['Mask'])], sep='\n')
#  df_singalp_split = df_singalp6_preds[
#     df_singalp6_preds['Type'] == 'SP'
# ]
# test_df = df_data.head().copy(deep=True)
# test_df
# test_df['asd'] = test_df.apply(lambda x: (x['Type'], x['Mask']), axis=1)

In [24]:
# df_data['predicted_label_linear_experts_imperfect'] = None

In [25]:
# df_data = df_data.rename(columns={'predicted_label_linear_experts_imperfect': 'predicted_label_linear_experts_imperfect_viterbi'})

In [26]:
# df_data

In [27]:
df_data.to_parquet('./results/df_data.parquet.gzip', compression='gzip')

In [54]:
# df_data.columns

In [53]:
# with pd.option_context('display.max_rows', None,
#                        'display.max_columns', None,
#                        'display.precision', 3,
#                        'display.max_colwidth', 100
#                        ):
#     display(df_data[[
#         'Type',
#         'predicted_type',
#         'Label',
#         'predicted_label_linear_ALL',
#         'predicted_label_linear_experts',
#         'predicted_label_crf_ALL',
#         'predicted_label_crf_experts',
#         'predicted_label_linear_experts_imperfect_viterbi',
#         'predicted_label_crf_experts_imperfect',
#         'predicted_label_linear_experts_imperfect'
#         ]].tail(1000))

---

In [28]:
# expert_adapter_location = adapter_location + f'expert_{EXPERT}'
# t5_base_model_expert.load_adapter(ROOT+expert_adapter_location)

# FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
# df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))
# df_data = df_data[df_data['Partition_No'] == 4].reset_index(drop=True)
# df_data['Sequence'] = df_data['Sequence'].apply(lambda x: x.replace(' ', ''))

In [29]:
# # df_data['Type_Prediction'] = 
# df_data['Label'].iloc[17:18].apply(lambda x: src.model_new.moe_inference(
#     sequence=x,
#     tokenizer=t5_tokenizer,
#     model_gate=t5_base_model_gate,
#     model_expert=t5_base_model_expert,
#     device=device,
#     # result_type='SP',
#     )[0])

In [30]:
# EXPERT = 'LIPO'
# expert_adapter_location = ROOT + adapter_location + f'expert_{EXPERT}'
# print(expert_adapter_location)

In [31]:
# t5_base_model_expert.load_adapter(expert_adapter_location, adapter_name=f"{EXPERT}_1")

In [32]:
# t5_base_model_expert.unload(EXPERT)

In [33]:
# _ds_index = 4
# # _input_ids_test = df_data['Sequence'].iloc[_ds_index]
# # _labels_test = df_data['Label'].iloc[_ds_index]
# # _type_test = df_data['Type'].iloc[_ds_index]
# _input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
# _labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
# _attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)


# print('Iput IDs:\t', _input_ids_test)
# print('Labels:\t\t', _labels_test)
# print('Type:\t\t', _type_test)

# result = src.model_new.moe_inference(
#     sequence=_input_ids_test,
#     attentino_mask=_attention_mask_test,
#     tokenizer=t5_tokenizer,
#     model_gate=t5_base_model_gate,
#     model_expert=t5_base_model_expert,
#     device=device,
#     result_type='LIPO',
#     use_crf=True,
# )

# print(result)

In [34]:
# t5_base_model_gate.unload()

---

In [35]:
# # _ds_index = 220
# # _ds_type = 'test'
# # USE_CRF = True

# # _input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
# # _labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
# # _attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

# # _labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[0][:-1]]

# # print('Iput IDs:\t', _input_ids_test)
# # print('Labels:\t\t', *_labels_test.tolist()[0])
# # print('Labels Decoded:\t', *_labels_test_decoded)
# # print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
# # print('----')

# # _ds_index = 3250
# _ds_index = 3250
# _ds_type = 'test'
# USE_CRF = True

# _input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'][:-1])
# _labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
# _attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

# _labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[0][:-1]]

# print('Iput IDs:\t', _input_ids_test)
# print('Labels:\t\t', *_labels_test.tolist()[0])
# print('Labels Decoded:\t', *_labels_test_decoded)
# print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
# print('----')

# preds = src.model_new.predict_model(
#     sequence=_input_ids_test,
#     tokenizer=t5_tokenizer,
#     model=t5_base_model_expert,
#     labels=_labels_test,
#     attention_mask=_attention_mask_test,
#     device=device,
#     viterbi_decoding=USE_CRF,
#     )

# _result = src.model_new.translate_logits(
#     logits=preds.logits,
#     viterbi_decoding=USE_CRF,
#     decoding=src.config.label_decoding
#     )

# print('Result: \t',* _result)