In [None]:
print(1)


: 

In [None]:
import json
import logging
import sys
import os
import time
from typing import Dict
sys.path.append("/home/jovyan/filtered-transformer/")

from data_filters.top_errors import InputTarget, TopErrorsFilter
from memup.accumulator import Accumulator
from metrics.pearson import MeanPearsonCorrCoefPerChannel, PearsonCorrLoss
from torch import Tensor, nn
from memup.loss import LossModule, PredictorLoss, PredictorLossWithContext
from memup.base import CT, SD, DataCollectorAppend, MemoryOut, MemoryRollout, State
from memup.preproc import IncrementStep
from pathlib import Path
import torch
from torch.utils.data import DataLoader, DistributedSampler
import transformers
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser
from transformers.optimization import AdamW
from sklearn.metrics import average_precision_score
import numpy as np
from examples.enformer_2.data import EnformerDataset
from examples.enformer_2.modules import BertForEnformer, DataCollectorTrain, DataFilter, MemUpMemoryImpl, Predictor
from gena_lm.modeling_bert import BertPreTrainedModel, BertModel
from torch.nn.utils.rnn import pad_sequence

torch.cuda.set_device("cuda:0")

tokenizer = AutoTokenizer.from_pretrained('AIRI-Institute/gena-lm-bert-base')
model_cfg = AutoConfig.from_pretrained('AIRI-Institute/gena-lm-bert-base')
model_cfg.num_labels = EnformerDataset.TG_COUNT
model = BertForEnformer(config=model_cfg)

model = model.cuda()
model.eval()

predictor = Predictor(model_cfg).cuda()
predictor.eval()

weights = torch.load("/home/jovyan/enformer.pt", map_location="cpu")
model.load_state_dict(weights["mem_acc"])
predictor.load_state_dict(weights["pred_acc"])

def collate_fn(batch):
    pad_token_ids = {'input_ids': tokenizer.pad_token_id, 'token_type_ids': 0, 'attention_mask': 0, 'bins_mask': 0, 'labels': 0}

    def pad_batch(name, feature_keys):

        padded_batch = {k: [] for k in feature_keys}
        
        for k in feature_keys:
            padded_batch[k] = pad_sequence(
                [torch.from_numpy(el[name][k]) for el in batch], 
                batch_first=True, 
                padding_value=pad_token_ids[k]
            )

        return padded_batch

    padded_center = pad_batch("center", ['input_ids', 'token_type_ids', 'attention_mask', 'bins_mask'])
    padded_center['labels'] = torch.stack([torch.from_numpy(el["center"]["labels"]) for el in batch])

    padded_left = pad_batch("left", ['input_ids', 'token_type_ids', 'attention_mask'])
    padded_right = pad_batch("right", ['input_ids', 'token_type_ids', 'attention_mask'])

    return {
        "left": padded_left,
        "center": padded_center,
        "right": padded_right
    }


data_filter = DataFilter(14, 300)

class ContextCollector(DataCollectorAppend[Dict[str, Tensor], Tensor]):

    def apply(self, data: SD, out: MemoryOut, state: State) -> CT:
        return out.cpu() if out is not None else None

mem_acc = Accumulator(model, decay=0.9)
pred_acc = Accumulator(predictor, decay=0.9)

errors_filter = TopErrorsFilter(14, (14, 14), pred_acc, nn.PoissonNLLLoss(log_input=False, reduction="none"), is_random=True)

memup_iter_acc = MemoryRollout[Dict[str, torch.Tensor]](
    steps=1000,
    memory=MemUpMemoryImpl(mem_acc),
    data_filter=data_filter,
    info_update=[IncrementStep()]
)

: 

In [5]:
data_path = "/mnt/nfs_dna/DNALM/downstream_tasks/enformer/human/h5/human_train.h5"
train_dataset = EnformerDataset(tokenizer, data_path)

print(f'len(train_dataset): {len(train_dataset)}')

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=20, num_workers=5, collate_fn=collate_fn)

mem_acc.get_module().eval()
predictor.train()


optimizer = AdamW([
    {"params": predictor.parameters(), "lr": 5e-5},
] , weight_decay=1e-5)

for it, batch in enumerate(train_dataloader):

    if it > 10:
        break
    
    info = {}
    done = False
    print()
    state = torch.zeros(batch["center"]["labels"].shape[0], 200, model_cfg.hidden_size, device=torch.device("cuda:0"))

    train_set = []
    train_state = None

    with torch.no_grad():
        context_collector, last_state, _, _ = memup_iter_acc.forward(batch, state, {}, ContextCollector())
        train_state = last_state
        context = torch.cat(context_collector.collection, 1)
        print("context", context.shape)
        if context.shape[1] != 896:
            continue

    last_state = last_state.cuda()
    for j in range(0, 896, 14):
        optimizer.zero_grad()
        pred_j = predictor(context[:, j:j+14].cuda(), last_state)
        tg_j = batch["center"]["labels"][:, j:j+14].cuda()
        loss = nn.PoissonNLLLoss(log_input=False)(pred_j, tg_j)
        pc = PearsonCorrLoss()(pred_j.reshape(-1, 5313), tg_j.reshape(-1, 5313)).item()
        print(loss.item(), pc)
        loss.backward()
        optimizer.step()
        pred_acc.accumulate()



len(train_dataset): 34012

stage left
stage center
stage right
context torch.Size([20, 896, 4608])
0.6783071160316467 -0.5552586317062378
0.7053796052932739 -0.540067732334137
0.6897053718566895 -0.5507819056510925
0.6714479327201843 -0.5676361322402954
0.7063537240028381 -0.5597556829452515
0.6333743333816528 -0.5462321639060974
0.6811623573303223 -0.5505009293556213
0.6368356943130493 -0.5733550786972046
0.6317087411880493 -0.5485631227493286
0.6617066264152527 -0.5710166096687317
0.7001914978027344 -0.5541484355926514
0.5223304033279419 -0.5695292949676514
0.6826533675193787 -0.5671328902244568
0.6664441823959351 -0.5565056800842285
0.6912162899971008 -0.5757006406784058
0.6986082792282104 -0.5586433410644531
0.697401762008667 -0.5637553930282593
0.6868051886558533 -0.561008632183075
0.6852378845214844 -0.5736974477767944
0.6650139689445496 -0.5729855298995972
0.6717947721481323 -0.5745871067047119
0.6764571666717529 -0.5581704378128052
0.6846470832824707 -0.5719487071037292
0.65325

In [2]:
pearson_corr_coef = MeanPearsonCorrCoefPerChannel(5313)
mem_acc.get_module().eval()
pred_acc.get_module().eval()

simple_corr = []

data_path = "/mnt/nfs_dna/DNALM/downstream_tasks/enformer/human/h5/human_test.h5"
train_dataset = EnformerDataset(tokenizer, data_path)

print(f'len(train_dataset): {len(train_dataset)}')

train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=30, num_workers=5, collate_fn=collate_fn)


# optimizer = AdamW([
#     {"params": predictor.parameters(), "lr": 5e-5},
# ] , weight_decay=1e-5)

for it, batch in enumerate(train_dataloader):
    
    info = {}
    done = False
    print()
    state = torch.zeros(batch["center"]["labels"].shape[0], 200, model_cfg.hidden_size, device=torch.device("cuda:0"))

    train_set = []
    train_state = None

    with torch.no_grad():
        context_collector, last_state, _, _ = memup_iter_acc.forward(batch, state, {}, ContextCollector())
        train_state = last_state
        context = torch.cat(context_collector.collection, 1)
        print("context", context.shape)
        if context.shape[1] != 896:
            continue

        predictions = []

        for j in range(0, 896, 14):
            pred_j = pred_acc(context[:, j:j+14].cuda(), last_state.cuda()).cpu()
            # tg_j = batch["center"]["labels"][:, j:j+14]
            predictions.append(pred_j)


        predictions = torch.cat(predictions, 1)
        print(predictions.shape)

        pearson_corr_coef.update(predictions, batch["center"]["labels"])
        p_corr = pearson_corr_coef.compute().mean().item()
        print("pearson_corr_coef", p_corr)

        simple_corr.append(PearsonCorrLoss()(predictions.reshape(-1, 5313), batch["center"]["labels"].reshape(-1, 5313)).item())
        print("simple_corr_coef", sum(simple_corr) / len(simple_corr))

    



len(train_dataset): 1937

stage left




stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.48307672142982483
simple_corr_coef -0.5330652594566345

stage left
stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.49964195489883423
simple_corr_coef -0.5433982014656067

stage left
stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.50095534324646
simple_corr_coef -0.5415904919306437

stage left
stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.5036499500274658
simple_corr_coef -0.5358397215604782

stage left
stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.5023065209388733
simple_corr_coef -0.5389904141426086

stage left
stage center
stage right
context torch.Size([30, 896, 4608])
torch.Size([30, 896, 5313])
pearson_corr_coef 0.4914115071296692

KeyboardInterrupt: 