In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
dest_path ='/liuzicheng/ljh/hyena-dna/data/drosophila_enhancer_activity'
split='test'
targets_file = os.path.join(dest_path, 'Sequences_activity_'+split+".txt")
data = pd.read_csv(targets_file, sep="\t")[
            ["Dev_log2_enrichment", "Hk_log2_enrichment"]
        ]
fasta_file= os.path.join(dest_path, 'Sequences_'+split+".fa")
all_seqs = []
all_labels = []

with open(fasta_file) as fin:
    header = False
    for line in fin:
        l = line.strip()
        if len(l) == 0:  # last line
            break
        if line.startswith(">"):
            header = True
            continue
        else:
            assert header  # check fasta format is correct
            all_seqs.append(l)
        
        
        

all_labels=data[["Dev_log2_enrichment", "Hk_log2_enrichment"]].values.astype("float32").tolist()
        
assert len(all_seqs) == len(
            all_labels
        ), "Number of targets does not match number of sequences"
batch_size=128




restrict = lambda x: (
                    torch.cumsum(x, dim=-2)
                    / torch.arange(
                        1, 1 + x.size(-2), device=x.device, dtype=x.dtype
                    ).unsqueeze(-1)
                )[..., -1:, :]   
from scipy.stats import pearsonr
from scipy import stats
def pearsonr_1(outs, y, len_batch=None):
    # TODO: generalize, currently for Monash dataset
    metrics = {}
    outs=outs.detach()
    for i, label in enumerate(['dev', 'hk']):
        y_true = y[:, i].cpu().numpy()
        p = outs[:, i].cpu().numpy()
        r = stats.pearsonr(y_true, p)[0]
        metrics[f'pearsonr_{label}'] = r
        metrics[f'pearsonr2_{label}'] = r ** 2
    metrics['pearsonr'] = (metrics['pearsonr_dev'] + metrics['pearsonr_hk']) / 2
    return metrics
import torch.nn.functional as F
def mse(outs, y, len_batch=None):
    # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1
    # outs = outs.squeeze(-1)
    # if len(y.shape) < len(outs.shape):
    #     assert outs.shape[-1] == 1
    #     outs = outs.squeeze(-1)
    if len_batch is None:
        # return F.mse_loss(outs, y)
        loss = (
                            (
                                outs[~torch.isnan(y)]
                                - y[~torch.isnan(y)]
                            )
                            ** 2
                        ).mean()
        #check if y include nan
        if torch.isnan(outs).any():
            nan_indices_outs=torch.nonzero(torch.isnan(outs), as_tuple=False)
            print(nan_indices_outs)
        
        if torch.isnan(y).any():
            nan_indices=torch.nonzero(torch.isnan(y), as_tuple=False)
            print(nan_indices)
        return loss
    else:
        # Computes the loss of the first `lens` items in the batches
        # TODO document the use case of this
        mask = torch.zeros_like(outs, dtype=torch.bool)
        for i, l in enumerate(len_batch):
            mask[i, :l, :] = 1
        outs_masked = torch.masked_select(outs, mask)
        y_masked = torch.masked_select(y, mask)
        return F.mse_loss(outs_masked, y_masked)

In [10]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import nn

max_length=128

with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/hyenadna/hyenadna-large-1m-seqlen'
    hyena_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    hyena_model=AutoModel.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-05-04/13-01-30-043019/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    hyena_decoder = nn.Linear(256,2).to('cuda')
    #edit key name in hyena_decoder
    
    hyena_model.load_state_dict(checkpoint,strict=False)
    hyena_decoder.load_state_dict(checkpoint,strict=False)
    hyena_model.eval()
    hyena_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):
        sequence_encoded=hyena_tokenizer(all_seqs[i],
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=hyena_model(input_ids=seqs).last_hidden_state
        hidden_states=restrict(hidden_states)
        out1=hyena_decoder(hidden_states)
        out1_hyena=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_hyena)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            
            if i%1000==0:
                print(i)
                pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
                print(pearsonr)
            

#plot the bar plot of the pearsonr


1000
{'pearsonr_dev': 0.28364549906781383, 'pearsonr2_dev': 0.08045476914142918, 'pearsonr_hk': 0.7182049665806906, 'pearsonr2_hk': 0.5158183740211708, 'pearsonr': 0.5009252328242522}
2000
{'pearsonr_dev': 0.2686941683903085, 'pearsonr2_dev': 0.07219655612695947, 'pearsonr_hk': 0.6727362950160191, 'pearsonr2_hk': 0.45257412263188024, 'pearsonr': 0.47071523170316376}
3000
{'pearsonr_dev': 0.44836453244429925, 'pearsonr2_dev': 0.20103075395399508, 'pearsonr_hk': 0.6601745122358464, 'pearsonr2_hk': 0.4358303866058378, 'pearsonr': 0.5542695223400729}
4000
{'pearsonr_dev': 0.4860841635844665, 'pearsonr2_dev': 0.23627781408761037, 'pearsonr_hk': 0.6544510277299389, 'pearsonr2_hk': 0.4283061476967733, 'pearsonr': 0.5702675956572028}
5000
{'pearsonr_dev': 0.4929511378143658, 'pearsonr2_dev': 0.24300082427247785, 'pearsonr_hk': 0.6431594390840369, 'pearsonr2_hk': 0.413654064082893, 'pearsonr': 0.5680552884492014}
6000
{'pearsonr_dev': 0.48780199029144755, 'pearsonr2_dev': 0.2379507817322975, 'p

In [4]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import nn

def group_by_kmer(seq: str, kmer: int) -> str:
        return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()
max_length=128
with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/dnabert/dnabert3/3-new-12w-0'
    bert_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    bert_model=AutoModel.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-03-29/04-45-15-121292/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    bert_decoder = nn.Linear(768,2).to('cuda')
    #edit key name in hyena_decoder
    
    bert_model.load_state_dict(checkpoint,strict=False)
    bert_decoder.load_state_dict(checkpoint,strict=False)
    bert_model.eval()
    bert_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):
        all_seqs_group = group_by_kmer(all_seqs[i],kmer=3)
        sequence_encoded=bert_tokenizer(all_seqs_group,
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=bert_model(input_ids=seqs).last_hidden_state
        hidden_states=restrict(hidden_states)
        out1=bert_decoder(hidden_states)
        out1_bert=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_bert)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            if i%1000==0:
                print(i)
                pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
                print(pearsonr)
            

#plot the bar plot of the pearsonr


1000
{'pearsonr_dev': 0.41906931003452125, 'pearsonr2_dev': 0.1756190866128097, 'pearsonr_hk': 0.8753353633535679, 'pearsonr2_hk': 0.7662119983373227, 'pearsonr': 0.6472023366940446}
2000
{'pearsonr_dev': 0.41060924408617483, 'pearsonr2_dev': 0.1685999513290199, 'pearsonr_hk': 0.8400015690857895, 'pearsonr2_hk': 0.7056026360665884, 'pearsonr': 0.6253054065859822}
3000
{'pearsonr_dev': 0.6120348823835176, 'pearsonr2_dev': 0.3745866972542062, 'pearsonr_hk': 0.8317270664662886, 'pearsonr2_hk': 0.691769913092618, 'pearsonr': 0.7218809744249031}
4000
{'pearsonr_dev': 0.6464362819337207, 'pearsonr2_dev': 0.41787986660029286, 'pearsonr_hk': 0.8259247689789676, 'pearsonr2_hk': 0.682151724012961, 'pearsonr': 0.7361805254563442}
5000
{'pearsonr_dev': 0.6581854274833475, 'pearsonr2_dev': 0.43320805695143694, 'pearsonr_hk': 0.8189810741750629, 'pearsonr2_hk': 0.6707299998569399, 'pearsonr': 0.7385832508292052}
6000
{'pearsonr_dev': 0.6603334668402109, 'pearsonr2_dev': 0.43604028742921186, 'pearson

In [5]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import nn

def group_by_kmer(seq: str, kmer: int) -> str:
        return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()
max_length=128
with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/dnabert2'
    bert2_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    bert2_model=AutoModel.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-03-29/08-13-54-405523/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    bert2_decoder = nn.Linear(768,2).to('cuda')
    #edit key name in hyena_decoder
    
    bert2_model.load_state_dict(checkpoint,strict=False)
    bert2_decoder.load_state_dict(checkpoint,strict=False)
    bert2_model.eval()
    bert2_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):

        sequence_encoded=bert2_tokenizer(all_seqs[i],
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=bert2_model(input_ids=seqs,export_hidden_states=True)[0]
        hidden_states=restrict(hidden_states)
        out1=bert2_decoder(hidden_states)
        out1_bert2=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_bert2)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            if i%1000==0:
                print(i)
                pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
                print(pearsonr)
            

#plot the bar plot of the pearsonr


Some weights of BertModel were not initialized from the model checkpoint at /liuzicheng/ljh/hyena-dna/weight/dnabert2 and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


1000
{'pearsonr_dev': 0.40376231061410334, 'pearsonr2_dev': 0.16302400347243967, 'pearsonr_hk': 0.8870294712558203, 'pearsonr2_hk': 0.7868212828763802, 'pearsonr': 0.6453958909349619}
2000
{'pearsonr_dev': 0.38980792075562903, 'pearsonr2_dev': 0.15195021508382675, 'pearsonr_hk': 0.8537560044292595, 'pearsonr2_hk': 0.7288993150990137, 'pearsonr': 0.6217819625924442}
3000
{'pearsonr_dev': 0.6066960771598909, 'pearsonr2_dev': 0.3680801300412003, 'pearsonr_hk': 0.8385283466776545, 'pearsonr2_hk': 0.7031297881819608, 'pearsonr': 0.7226122119187728}
4000
{'pearsonr_dev': 0.6424585324177863, 'pearsonr2_dev': 0.41275296587641574, 'pearsonr_hk': 0.8311816574792742, 'pearsonr2_hk': 0.6908629477299935, 'pearsonr': 0.7368200949485302}
5000
{'pearsonr_dev': 0.657806057359012, 'pearsonr2_dev': 0.43270880909820775, 'pearsonr_hk': 0.8275756276598912, 'pearsonr2_hk': 0.6848814194966629, 'pearsonr': 0.7426908425094516}
6000
{'pearsonr_dev': 0.6589346042820474, 'pearsonr2_dev': 0.43419481272033844, 'pear

In [7]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import nn

def group_by_kmer(seq: str, kmer: int) -> str:
        return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()
max_length=128
with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/genalm/gena-lm-bert-large-t2t'
    genalm_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    genalm_model=AutoModel.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-03-30/01-49-47-161996/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    genalm_decoder = nn.Linear(1024,2).to('cuda')
    #edit key name in hyena_decoder
    
    genalm_model.load_state_dict(checkpoint,strict=False)
    genalm_decoder.load_state_dict(checkpoint,strict=False)
    genalm_model.eval()
    genalm_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):

        sequence_encoded=genalm_tokenizer(all_seqs[i],
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=genalm_model(input_ids=seqs, output_hidden_states=True,).hidden_states[-1]
        hidden_states=restrict(hidden_states)
        out1=genalm_decoder(hidden_states)
        out1_genalm=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_genalm)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            if i%1000==0:
                print(i)
    pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
    print(pearsonr)
            

#plot the bar plot of the pearsonr




1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
{'pearsonr_dev': 0.6241779409090845, 'pearsonr2_dev': 0.3895981019175046, 'pearsonr_hk': 0.7401273713593093, 'pearsonr2_hk': 0.547788525835241, 'pearsonr': 0.6821526561341968}


In [8]:
from transformers import AutoTokenizer,AutoModel,AutoModelForMaskedLM
import torch
from torch import nn

def group_by_kmer(seq: str, kmer: int) -> str:
        return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()
max_length=128
with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/nt/nucleotide-transformer-v2-500m-multi-species'
    nt_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    nt_model=AutoModelForMaskedLM.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-03-29/10-42-51-953352/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    nt_decoder = nn.Linear(1024,2).to('cuda')
    #edit key name in hyena_decoder
    
    nt_model.load_state_dict(checkpoint,strict=False)
    nt_decoder.load_state_dict(checkpoint,strict=False)
    nt_model.eval()
    nt_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):

        sequence_encoded=nt_tokenizer(all_seqs[i],
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=nt_model(input_ids=seqs,output_hidden_states=True)['hidden_states'][-1]
        hidden_states=restrict(hidden_states)
        out1=nt_decoder(hidden_states)
        out1_nt=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_nt)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            if i%1000==0:
                print(i)
    pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
    print(pearsonr)
            

#plot the bar plot of the pearsonr


1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
{'pearsonr_dev': 0.6122770428054839, 'pearsonr2_dev': 0.37488317714662833, 'pearsonr_hk': 0.7367069263616648, 'pearsonr2_hk': 0.5427370953492514, 'pearsonr': 0.6744919845835744}


In [9]:
from transformers import AutoTokenizer,AutoModel,AutoModelForMaskedLM
import torch
from torch import nn

def group_by_kmer(seq: str, kmer: int) -> str:
        return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper()
max_length=128
with torch.no_grad():
    state_dict='/liuzicheng/ljh/hyena-dna/weight/mamba/caduceus-ph_seqlen-131k_d_model-256_n_layer-16'
    mamba_tokenizer=AutoTokenizer.from_pretrained(state_dict, trust_remote_code=True)
    mamba_model=AutoModel.from_pretrained(state_dict, trust_remote_code=True).to('cuda')
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-04-19/14-39-48-177210/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "decoder.0.output_transform."
        )

    mamba_decoder = nn.Linear(256,2).to('cuda')
    #edit key name in hyena_decoder
    
    mamba_model.load_state_dict(checkpoint,strict=False)
    mamba_decoder.load_state_dict(checkpoint,strict=False)
    mamba_model.eval()
    mamba_decoder.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):

        sequence_encoded=mamba_tokenizer(all_seqs[i],
                            add_special_tokens= False,  # this is what controls adding eos
                            padding="max_length",
                            max_length=max_length,
                            truncation=True,
                        )
        seq_ids=sequence_encoded['input_ids']
        seq_ids = torch.LongTensor(seq_ids)
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length)).to('cuda')
        target_list.append(target)
        hidden_states=mamba_model(seqs,output_hidden_states=True).last_hidden_state
        hidden_states=restrict(hidden_states)
        out1=mamba_decoder(hidden_states)
        out1_mamba=out1.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_mamba)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            if i%1000==0:
                print(i)
    pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
    print(pearsonr)
            

#plot the bar plot of the pearsonr


1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
{'pearsonr_dev': 0.44316330657723835, 'pearsonr2_dev': 0.19639371629647134, 'pearsonr_hk': 0.5303301817137167, 'pearsonr2_hk': 0.28125010163650377, 'pearsonr': 0.48674674414547753}


In [11]:
from transformers import AutoTokenizer,AutoModel
import torch
from torch import nn
import sys
sys.path.append('/liuzicheng/ljh/hyena-dna/')
from src.models.sequence.deepSTAR import DeepSTAR
max_length=128
def genomic_to_one_hot(genomic_sequence):
        mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        one_hot = np.zeros((len(genomic_sequence), 4))
        for i, base in enumerate(genomic_sequence):
            if base in mapping:
                one_hot[i, mapping[base]] = 1
            else:
                # 如果碱基不是A、C、G、T或N，可以选择将其编码为全零向量或者平均分配概率
                one_hot[i, :] = 0.25  # 或者使用 np.full((5,), 0.2) 平均分配概率
        return one_hot
with torch.no_grad():
    DeepSTAR_model=DeepSTAR(input_size=128,output_size=2)
    full_sequence=[]
    checkpoint=torch.load('/liuzicheng/ljh/hyena-dna/outputs/2024-05-13/09-21-43-779050/checkpoints/val/pearsonr.ckpt')['state_dict']
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            checkpoint, "model.backbone."
        )
    

    #edit key name in hyena_decoder
    
    DeepSTAR_model.load_state_dict(checkpoint,strict=False)
    DeepSTAR_model.to('cuda')
    DeepSTAR_model.eval()

    target_list=[]
    seq_list=[]
    for i in range(len(all_seqs)):
        seq=genomic_to_one_hot(all_seqs[i])
        seq_ids = torch.from_numpy(seq).float()[:max_length,:]
        target = all_labels[i]
        
        seqs=torch.reshape(seq_ids,(1,max_length,-1)).to('cuda')
        target_list.append(target)
        hidden_states=DeepSTAR_model(seqs)
        out1_deepstar=hidden_states.squeeze(1).squeeze(0).cpu().detach().numpy()
        
        seq_list.append(out1_deepstar)
        seq_list_tensor=torch.FloatTensor(seq_list)
        target_list_tensor=torch.FloatTensor(target_list)
        #calculate the 
        if i>=1:
            
            
            if i%1000==0:
                print(i)
                pearsonr=pearsonr_1(seq_list_tensor,target_list_tensor)
                print(pearsonr)
            

#plot the bar plot of the pearsonr


1000
{'pearsonr_dev': 0.24202421722145104, 'pearsonr2_dev': 0.058575721721656114, 'pearsonr_hk': 0.6362079135623132, 'pearsonr2_hk': 0.4047605092793118, 'pearsonr': 0.43911606539188214}
2000
{'pearsonr_dev': 0.19344125536819476, 'pearsonr2_dev': 0.03741951927842314, 'pearsonr_hk': 0.6061167038653557, 'pearsonr2_hk': 0.36737745870460325, 'pearsonr': 0.3997789796167752}
3000
{'pearsonr_dev': 0.3831511174324757, 'pearsonr2_dev': 0.1468047787897548, 'pearsonr_hk': 0.597412798440215, 'pearsonr2_hk': 0.3569020517401689, 'pearsonr': 0.4902819579363453}
4000
{'pearsonr_dev': 0.42583537391442033, 'pearsonr2_dev': 0.18133576567683418, 'pearsonr_hk': 0.5899780072093755, 'pearsonr2_hk': 0.34807404899074595, 'pearsonr': 0.507906690561898}
5000
{'pearsonr_dev': 0.43487575685905355, 'pearsonr2_dev': 0.18911692390373466, 'pearsonr_hk': 0.5858931499149908, 'pearsonr2_hk': 0.34327078311730985, 'pearsonr': 0.5103844533870221}
6000
{'pearsonr_dev': 0.43364567462259984, 'pearsonr2_dev': 0.18804857111888973