In [1]:
import time
from omegaconf import OmegaConf
from torch import nn
import sys
import os
sys.path.append(os.getcwd())
from torch.utils.tensorboard import SummaryWriter
from data_filters.sliding_window import SlidingWindowFilterTuple
from examples.promoters.data import Promoters
from torch.utils.data import Dataset, DataLoader
import torch
from metrics.f1 import F1Metric
from memup.accumulator import Accumulator
from memup.loss import PredictorLoss, LossModule, PredictorLossStateOnly, EvalLossStateOnly
from memup.preproc import IncrementStep
from examples.promoters.modules import MemUpMemoryImpl, DataCollectorTrain, DataCollectorLastState
from common_modules.transformers import BertRecurrentTransformerWithTokenizer, BertClassifier
from transformers import AutoTokenizer
from gena_lm.modeling_bert import BertModel, BertForSequenceClassification
from memup.base import MemoryRollout
from examples.promoters.modules import DataType
from metrics.accuracy import AccuracyMetric
from absl import flags

conf = OmegaConf.load('/home/slavic/PycharmProjects/filtered-transformer/examples/promoters/config.yaml')


train_data = Promoters([os.path.join(conf.data.path, f) for f in conf.data.train])
test_data = Promoters([os.path.join(conf.data.path, conf.data.test)])

train_loader = DataLoader(train_data, shuffle=True, batch_size=conf.model.batch_size)
test_loader = DataLoader(test_data, shuffle=False, batch_size=conf.model.eval_batch_size)

rollout = conf.model.rec_block_size
state_length = conf.model.state_size
data_filter = SlidingWindowFilterTuple[DataType](rollout, pad_fields={"text"}, padding=conf.model.rec_block_padding, skip_fields={"target", "length"})

tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base')
bert_model: BertModel = BertForSequenceClassification.from_pretrained('AIRI-Institute/gena-lm-bert-base').bert
mem_transformer = BertRecurrentTransformerWithTokenizer(bert_model, tokenizer, conf.model.max_token_length, 4, 3, bert_model.config.hidden_size * 2).cuda()
predictor = BertClassifier(2, bert_model.config, 4, 2, bert_model.config.hidden_size).cuda()

weights = torch.load("/home/slavic/PycharmProjects/promoter.pt")
mem_transformer.load_state_dict(weights["mem"])
predictor.load_state_dict(weights["pred"])



2023-02-27 21:42:50.617602: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-27 21:42:51.090046: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-27 21:42:51.090085: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


/home/slavic/PycharmProjects/promoters16/fold_1.csv
/home/slavic/PycharmProjects/promoters16/fold_2.csv
/home/slavic/PycharmProjects/promoters16/fold_3.csv
/home/slavic/PycharmProjects/promoters16/fold_5.csv


Some weights of the model checkpoint at AIRI-Institute/gena-lm-bert-base were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

<All keys matched successfully>

In [25]:
from tqdm.notebook import tqdm
import pandas as pd
from difflib import SequenceMatcher

fa_data = pd.read_csv("/home/slavic/Downloads/Telegram Desktop/pr_spl_fwd_test_mod.fa", header=None)
fa_data_list = [fa_data.iloc[i][0][:2000] for i in range(fa_data.shape[0])]


In [26]:
mem_transformer.eval()
predictor.eval()

memup_iter = MemoryRollout[DataType](
    steps=200,
    memory=MemUpMemoryImpl(mem_transformer),
    data_filter=data_filter,
    info_update=[IncrementStep()]
)


predictor_loss = PredictorLossStateOnly(predictor, [
        LossModule(nn.CrossEntropyLoss(), "CE", 1.0),
        LossModule(AccuracyMetric(), "Accuracy", 0.0)
])

all_pred = []
all_labels = []

with torch.no_grad():

    text = fa_data_list
    state2 = torch.zeros(len(text), state_length, bert_model.config.hidden_size, device=torch.device("cuda"))
    T = len(text[0])
    print("T =", T)

    collector, _, _, _ = memup_iter.forward(DataType(text, labels, T), state2, {}, DataCollectorLastState())
    target_seq, state_seq = collector.result()
    s0 = torch.cat(state_seq, 0)
    logits = predictor(s0)
    pred = logits.argmax(-1).reshape(-1).cpu().numpy()
    print(logits[:10])
    print(pred)

T = 2000




tensor([[-2.5012,  1.9648],
        [-0.2476,  0.0933],
        [-1.6873,  1.2889],
        [-1.3233,  0.8667],
        [-1.3856,  0.9312],
        [-0.1955,  0.0634],
        [-0.8124,  0.5294],
        [-2.4897,  1.9106],
        [-2.4382,  1.9563],
        [-1.6044,  1.0846]], device='cuda:0')
[1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 0 0 1 0 1 0 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 0 1 1
 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1 1 0 1 1 1 1 0 1 1 1 0 1 1 1 1 1 1 1
 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1
 1 1 0 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 0 1 0]
