In [None]:
import torch
from colonmodel import jambaregression
from colontokenrizer import colonTokenizer

# 1. 加载 tokenizer
tokenizer = colonTokenizer(model_max_length=300)

# 2. 假设有原始序列列表和对应特征
seqs = [
    "ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCACTGGACAGTTATTCGTGTCTCTTACAATTACCAAACAGA",    # WT
    "ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACTTTTGAGACCACAAACTTACCGCGCTGTTCTTTGGGAGGGCCAAGCACCGCTGACGTGTCCCTTGGACAGTTACTCGCGTCTCTTACCATTACCAAACAGC",    #mu1
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCATTGGACAGTTACTCGCGTCTCTTGCAATTACCAAACAGC',    #mut2
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCACTGGACAGTTACTCGCGTCTCTTACAATTACCAGGCAGC',    #mut3
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCATTGGGCAGTTATTCGTGTCTCTTACCACTACCAGACAGA',    #mut4
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAAGCACCGCTGACGTGTCCCTTGGACAGTTACTCGCGTCTCTTACCATTACCAAACAGC',    #>BLNK-5’UTR-SNP10-3
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACCTCTGACGTGTTCCTTGGACAGTTACTCGCGTCTCTTGCAATTACCAAACAGC',  #BLNK-5’UTR-SNP10-4
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACCGCTGACGTGTCCCTTGGACAGTTACTCGCGTCTCTTACCATTACCAAACAGC',   #BLNK-5’UTR-SNP8-5
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGATCCTTGGACAGTTACTCGCGTCTCTTACAATTACCAGGCAGC',   #BLNK-5’UTR-SNP8-6
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGTTCCTTGGACAGTTACTCGCGTCTCTTGCAATTACCAAACAGC',  #BLNK-5’UTR-SNP8-7
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCATTGGACAGTTACTCGCGTCTCTTACCATTACCAAACAGC',  #BLNK-5’UTR-SNP8-10
    'ACTTCTCCCTAGAGCAGGGGTGTTTGCCAGCAGCCTGCACTCTCAGAAATCAGACTTGAGTGGCCGGAACCCTTGAGACCAGAGGCTTACCATGCTGCTCCCTAGGAGGGCCAGGAACTGCTGACGTGACCACTGGGCAGTTATTCGTGTCTCTTACCACTACCAGGCAGA', #BLNK-5’UTR-SNP8-11
]

tpm = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] #example rna expression values


cell_type = [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]   # cell 293T 

# Tokenize and pad/truncate
input_ids, attention_masks = [], []
for seq in seqs:
    toks = tokenizer(seq)
    ids = torch.tensor(toks["input_ids"], dtype=torch.long)
    mask = torch.tensor(toks["attention_mask"], dtype=torch.bool)
    input_ids.append(ids)
    attention_masks.append(mask)

device = "cuda:0"
input_ids = torch.stack(input_ids).to(device)           # [B, L]
attention_mask = torch.stack(attention_masks).to(device)
tpm_tensor = torch.tensor(tpm, dtype=torch.float32, device=device)            # [B]
cell_type_tensor = torch.tensor(cell_type, dtype=torch.long, device=device)   # [B]

model = jambaregression(
    d_model=16,
    num_mamba_blocks=3,
    d_intermediate=32,
    vocab_size=5,
    max_seq_len=300,
)
state_dict = torch.load("pure_weights_colon/colon_model.pt", map_location="cuda")
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()

# inference
with torch.no_grad():
    output = model(input_ids, attention_mask=attention_mask, tpm=tpm_tensor, cell_type=cell_type_tensor)
print("Inference outputs:", output)

Inference outputs: tensor([[0.7965],
        [0.8032],
        [0.8017],
        [0.8083],
        [0.8097],
        [0.8090],
        [0.8016],
        [0.8105],
        [0.8064],
        [0.8028],
        [0.8035],
        [0.8164]], device='cuda:0')


: 