In [1]:
%load_ext autoreload
%autoreload 2

import gc
import copy

import torch
import torch.nn as nn
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

import evaluate

from transformers import (
    T5Tokenizer,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
)

from src.model_new import (
    T5EncoderModelForTokenClassification,
    create_datasets,
)
import src.config
import src.data
import src.model_new


import peft
from peft import (
    LoraConfig,
    PeftModel
)

import random

print("Base Model:\t", src.config.base_model_name)
print("MPS:\t\t", torch.backends.mps.is_available())
ROOT = src.utils.get_project_root_path()
print("Path:\t\t", ROOT)
device = torch.device('cuda:0' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f"Using device:\t {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# torch.set_printoptions(threshold=10_000)

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
Base Model:	 Rostlab/prot_t5_xl_uniref50
MPS:		 True
Path:		 /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction
Using device:	 mps


In [2]:
t5_tokenizer = T5Tokenizer.from_pretrained(
        pretrained_model_name_or_path=src.config.base_model_name,
        do_lower_case=False,
        use_fast=True,
        legacy=False
    )

In [3]:
t5_base_model = T5EncoderModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=src.config.base_model_name,
    device_map='auto',
    load_in_8bit=False,
    custom_num_labels=len(src.config.label_decoding),
    custom_dropout_rate=0.1,
    )

Some weights of T5EncoderModelForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_t5_xl_uniref50 and are newly initialized: ['custom_classifier.bias', 'custom_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
adapter_location = '/models/testing_1'
t5_base_model.load_adapter(ROOT+adapter_location)

Loading adapter weights from /Users/finnlueth/Developer/gits/prottrans-t5-signalpeptide-prediction/models/testing_1 led to unexpected keys not found in the model:  ['crf.start_transitions', 'crf.end_transitions', 'crf.transitions']. 


---

In [5]:
FASTA_FILENAME = '5_SignalP_5.0_Training_set.fasta'
# FASTA_FILENAME = '5_SignalP_5.0_Training_set_testing.fasta'
annotations_name = 'Label' # Choose Type or Label

df_data = src.data.process(src.data.parse_file(ROOT + '/data/raw/' + FASTA_FILENAME))

dataset_signalp = src.model_new.create_datasets(
    splits=src.config.splits,
    tokenizer=t5_tokenizer,
    data=df_data,
    annotations_name=annotations_name,
    # dataset_size=src.config.dataset_size,
    dataset_size=3,
    encoder=src.config.select_encodings[annotations_name],
    )

del df_data

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [6]:
display(dataset_signalp)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9
    })
    valid: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 3
    })
})

In [7]:
count = 0
for sublist in dataset_signalp['valid']['labels']:
    for element in sublist:
        if element == 0:
            count += 1

count
# dataset_signalp['test']['labels']

210

---

In [9]:
_ds_index = 2
_ds_type = 'test'

_input_ids_test = t5_tokenizer.decode(dataset_signalp[_ds_type][_ds_index]['input_ids'])
_labels_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['labels'] + [-100]]).to(device)
_attention_mask_test = torch.tensor([dataset_signalp[_ds_type][_ds_index]['attention_mask']]).to(device)

_labels_test_decoded = [src.config.label_decoding[x] for x in _labels_test.tolist()[:-1]]

print('Iput IDs:\t', _input_ids_test)
print('Labels:\t\t', *_labels_test.tolist())
print('Labels Decoded:\t', *_labels_test_decoded)
print('Attention Mask:\t', *_attention_mask_test.tolist()[0])
print('----')

preds = src.model_new.predict_model(
    sequence=_input_ids_test,
    tokenizer=t5_tokenizer,
    model=t5_base_model,
    labels=_labels_test,
    attention_mask=_attention_mask_test,
    device=device,
    )

_result = src.model_new.translate_logits(preds.logits.cpu().numpy())

print('Result: \t',* _result)

Iput IDs:	 M A A V I L E R L G A L W V Q N L R G K L A L G I L P Q S H I H T S A S L E I S R K W E K K N K I V Y P P Q L P G E P R R P A E I Y H C R R</s>
Labels:		 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100]
Labels Decoded:	
Attention Mask:	 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
----
torch.Size([1, 71])
torch.Size([1, 71])
torch.Size([1, 71])
Result: 	 S S S O T L T M S L L O M T S S M M L M T L M O L T L L M M L L L S L S S S S S O O O T T M S I O M I T T I I I O I I T I T S T I L S T O T M


In [None]:
# pd.read_csv(ROOT + adapter_location + '/training_log.csv')
training_log = pd.read_parquet(ROOT + adapter_location + '/training_log.parquet')

In [None]:
display(training_log)

In [None]:
data_cm = np.concatenate(training_log['eval_confusion_matrix'].iloc[-1]).reshape(6, 6)

In [None]:
pd.DataFrame(data_cm)[0].sum()

In [None]:
ax = sns.heatmap(
    data_cm,
    annot=True,
    xticklabels=[src.config.label_decoding[label] for label in range(len(src.config.label_decoding))],
    yticklabels=[src.config.label_decoding[label] for label in range(len(src.config.label_decoding))],
    fmt='d'
    )

ax.set_title('Confusion Matrix')
ax.set_xlabel('Actual')
ax.set_ylabel('Predicted')

In [None]:
# for index, (param_name, param) in enumerate(t5_base_model.named_parameters()):
#     # if index == 11:
#     #     break
#     if param.requires_grad:
#         print(param_name)
#     # print(param)
#     if param_name in ['custom_classifier.bias', 'custom_classifier.weight']:
#         print(param)

In [None]:
# type(t5_base_model)

In [None]:
# [x for x in t5_base_model.custom_classifier.parameters()]

In [None]:
# t5_base_model.custom_classifier.weight