## Notebook Setup
___

In [2]:
%load_ext autoreload
%autoreload 2

## Packages
___

In [17]:
import re
import os
import math
import copy
import types
import yaml

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.nn import (
    CrossEntropyLoss,
    MSELoss
)

import evaluate

from transformers import (
    AutoModelForTokenClassification,
    AutoConfig,
    T5EncoderModel,
    T5Tokenizer,
    T5PreTrainedModel,
    T5ForConditionalGeneration,
    pipeline,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    set_seed,
    )
from transformers.modeling_outputs import TokenClassifierOutput

from peft import (
    LoraConfig,
    get_peft_model,
    TaskType,
    get_peft_config,
    PeftModel,
    PeftConfig,
    prepare_model_for_kbit_training
    )

from datasets import Dataset

import src.config as config

from src.model import (
    get_prottrans_tokenizer_model,
    df_to_dataset,
    inject_linear_layer,
    compute_metrics_full,
    compute_metrics_fast
    )
from src.utils import get_project_root_path
from tqdm import tqdm

---
## Setup and Variables

In [145]:
base_model_name = config.base_model_name
print("Base Model:\t", base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


In [146]:
lr = config.lr
batch_size = config.batch_size
num_epochs = config.num_epochs
dropout_rate = config.dropout_rate

label_encoding = config.label_encoding
label_list = config.label_decoding

compute_metrics = compute_metrics_fast

In [147]:
tqdm.pandas()

---
## Create Tokenizer and Load Model

In [6]:
model_architecture = T5EncoderModel
tokenizer, model = get_prottrans_tokenizer_model(base_model_name, model_architecture)

---
## Load Data, Split into Dataset, and Tokenize Sequences

In [211]:
df_data = pd.read_parquet(ROOT + '/data/processed/5.0_train.parquet.gzip')

In [212]:
df_data.head(3)

Unnamed: 0,Sequence,Label,Split
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train


In [213]:
ids = tokenizer.batch_encode_plus(df_data['Sequence'], add_special_tokens=True, padding=True)

In [214]:
df_data['input_ids'] = ids['input_ids']
df_data['attention_mask'] = ids['attention_mask']

In [215]:
df_data['input_ids'] = df_data.progress_apply(lambda x: torch.tensor(x['input_ids']).unsqueeze(0), axis=1)
df_data['attention_mask'] = df_data.progress_apply(lambda x: torch.tensor(x['attention_mask']).unsqueeze(0), axis=1)

  0%|          | 0/20758 [00:00<?, ?it/s]

100%|██████████| 20758/20758 [00:00<00:00, 73815.67it/s]
100%|██████████| 20758/20758 [00:00<00:00, 73977.99it/s]


In [216]:
df_data.head(3)

Unnamed: 0,Sequence,Label,Split,input_ids,attention_mask
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19), tensor(3), tensor(13), tensor(11...","[[tensor(1), tensor(1), tensor(1), tensor(1), ..."
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19), tensor(10), tensor(15), tensor(1...","[[tensor(1), tensor(1), tensor(1), tensor(1), ..."
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19), tensor(10), tensor(10), tensor(1...","[[tensor(1), tensor(1), tensor(1), tensor(1), ..."


In [217]:
df_data['input_ids'].at[0].device

device(type='cpu')

In [218]:
df_data['input_ids'] = df_data['input_ids'].progress_apply(lambda x: x.to(device))
df_data['attention_mask'] = df_data['attention_mask'].progress_apply(lambda x: x.to(device))

  0%|          | 0/20758 [00:00<?, ?it/s]

100%|██████████| 20758/20758 [00:03<00:00, 6303.89it/s]
100%|██████████| 20758/20758 [00:03<00:00, 6304.32it/s]


---
## Feature Extraction

In [219]:
model.to(device)
model.eval()
print(f"Model loaded to {device}.")

Model loaded to mps.


In [223]:
with torch.no_grad():
    df_data['embeddings'] = df_data.progress_apply(lambda x: model(x['input_ids'], x['attention_mask']), axis=1)

  0%|          | 0/20758 [00:00<?, ?it/s]

 14%|█▍        | 2958/20758 [17:49<1:46:26,  2.79it/s]

In [222]:
display(df_data.head())
display(df_data.tail())

Unnamed: 0,Sequence,Label,Split,input_ids,attention_mask
0,M A P T L F Q K L F S K R T G L G A P G R D A ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(3, device...","[[tensor(1, device='mps:0'), tensor(1, device=..."
1,M D F T S L E T T T F E E V V I A L G S N V G ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(10, devic...","[[tensor(1, device='mps:0'), tensor(1, device=..."
2,M D D I S G R Q T L P R I N R L L E H V G N P ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(10, devic...","[[tensor(1, device='mps:0'), tensor(1, device=..."
3,M L G T V K M E G H E T S D W N S Y Y A D T Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",test,"[[tensor(19, device='mps:0'), tensor(4, device...","[[tensor(1, device='mps:0'), tensor(1, device=..."
4,M L G A V K M E G H E P S D W S S Y Y A E P E ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",test,"[[tensor(19, device='mps:0'), tensor(4, device...","[[tensor(1, device='mps:0'), tensor(1, device=..."


Unnamed: 0,Sequence,Label,Split,input_ids,attention_mask
20753,M Q T Q V L F E H P L N E K M R T W L R I E F ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",test,"[[tensor(19, device='mps:0'), tensor(16, devic...","[[tensor(1, device='mps:0'), tensor(1, device=..."
20754,M Q S V T P T S Q Y L K A L N E G S H Q P D D ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(16, devic...","[[tensor(1, device='mps:0'), tensor(1, device=..."
20755,M R I F V Y G S L R T K Q G N S H W M T N A L ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(8, device...","[[tensor(1, device='mps:0'), tensor(1, device=..."
20756,M T M S L E V F E K L E A K V Q Q A I D T I T ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(11, devic...","[[tensor(1, device='mps:0'), tensor(1, device=..."
20757,M S A Q P V D I Q I F G R S L R V N C P P D Q ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",train,"[[tensor(19, device='mps:0'), tensor(7, device...","[[tensor(1, device='mps:0'), tensor(1, device=..."


In [None]:
df_data.to_parquet(ROOT + '/data/processed/5.0_train_embeddings.parquet.gzip', compression='gzip')

---
## Model