In [1]:
!pip install pandas

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
[0m

In [2]:
!pip install openpyxl

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
[0m

In [6]:
import pandas as pd
import json
from collections import defaultdict

def preprocess(input_csv, output_json):
    # 读取Excel文件
    df = pd.read_excel(input_csv)
    
    # 按 SMILES 分组
    grouped = defaultdict(lambda: {"sequences": [], "features": {}})    
    
    for _, row in df.iterrows():
        smi = row["SMILES"]
        # 获取序列并将DNA转为RNA（T→U）
        dna_seq = row["Sequence"]
        rna_seq = dna_seq.replace('T', 'U')  # 将DNA序列中的T替换为U，转换为RNA
        
        # 提取全局分子特征
        features = {            
            "molecular_weight": row.get("molecular_weight", None),            
            "pubchem_id": row.get("Pubchem ID", None),            
            "title": row.get("Titles", None),            
            "iupac": row.get("IUPAC Names", None),            
            "condition": row.get("Binding Conditions/Buffer", None),        
        }
        
        # 如果是该SMILES的第一条记录，设置特征
        if not grouped[smi]["features"]:
            grouped[smi]["features"] = features
        
        # 添加转换后的RNA序列
        grouped[smi]["sequences"].append(rna_seq)
    
    # 转存为list格式
    dataset = []    
    for smi, v in grouped.items():
        dataset.append({            
            "smiles": smi,            
            "features": v["features"],            
            "sequences": v["sequences"]        
        })
    
    # 保存为JSON文件
    with open(output_json, "w") as f:
        json.dump(dataset, f, indent=2)

# 用法示例：
preprocess("merged_file_with_smiles_filled.xlsx", "dataset.json")
    

In [5]:
!pip install torch

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
[0m

In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
import json

"""
字符级 tokenization
"""

class Vocab:
    def __init__(self, tokens):
        self.tokens = ["<pad>", "<bos>", "<eos>", "<unk>"] + tokens
        self.stoi = {s:i for i,s in enumerate(self.tokens)}
        self.itos = {i:s for s,i in self.stoi.items()}

    def encode(self, s, add_bos=True, add_eos=True):
        tokens = list(s)
        ids = []
        if add_bos: ids.append(self.stoi["<bos>"])
        ids += [self.stoi.get(t, self.stoi["<unk>"]) for t in tokens]
        if add_eos: ids.append(self.stoi["<eos>"])
        return ids

    def decode(self, ids):
        return "".join([self.itos[i] for i in ids if i not in (0,1,2)])

    @property
    def pad_id(self): return self.stoi["<pad>"]

class AptamerDataset(Dataset):
    def __init__(self, json_path, smiles_vocab, seq_vocab, max_len=100):
        with open(json_path) as f:
            self.data = json.load(f)
        self.smiles_vocab = smiles_vocab
        self.seq_vocab = seq_vocab
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        smiles = entry["smiles"]
        seqs = entry["sequences"]  # 多个 aptamer

        # Encode SMILES
        smiles_ids = self.smiles_vocab.encode(smiles, add_bos=False, add_eos=False)

        # Encode 所有 aptamer 序列
        seq_ids = [self.seq_vocab.encode(s) for s in seqs]

        return {
            "smiles_ids": torch.tensor(smiles_ids, dtype=torch.long),
            "sequences": [torch.tensor(s, dtype=torch.long) for s in seq_ids]
        }

def collate_fn(batch, pad_id=0):
    # Pad SMILES
    smiles_lens = [len(b["smiles_ids"]) for b in batch]
    max_smi = max(smiles_lens)
    smiles_tensor = torch.full((len(batch), max_smi), pad_id, dtype=torch.long)
    for i,b in enumerate(batch):
        smiles_tensor[i,:len(b["smiles_ids"])] = b["smiles_ids"]

    # 这里 sequences 是 list[list[tensor]]，训练时要 special handle
    seqs_list = [b["sequences"] for b in batch]

    return {
        "smiles": smiles_tensor,
        "sequences": seqs_list
    }

# 用法
# smiles_vocab = Vocab(list("CNOPSH[]=()#123456789"))  # 自定义
# seq_vocab = Vocab(list("ACGU"))
# dataset = AptamerDataset("dataset.json", smiles_vocab, seq_vocab)
# loader = DataLoader(dataset, batch_size=4, collate_fn=lambda x: collate_fn(x, pad_id=smiles_vocab.pad_id))


In [8]:
smiles_vocab = Vocab(list("CNOPSH[]=()#123456789")) 
seq_vocab = Vocab(list("ACGU"))
dataset = AptamerDataset("dataset.json", smiles_vocab, seq_vocab)
loader = DataLoader(dataset, batch_size=4, collate_fn=lambda x: collate_fn(x, pad_id=smiles_vocab.pad_id))

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
模型骨架（Encoder-Decoder + MIL loss）
"""

class SimpleMolEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=512),
            num_layers=4
        )

    def forward(self, smiles_ids):
        x = self.embed(smiles_ids).transpose(0,1)  # (L,B,D)
        x = self.encoder(x)  # (L,B,D)
        mol_repr = x.mean(0) # (B,D)
        return mol_repr

class RNASeqDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead=8, dim_feedforward=512),
            num_layers=4
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt_ids, memory):
        tgt_emb = self.embed(tgt_ids).transpose(0,1)  # (L,B,D)
        out = self.decoder(tgt_emb, memory.unsqueeze(0)) # memory (1,B,D)
        logits = self.fc_out(out).transpose(0,1)  # (B,L,V)
        return logits

class Mol2Aptamer(nn.Module):
    def __init__(self, smi_vocab, seq_vocab, d_model=256):
        super().__init__()
        self.encoder = SimpleMolEncoder(len(smi_vocab.tokens), d_model)
        self.decoder = RNASeqDecoder(len(seq_vocab.tokens), d_model)

    def forward(self, smiles_ids, seq_ids):
        memory = self.encoder(smiles_ids) # (B,D)
        logits = self.decoder(seq_ids, memory)
        return logits

def mil_loss(logits_list, seqs_list, pad_id):
    """
    logits_list: list of [B,L,V] from decoder for each sequence candidate
    seqs_list:  list of list[tensor], each tensor is (L,)
    """
    losses = []
    for b, seqs in enumerate(seqs_list):
        logps = []
        for seq in seqs:
            tgt = seq[1:]  # skip BOS
            inp = seq[:-1]
            logit = logits_list[b][:len(inp)]
            logp = F.cross_entropy(
                logit, tgt, ignore_index=pad_id, reduction="sum"
            )
            logps.append(-logp)
        losses.append(-torch.logsumexp(torch.tensor(logps), dim=0))
    return torch.stack(losses).mean()


In [10]:
!pip install tokenizers

Looking in indexes: http://mirrors.aliyun.com/pypi/simple
[0m

In [11]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer

"""
BPE tokenizer
"""

def train_bpe_tokenizer(corpus_file, vocab_size=2000, save_path="tokenizer.json"):
    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
    trainer = BpeTrainer(vocab_size=vocab_size, special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"])
    tokenizer.train([corpus_file], trainer)
    tokenizer.post_processor = processors.TemplateProcessing(
        single="<bos> $A <eos>",
        special_tokens=[("<bos>",1), ("<eos>",2)]
    )
    tokenizer.save(save_path)
    return tokenizer


In [15]:
train_bpe_tokenizer("all_smiles.txt", vocab_size=500, save_path="/root/autodl-tmp/smiles_tokenizer.json")
train_bpe_tokenizer("all_rna_converted.txt", vocab_size=100, save_path="/root/autodl-tmp/rna_tokenizer.json")










Tokenizer(version="1.0", truncation=None, padding=None, added_tokens=[{"id":0, "content":"<pad>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":1, "content":"<bos>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":2, "content":"<eos>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}, {"id":3, "content":"<unk>", "single_word":False, "lstrip":False, "rstrip":False, "normalized":False, "special":True}], normalizer=None, pre_tokenizer=ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), post_processor=TemplateProcessing(single=[SpecialToken(id="<bos>", type_id=0), Sequence(id=A, type_id=0), SpecialToken(id="<eos>", type_id=0)], pair=[Sequence(id=A, type_id=0), Sequence(id=B, type_id=1)], special_tokens={"<bos>":SpecialToken(id="<bos>", ids=[1], tokens=["<bos>"]), "<eos>":SpecialToken(id="<eos>", ids=[2], tokens=["<eos>"])}), decoder=None

In [14]:
def dna_to_rna(input_file_path, output_file_path):
    """
    将文本文件中的所有DNA序列（含T）转换为RNA序列（T→U）
    :param input_file_path: 输入文本文件路径（如 all_rna.txt）
    :param output_file_path: 输出RNA序列文件路径
    """
    try:
        # 1. 读取输入文件中的所有序列
        with open(input_file_path, 'r', encoding='utf-8') as input_file:
            # 读取所有行，去除空行并保留有效序列
            dna_sequences = [line.strip() for line in input_file if line.strip()]
        
        # 2. 批量将T替换为U（DNA→RNA转换）
        rna_sequences = []
        for seq in dna_sequences:
            # 仅替换碱基T为U，保留其他字符（如序列分隔符、注释等，若存在）
            rna_seq = seq.replace('T', 'U')
            rna_sequences.append(rna_seq)
        
        # 3. 将转换后的RNA序列写入输出文件
        with open(output_file_path, 'w', encoding='utf-8') as output_file:
            # 每行写入一条RNA序列，保持原始文件的分行格式
            output_file.write('\n'.join(rna_sequences))
        
        print(f"转换完成！")
        print(f"输入文件：{input_file_path}")
        print(f"输出文件：{output_file_path}")
        print(f"总计转换序列数量：{len(rna_sequences)}")
        
    except Exception as e:
        print(f"转换过程中出现错误：{str(e)}")


INPUT_FILE = "all_rna.txt"    # 输入的原始序列文件（含T的DNA序列）
OUTPUT_FILE = "all_rna_converted.txt"  # 输出的RNA序列文件（T已替换为U）

# 执行转换
dna_to_rna(INPUT_FILE, OUTPUT_FILE)

转换完成！
输入文件：all_rna.txt
输出文件：all_rna_converted.txt
总计转换序列数量：796


In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
import json
from tokenizers import Tokenizer

"""
数据集+DataLoader (支持 BPE)
"""

class AptamerDataset(Dataset):
    def __init__(self, json_path, smiles_tok_path, rna_tok_path, max_len=128):
        with open(json_path) as f:
            self.data = json.load(f)
        self.smiles_tok = Tokenizer.from_file(smiles_tok_path)
        self.rna_tok = Tokenizer.from_file(rna_tok_path)
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        smiles = entry["smiles"]
        seqs = entry["sequences"]

        smiles_ids = self.smiles_tok.encode(smiles).ids[:self.max_len]
        seq_ids = [self.rna_tok.encode(s).ids[:self.max_len] for s in seqs]

        return {
            "smiles_ids": torch.tensor(smiles_ids, dtype=torch.long),
            "sequences": [torch.tensor(s, dtype=torch.long) for s in seq_ids]
        }

def collate_fn(batch, pad_id=0):
    # Pad smiles
    max_smi_len = max(len(b["smiles_ids"]) for b in batch)
    smiles_tensor = torch.full((len(batch), max_smi_len), pad_id, dtype=torch.long)
    for i,b in enumerate(batch):
        smiles_tensor[i,:len(b["smiles_ids"])] = b["smiles_ids"]

    # keep sequences as list[list[tensor]]
    seqs_list = [b["sequences"] for b in batch]
    return {"smiles": smiles_tensor, "sequences": seqs_list}


In [17]:
import torch.nn as nn
import torch.nn.functional as F

"""
Transformer Encoder (SMILES) + Transformer Decoder (RNA)
"""

class SimpleMolEncoder(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=512),
            num_layers=4
        )

    def forward(self, smiles_ids):
        x = self.embed(smiles_ids).transpose(0,1) # (L,B,D)
        x = self.encoder(x) # (L,B,D)
        mol_repr = x.mean(0) # (B,D)
        return mol_repr

class RNASeqDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead=8, dim_feedforward=512),
            num_layers=4
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt_ids, memory):
        tgt_emb = self.embed(tgt_ids).transpose(0,1) # (L,B,D)
        memory = memory.unsqueeze(0) # (1,B,D)
        out = self.decoder(tgt_emb, memory)
        logits = self.fc_out(out).transpose(0,1) # (B,L,V)
        return logits

class Mol2Aptamer(nn.Module):
    def __init__(self, smi_vocab_size, rna_vocab_size, d_model=256):
        super().__init__()
        self.encoder = SimpleMolEncoder(smi_vocab_size, d_model)
        self.decoder = RNASeqDecoder(rna_vocab_size, d_model)

    def forward(self, smiles_ids, seq_ids):
        memory = self.encoder(smiles_ids)
        logits = self.decoder(seq_ids, memory)
        return logits


In [13]:
#多实例似然 (MIL loss) 

def mil_loss(model, smiles_ids, seqs_list, pad_id=0):
    """
    smiles_ids: (B,Lsmi)
    seqs_list: list[list[tensor]]
    """
    batch_losses = []
    for b, seqs in enumerate(seqs_list):
        logps = []
        for seq in seqs:
            inp = seq[:-1].unsqueeze(0) # (1,L-1)
            tgt = seq[1:].unsqueeze(0) # (1,L-1)
            logits = model(smiles_ids[b].unsqueeze(0), inp) # (1,L-1,V)
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                tgt.reshape(-1),
                ignore_index=pad_id,
                reduction="sum"
            )
            logps.append(-loss)
        logps = torch.stack(logps)
        batch_losses.append(-torch.logsumexp(logps, dim=0))
    return torch.stack(batch_losses).mean()


In [14]:
#训练循环

import torch
from torch.utils.data import random_split

def train(model, train_loader, val_loader, epochs=10, lr=1e-4, device="cuda"):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch in train_loader:
            smiles = batch["smiles"].to(device)
            seqs_list = [[s.to(device) for s in seqs] for seqs in batch["sequences"]]

            loss = mil_loss(model, smiles, seqs_list, pad_id=0)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()

        val_loss = 0
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                smiles = batch["smiles"].to(device)
                seqs_list = [[s.to(device) for s in seqs] for seqs in batch["sequences"]]
                loss = mil_loss(model, smiles, seqs_list, pad_id=0)
                val_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss/len(train_loader):.4f} | Val Loss {val_loss/len(val_loader):.4f}")


In [18]:
import json
import random
from typing import List, Dict

def split_dataset(input_path: str, train_path: str, val_path: str, val_ratio: float = 0.2, seed: int = 42) -> None:
    """
    将JSON格式的数据集随机拆分为训练集和验证集
    
    Args:
        input_path: 输入JSON文件路径
        train_path: 训练集输出路径
        val_path: 验证集输出路径
        val_ratio: 验证集占总数据的比例，默认0.2
        seed: 随机种子，保证结果可复现
    """
    # 设置随机种子，确保结果可复现
    random.seed(seed)
    
    # 读取JSON数据
    with open(input_path, 'r', encoding='utf-8') as f:
        data: List[Dict] = json.load(f)
    
    # 打乱数据顺序
    random.shuffle(data)
    
    # 计算拆分索引
    total = len(data)
    val_size = int(total * val_ratio)
    
    # 拆分数据
    val_data = data[:val_size]
    train_data = data[val_size:]
    
    # 保存训练集
    with open(train_path, 'w', encoding='utf-8') as f:
        json.dump(train_data, f, ensure_ascii=False, indent=2)
    
    # 保存验证集
    with open(val_path, 'w', encoding='utf-8') as f:
        json.dump(val_data, f, ensure_ascii=False, indent=2)
    
    print(f"数据集拆分完成！总样本数: {total}")
    print(f"训练集样本数: {len(train_data)} ({len(train_data)/total:.2%})")
    print(f"验证集样本数: {len(val_data)} ({len(val_data)/total:.2%})")

if __name__ == "__main__":
    # 配置文件路径
    INPUT_JSON = "dataset.json"
    TRAIN_JSON = "train_dataset.json"
    VAL_JSON = "val_dataset.json"
    
    # 拆分比例设置为20%作为验证集
    split_dataset(INPUT_JSON, TRAIN_JSON, VAL_JSON, val_ratio=0.2)

数据集拆分完成！总样本数: 238
训练集样本数: 191 (80.25%)
验证集样本数: 47 (19.75%)


In [19]:
import torch
import json
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR

# 训练设置
epochs = 10
lr = 1e-4
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class AptamerDataset(Dataset):
    def __init__(self, json_path, smiles_tok_path, rna_tok_path, max_len=128):
        with open(json_path) as f:
            self.data = json.load(f)
        
        self.smiles_tok = Tokenizer.from_file(smiles_tok_path)
        self.rna_tok = Tokenizer.from_file(rna_tok_path)
        
        self.max_len = max_len
        self.smiles_pad_id = self.smiles_tok.get_vocab()["<pad>"]
        self.rna_pad_id = self.rna_tok.get_vocab()["<pad>"]
        
        self._filter_invalid_data()

    def _filter_invalid_data(self):
        valid_data = []
        for entry in self.data:
            if "smiles" not in entry or "sequences" not in entry:
                continue
            if not entry["smiles"] or not entry["sequences"]:
                continue
            valid_seqs = [s for s in entry["sequences"] if s.strip()]
            if not valid_seqs:
                continue
            entry["sequences"] = valid_seqs
            valid_data.append(entry)
        self.data = valid_data
        print(f"Filtered data: {len(self.data)} valid entries remaining")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        smiles = entry["smiles"]
        seqs = entry["sequences"]

        # 处理SMILES
        smiles_encoded = self.smiles_tok.encode(smiles)
        smiles_ids = self._pad_or_truncate(
            smiles_encoded.ids, self.max_len, self.smiles_pad_id
        )
        
        # 处理RNA序列
        seq_ids_list = []
        for s in seqs:
            seq_encoded = self.rna_tok.encode(s)
            seq_ids = self._pad_or_truncate(
                seq_encoded.ids, self.max_len, self.rna_pad_id
            )
            seq_ids_list.append(seq_ids)

        return {
            "smiles_ids": torch.tensor(smiles_ids, dtype=torch.long),
            "sequences": torch.tensor(seq_ids_list, dtype=torch.long)
        }

    def _pad_or_truncate(self, ids, max_len, pad_id):
        if len(ids) > max_len:
            return ids[:max_len]
        else:
            return ids + [pad_id] * (max_len - len(ids))


def collate_fn(batch, smiles_pad_id, rna_pad_id):
    """修复布尔张量处理问题的collate_fn"""
    # 处理SMILES
    smiles_list = [item["smiles_ids"] for item in batch]
    smiles_padded = torch.nn.utils.rnn.pad_sequence(
        smiles_list,
        batch_first=True,
        padding_value=smiles_pad_id
    )
    
    # 处理RNA序列（修复关键部分）
    seqs_list = []
    for item in batch:
        valid_seqs = []
        for seq in item["sequences"]:
            # 找到序列中第一个pad_id的位置（修复布尔张量问题）
            # 方法1：将布尔张量转换为整数张量
            pad_mask = (seq == rna_pad_id).int()  # 转换为整数张量
            if pad_mask.sum() > 0:  # 检查是否有padding
                # 方法2：使用nonzero找到第一个pad的位置
                # valid_len = (seq == rna_pad_id).nonzero(as_tuple=True)[0][0].item()
                valid_len = pad_mask.argmax().item()  # 现在可以安全使用argmax了
                valid_seq = seq[:valid_len]
            else:
                valid_seq = seq  # 没有padding，使用完整序列
            
            # 保留有效长度≥2的序列
            if valid_seq.numel() >= 2:
                valid_seqs.append(valid_seq)
        seqs_list.append(valid_seqs)
    
    return {
        "smiles": smiles_padded,
        "seqs_list": seqs_list
    }


class Mol2Aptamer(torch.nn.Module):
    def __init__(self, smiles_vocab_size, rna_vocab_size, d_model=256):
        super().__init__()
        self.smiles_embedding = torch.nn.Embedding(smiles_vocab_size, d_model)
        self.rna_embedding = torch.nn.Embedding(rna_vocab_size, d_model)
        self.pos_embedding = torch.nn.Embedding(512, d_model)
        
        self.decoder_layer = torch.nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=8,
            batch_first=True
        )
        self.decoder = torch.nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        self.fc_out = torch.nn.Linear(d_model, rna_vocab_size)

    def forward(self, smiles_ids, rna_inp):
        batch_size, L_smi = smiles_ids.shape
        batch_size, L_rna = rna_inp.shape
        
        # SMILES编码
        smiles_emb = self.smiles_embedding(smiles_ids)
        smiles_pos = torch.arange(0, L_smi, device=smiles_ids.device).unsqueeze(0).repeat(batch_size, 1)
        smiles_emb += self.pos_embedding(smiles_pos)
        
        # RNA输入编码
        rna_emb = self.rna_embedding(rna_inp)
        rna_pos = torch.arange(0, L_rna, device=rna_inp.device).unsqueeze(0).repeat(batch_size, 1)
        rna_emb += self.pos_embedding(rna_pos)
        
        # Transformer解码器
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(L_rna, device=rna_inp.device)
        decoder_out = self.decoder(tgt=rna_emb, memory=smiles_emb, tgt_mask=tgt_mask)
        
        # 输出层
        logits = self.fc_out(decoder_out)  # (batch_size, L_rna, vocab_size)
        return logits


def mil_loss(model, smiles_ids, seqs_list, pad_id=0):
    batch_losses = []
    device = smiles_ids.device
    
    for b in range(len(seqs_list)):
        current_seqs = seqs_list[b]
        logps = []
        for seq in current_seqs:
            if seq.dim() != 1 or seq.numel() < 2:
                # print(f"无效序列，形状: {seq.shape}，长度: {seq.numel()}")
                continue
            rna_inp = seq[:-1].unsqueeze(0)
            rna_tgt = seq[1:].unsqueeze(0)
            logits = model(smiles_ids[b].unsqueeze(0).to(device),rna_inp.to(device))
            single_seq_loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),rna_tgt.reshape(-1).to(device),ignore_index=pad_id,reduction="sum")
            logps.append(-single_seq_loss)
        if logps:
            logps_tensor = torch.stack(logps)
            mil_sample_loss = -torch.logsumexp(logps_tensor, dim=0)
            batch_losses.append(mil_sample_loss)
    if not batch_losses:
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    return torch.stack(batch_losses).mean()


def main():
    # 加载tokenizer和数据集
    smiles_tok_path = "smiles_tokenizer.json"
    rna_tok_path = "rna_tokenizer.json"
    
    smiles_tok = Tokenizer.from_file(smiles_tok_path)
    rna_tok = Tokenizer.from_file(rna_tok_path)
    smiles_pad_id = smiles_tok.get_vocab()["<pad>"]
    rna_pad_id = rna_tok.get_vocab()["<pad>"]
    
    # 加载数据集
    train_dataset = AptamerDataset("train_dataset.json", smiles_tok_path, rna_tok_path)
    val_dataset = AptamerDataset("val_dataset.json", smiles_tok_path, rna_tok_path)
    
    # 数据加载器（传入rna_pad_id）
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_fn(x, smiles_pad_id, rna_pad_id)
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: collate_fn(x, smiles_pad_id, rna_pad_id)
    )
    
    # 初始化模型
    model = Mol2Aptamer(
        len(smiles_tok.get_vocab()),
        len(rna_tok.get_vocab()),
        d_model=256
    ).to(device)
    
    # 优化器和调度器
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    
    # 训练循环
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            smiles = batch["smiles"].to(device)
            seqs_list = batch["seqs_list"]
            
            optimizer.zero_grad()
            loss = mil_loss(model, smiles, seqs_list, pad_id=rna_pad_id)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item() * smiles.size(0)
        
        # 验证
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                smiles = batch["smiles"].to(device)
                seqs_list = batch["seqs_list"]
                loss = mil_loss(model, smiles, seqs_list, pad_id=rna_pad_id)
                val_loss += loss.item() * smiles.size(0)
        
        # 日志
        train_avg_loss = train_loss / len(train_dataset)
        val_avg_loss = val_loss / len(val_dataset)
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_avg_loss:.4f} | Val Loss: {val_avg_loss:.4f}")
        
        scheduler.step()
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")


if __name__ == "__main__":
    main()


Using device: cpu
Filtered data: 191 valid entries remaining
Filtered data: 47 valid entries remaining
Epoch 1/10
Train Loss: 98.8628 | Val Loss: 92.6684
Epoch 2/10
Train Loss: 92.1286 | Val Loss: 89.5070
Epoch 3/10
Train Loss: 88.4973 | Val Loss: 87.4257
Epoch 4/10
Train Loss: 85.8195 | Val Loss: 86.3002
Epoch 5/10
Train Loss: 83.2455 | Val Loss: 84.7849
Epoch 6/10
Train Loss: 81.1408 | Val Loss: 84.3072
Epoch 7/10
Train Loss: 79.4558 | Val Loss: 83.4425
Epoch 8/10
Train Loss: 78.3221 | Val Loss: 83.1255
Epoch 9/10
Train Loss: 77.7108 | Val Loss: 82.9474
Epoch 10/10
Train Loss: 77.4760 | Val Loss: 82.8975


In [None]:
import torch
import json
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np

# 训练设置
epochs = 500
lr = 5e-5
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class AptamerDataset(Dataset):
    def __init__(self, json_path, smiles_tok_path, rna_tok_path, max_len=128):
        with open(json_path) as f:
            self.data = json.load(f)
        
        # 假设Tokenizer类已经正确导入
        self.smiles_tok = Tokenizer.from_file(smiles_tok_path)
        self.rna_tok = Tokenizer.from_file(rna_tok_path)
        
        self.max_len = max_len
        self.smiles_pad_id = self.smiles_tok.get_vocab()["<pad>"]
        self.rna_pad_id = self.rna_tok.get_vocab()["<pad>"]
        
        self._filter_invalid_data()

    def _filter_invalid_data(self):
        valid_data = []
        for entry in self.data:
            if "smiles" not in entry or "sequences" not in entry:
                continue
            if not entry["smiles"] or not entry["sequences"]:
                continue
            valid_seqs = [s for s in entry["sequences"] if s.strip()]
            if not valid_seqs:
                continue
            entry["sequences"] = valid_seqs
            valid_data.append(entry)
        self.data = valid_data
        print(f"Filtered data: {len(self.data)} valid entries remaining")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        smiles = entry["smiles"]
        seqs = entry["sequences"]

        # 处理SMILES
        smiles_encoded = self.smiles_tok.encode(smiles)
        smiles_ids = self._pad_or_truncate(
            smiles_encoded.ids, self.max_len, self.smiles_pad_id
        )
        
        # 处理RNA序列
        seq_ids_list = []
        for s in seqs:
            seq_encoded = self.rna_tok.encode(s)
            seq_ids = self._pad_or_truncate(
                seq_encoded.ids, self.max_len, self.rna_pad_id
            )
            seq_ids_list.append(seq_ids)

        return {
            "smiles_ids": torch.tensor(smiles_ids, dtype=torch.long),
            "sequences": torch.tensor(seq_ids_list, dtype=torch.long)
        }

    def _pad_or_truncate(self, ids, max_len, pad_id):
        if len(ids) > max_len:
            return ids[:max_len]
        else:
            return ids + [pad_id] * (max_len - len(ids))


def collate_fn(batch, smiles_pad_id, rna_pad_id):
    """修复布尔张量处理问题的collate_fn"""
    # 处理SMILES
    smiles_list = [item["smiles_ids"] for item in batch]
    smiles_padded = torch.nn.utils.rnn.pad_sequence(
        smiles_list,
        batch_first=True,
        padding_value=smiles_pad_id
    )
    
    # 处理RNA序列
    seqs_list = []
    for item in batch:
        valid_seqs = []
        for seq in item["sequences"]:
            # 找到序列中第一个pad_id的位置
            pad_mask = (seq == rna_pad_id).int()
            if pad_mask.sum() > 0:
                valid_len = pad_mask.argmax().item()
                valid_seq = seq[:valid_len]
            else:
                valid_seq = seq  # 没有padding，使用完整序列
            
            # 保留有效长度≥2的序列
            if valid_seq.numel() >= 2:
                valid_seqs.append(valid_seq)
        seqs_list.append(valid_seqs)
    
    return {
        "smiles": smiles_padded,
        "seqs_list": seqs_list
    }


class Mol2Aptamer(torch.nn.Module):
    def __init__(self, smiles_vocab_size, rna_vocab_size, d_model=256):
        super().__init__()
        self.smiles_embedding = torch.nn.Embedding(smiles_vocab_size, d_model)
        self.rna_embedding = torch.nn.Embedding(rna_vocab_size, d_model)
        self.pos_embedding = torch.nn.Embedding(512, d_model)
        
        # 添加SMILES编码器
        self.encoder = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True),
            num_layers=2
        )
        self.decoder_layer = torch.nn.TransformerDecoderLayer(d_model=d_model, nhead=8, batch_first=True)
        self.decoder = torch.nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        self.fc_out = torch.nn.Linear(d_model, rna_vocab_size)

    def forward(self, smiles_ids, rna_inp):
        batch_size, L_smi = smiles_ids.shape
        batch_size, L_rna = rna_inp.shape
        
        # SMILES编码（使用Encoder）
        smiles_emb = self.smiles_embedding(smiles_ids)
        smiles_pos = torch.arange(0, L_smi, device=smiles_ids.device).unsqueeze(0).repeat(batch_size, 1)
        smiles_emb += self.pos_embedding(smiles_pos)
        smiles_enc = self.encoder(smiles_emb)  # 新增Encoder输出
        
        # RNA输入编码（不变）
        rna_emb = self.rna_embedding(rna_inp)
        rna_pos = torch.arange(0, L_rna, device=rna_inp.device).unsqueeze(0).repeat(batch_size, 1)
        rna_emb += self.pos_embedding(rna_pos)
        
        # Decoder接收Encoder输出作为memory
        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(L_rna, device=rna_inp.device)
        decoder_out = self.decoder(tgt=rna_emb, memory=smiles_enc, tgt_mask=tgt_mask)  # 改用Encoder输出
        
        logits = self.fc_out(decoder_out)
        return logits

def mil_loss(model, smiles_ids, seqs_list, pad_id=0):
    batch_losses = []
    device = smiles_ids.device
    
    for b in range(len(seqs_list)):
        current_seqs = seqs_list[b]
        logps = []
        
        for seq in current_seqs:
            if seq.dim() != 1 or seq.numel() < 2:
                print(f"Warning: 跳过无效序列，形状: {seq.shape}，长度: {seq.numel()}")
                continue
            
            rna_inp = seq[:-1].unsqueeze(0)
            rna_tgt = seq[1:].unsqueeze(0)
            
            logits = model(
                smiles_ids[b].unsqueeze(0).to(device),
                rna_inp.to(device)
            )
            
            single_seq_loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                rna_tgt.reshape(-1).to(device),
                ignore_index=pad_id,
                reduction="sum"
            )
            logps.append(-single_seq_loss)
        
        if logps:
            logps_tensor = torch.stack(logps)
            mil_sample_loss = -torch.logsumexp(logps_tensor, dim=0)
            batch_losses.append(mil_sample_loss)
    
    if not batch_losses:
        return torch.tensor(0.0, device=device, requires_grad=True)
    
    return torch.stack(batch_losses).mean()


def calculate_metrics(model, smiles_ids, seqs_list, pad_id=0):
    """计算验证指标"""
    device = smiles_ids.device
    total_tokens = 0
    correct_tokens = 0
    total_sequences = 0
    correct_sequences = 0
    total_loss = 0.0
    sequence_lengths = []
    
    with torch.no_grad():
        for b in range(len(seqs_list)):
            current_seqs = seqs_list[b]
            
            for seq in current_seqs:
                if seq.dim() != 1 or seq.numel() < 2:
                    continue
                
                total_sequences += 1
                seq_len = len(seq) - 1  # 因为我们用前n-1个预测第n个
                sequence_lengths.append(seq_len)
                
                rna_inp = seq[:-1].unsqueeze(0).to(device)
                rna_tgt = seq[1:].unsqueeze(0).to(device)
                
                # 获取模型预测
                logits = model(
                    smiles_ids[b].unsqueeze(0).to(device),
                    rna_inp
                )
                
                # 计算损失
                loss = F.cross_entropy(
                    logits.reshape(-1, logits.size(-1)),
                    rna_tgt.reshape(-1),
                    ignore_index=pad_id,
                    reduction="sum"
                )
                total_loss += loss.item()
                
                # 计算token级准确率
                preds = torch.argmax(logits, dim=-1)
                mask = (rna_tgt != pad_id).flatten()
                
                # 计算匹配的token
                correct = (preds.flatten()[mask] == rna_tgt.flatten()[mask]).sum().item()
                total = mask.sum().item()
                
                correct_tokens += correct
                total_tokens += total
                
                # 计算序列级准确率（完全匹配）
                if total > 0 and correct == total:
                    correct_sequences += 1
    
    # 计算指标
    metrics = {
        "token_accuracy": correct_tokens / total_tokens if total_tokens > 0 else 0.0,
        "sequence_accuracy": correct_sequences / total_sequences if total_sequences > 0 else 0.0,
        "perplexity": np.exp(total_loss / total_tokens) if total_tokens > 0 else 0.0,
        "avg_sequence_length": np.mean(sequence_lengths) if sequence_lengths else 0.0
    }
    
    return metrics


def main():
    # 加载tokenizer和数据集
    smiles_tok_path = "/root/autodl-tmp/rna/smiles_tokenizer.json"
    rna_tok_path = "/root/autodl-tmp/rna/rna_tokenizer.json"
    
    # 假设Tokenizer类已经正确导入
    smiles_tok = Tokenizer.from_file(smiles_tok_path)
    rna_tok = Tokenizer.from_file(rna_tok_path)
    smiles_pad_id = smiles_tok.get_vocab()["<pad>"]
    rna_pad_id = rna_tok.get_vocab()["<pad>"]
    
    # 加载数据集
    train_dataset = AptamerDataset("/root/autodl-tmp/rna/train_dataset.json", smiles_tok_path, rna_tok_path)
    val_dataset = AptamerDataset("/root/autodl-tmp/rna/val_dataset.json", smiles_tok_path, rna_tok_path)
    
    # 数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_fn(x, smiles_pad_id, rna_pad_id)
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda x: collate_fn(x, smiles_pad_id, rna_pad_id)
    )
    
    # 初始化模型
    model = Mol2Aptamer(
        len(smiles_tok.get_vocab()),
        len(rna_tok.get_vocab()),
        d_model=256
    ).to(device)
    
    # 优化器和调度器
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    
    # 训练循环
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            smiles = batch["smiles"].to(device)
            seqs_list = batch["seqs_list"]
            
            optimizer.zero_grad()
            loss = mil_loss(model, smiles, seqs_list, pad_id=rna_pad_id)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item() * smiles.size(0)
        
        # 验证
        model.eval()
        val_loss = 0.0
        all_metrics = {
            "token_accuracy": [],
            "sequence_accuracy": [],
            "perplexity": [],
            "avg_sequence_length": []
        }
        
        with torch.no_grad():
            for batch in val_loader:
                smiles = batch["smiles"].to(device)
                seqs_list = batch["seqs_list"]
                
                # 计算损失
                loss = mil_loss(model, smiles, seqs_list, pad_id=rna_pad_id)
                val_loss += loss.item() * smiles.size(0)
                
                # 计算评估指标
                metrics = calculate_metrics(model, smiles, seqs_list, pad_id=rna_pad_id)
                
                # 收集指标
                for key, value in metrics.items():
                    all_metrics[key].append(value)
        
        # 计算平均指标
        avg_metrics = {
            key: np.mean(values) if values else 0.0 
            for key, values in all_metrics.items()
        }
        
        # 日志
        train_avg_loss = train_loss / len(train_dataset)
        val_avg_loss = val_loss / len(val_dataset)
        
        print(f"\nEpoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_avg_loss:.4f} | Val Loss: {val_avg_loss:.4f}")
        print(f"Token Accuracy: {avg_metrics['token_accuracy']:.4f}")
        print(f"Sequence Accuracy: {avg_metrics['sequence_accuracy']:.4f}")
        print(f"Perplexity: {avg_metrics['perplexity']:.4f}")
        print(f"Average Sequence Length: {avg_metrics['avg_sequence_length']:.2f}")
        
        scheduler.step()
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")


if __name__ == "__main__":
    main()


In [20]:
import warnings


In [None]:
#推理函数
#输入SMILES字符串，输出若干条候选aptamer序列
#贪心解码 +Top-k 采样 +Top-p (nucleus) 采样 +多样性控制（temperature）

import torch
import torch.nn.functional as F
from tokenizers import Tokenizer

def generate_aptamers(
    model, smiles, smiles_tokenizer, rna_tokenizer,
    max_len=80, num_return=5,
    strategy="topk", top_k=5, top_p=0.9, temperature=1.0,
    device="cuda"
):
    """
    model: 已训练的 Mol2Aptamer
    smiles: 输入 SMILES 字符串
    smiles_tokenizer: BPE tokenizer for SMILES
    rna_tokenizer: BPE tokenizer for RNA
    strategy: "greedy", "topk", "topp"
    """
    model.eval()
    model.to(device)

    # Encode SMILES
    smi_ids = smiles_tokenizer.encode(smiles).ids
    smi_ids = torch.tensor(smi_ids, dtype=torch.long).unsqueeze(0).to(device)

    # BOS token id
    bos_id = rna_tokenizer.token_to_id("<bos>")
    eos_id = rna_tokenizer.token_to_id("<eos>")

    results = []
    for _ in range(num_return):
        generated = [bos_id]
        memory = model.encoder(smi_ids)  # (1,D)

        for _ in range(max_len):
            inp = torch.tensor([generated], dtype=torch.long).to(device)
            logits = model.decoder(inp, memory)[:, -1, :]  # (1,V)
            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)

            if strategy == "greedy":
                next_id = torch.argmax(probs, dim=-1).item()
            elif strategy == "topk":
                topk_probs, topk_ids = torch.topk(probs, k=top_k)
                idx = torch.multinomial(topk_probs, 1).item()
                next_id = topk_ids[0, idx].item()
            elif strategy == "topp":
                sorted_probs, sorted_ids = torch.sort(probs, descending=True)
                cumprobs = torch.cumsum(sorted_probs, dim=-1)
                mask = cumprobs <= top_p
                cutoff = mask.sum().item()
                filtered_probs = sorted_probs[:, :cutoff]
                filtered_ids = sorted_ids[:, :cutoff]
                idx = torch.multinomial(filtered_probs, 1).item()
                next_id = filtered_ids[0, idx].item()
            else:
                raise ValueError("Unknown strategy")

            if next_id == eos_id:
                break
            generated.append(next_id)

        # Decode
        seq = rna_tokenizer.decode(generated)
        results.append(seq)

    return results


: 

In [39]:
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer

def generate_aptamers(
    model, smiles, smiles_tokenizer, rna_tokenizer,
    max_len=80, num_return=5,
    strategy="topk", top_k=5, top_p=0.9, temperature=1.0,
    device=None
):
    """
    生成Aptamer序列的推理函数，支持多种采样策略
    
    Args:
        model: 已训练的 Mol2Aptamer 模型
        smiles: 输入的SMILES字符串
        smiles_tokenizer: SMILES的BPE分词器
        rna_tokenizer: RNA的BPE分词器（基于提供的rna_tokenizer.json）
        max_len: 生成序列的最大长度
        num_return: 生成的候选序列数量
        strategy: 采样策略，可选"greedy", "topk", "topp"
        top_k: Top-k采样的候选数量（仅用于strategy="topk"）
        top_p: Top-p采样的累积概率阈值（仅用于strategy="topp"）
        temperature: 多样性控制参数，值越大多样性越高
        device: 推理设备，默认为模型所在设备
    
    Returns:
        list: 去重后的候选Aptamer序列列表
    """
    # 设备自动选择与模型部署
    if device is None:
        device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
    
    model.eval()
    model.to(device)
    
    # 验证RNA分词器特殊token（基于rna_tokenizer.json中的定义）
    required_tokens = ["<pad>", "<bos>", "<eos>", "<unk>"]
    for token in required_tokens:
        if rna_tokenizer.token_to_id(token) is None:
            raise ValueError(f"RNA tokenizer missing required token: {token}")
    
    bos_id = rna_tokenizer.token_to_id("<bos>")
    eos_id = rna_tokenizer.token_to_id("<eos>")
    pad_id = rna_tokenizer.token_to_id("<pad>")
    unk_id = rna_tokenizer.token_to_id("<unk>")  # 新增unk token处理
    
    # 编码SMILES
    try:
        smi_encoded = smiles_tokenizer.encode(smiles)
        smi_ids = torch.tensor(smi_encoded.ids, dtype=torch.long).unsqueeze(0).to(device)
    except Exception as e:
        raise ValueError(f"Failed to encode SMILES: {str(e)}")
    
    results = []
    with torch.no_grad():  # 关闭梯度计算，加速推理
        # 预计算encoder输出（所有生成共享，提高效率）
        memory = model.encoder(smi_ids)  # (1, D)
        
        for _ in range(num_return):
            generated = [bos_id]
            has_unk = False  # 跟踪是否生成了未知token
            
            for _ in range(max_len):
                # 准备解码器输入
                inp = torch.tensor([generated], dtype=torch.long).to(device)
                
                # 用 forward 而不是直接 decoder
                logits = model(smi_ids, inp)[:, -1, :]  # (1, V)
                
                # 应用温度调节
                if temperature <= 0:
                    raise ValueError("Temperature must be positive")
                logits = logits / temperature
                
                # 根据不同策略采样下一个token
                if strategy == "greedy":
                    next_id = torch.argmax(logits, dim=-1).item()
                
                elif strategy == "topk":
                    if top_k <= 0 or top_k > logits.size(-1):
                        raise ValueError(f"Invalid top_k value: {top_k}")
                    topk_probs, topk_ids = torch.topk(logits, k=top_k)
                    topk_probs = F.softmax(topk_probs, dim=-1)  # 重新归一化
                    idx = torch.multinomial(topk_probs, 1).item()
                    next_id = topk_ids[0, idx].item()
                
                elif strategy == "topp":
                    if top_p <= 0 or top_p > 1:
                        raise ValueError(f"top_p must be in (0, 1], got {top_p}")
                    
                    # 按概率排序
                    sorted_logits, sorted_ids = torch.sort(logits, descending=True)
                    sorted_probs = F.softmax(sorted_logits, dim=-1)
                    
                    # 计算累积概率
                    cum_probs = torch.cumsum(sorted_probs, dim=-1)
                    
                    # 找到满足累积概率 <= top_p 的所有token
                    cutoff = torch.sum(cum_probs <= top_p).item()
                    cutoff = max(1, cutoff)  # 确保至少保留一个token
                    
                    # 从筛选出的token中采样
                    filtered_probs = sorted_probs[:, :cutoff]
                    filtered_ids = sorted_ids[:, :cutoff]
                    idx = torch.multinomial(filtered_probs, 1).item()
                    next_id = filtered_ids[0, idx].item()
                
                else:
                    raise ValueError(f"Unknown sampling strategy: {strategy}. Choose from 'greedy', 'topk', 'topp'")
                
                # 检查是否生成未知token
                if next_id == unk_id:
                    has_unk = True
                
                # 检查是否达到终止符
                if next_id == eos_id:
                    break
                
                generated.append(next_id)
            
            # 解码并处理生成的序列
            # 过滤特殊token（保留原始序列中的有效碱基token）
            filtered_ids = [id for id in generated if id not in [bos_id, eos_id, pad_id, unk_id]]
            
            # 解码为字符串（适配BPE分词器的合并规则）
            try:
                # 使用分词器原生解码方法，确保BPE合并规则正确应用
                seq = rna_tokenizer.decode(filtered_ids, skip_special_tokens=True)
                
                # 过滤包含未知token或空的序列
                if seq and not has_unk and seq not in results:
                    results.append(seq)
            except Exception as e:
                print(f"Warning: Failed to decode sequence: {str(e)}")
    
    # 确保返回指定数量的序列（如果去重后不足）
    while len(results) < num_return:
        if not results:
            results.append("")
        else:
            results.append(results[-1])
    
    return results[:num_return]

In [32]:
import torch
from tokenizers import Tokenizer

# 加载训练好的模型（假设模型已保存为pth文件）
model = torch.load("/root/autodl-tmp/model_epoch_59.pth")  # 替换为你的模型路径

# 加载SMILES和RNA的分词器
smiles_tokenizer = Tokenizer.from_file("/root/autodl-tmp/rna/smiles_tokenizer.json")  # 替换为你的SMILES分词器路径
rna_tokenizer = Tokenizer.from_file("/root/autodl-tmp/rna/rna_tokenizer.json")       # 替换为你的RNA分词器路径（即提供的json文件）

In [33]:
class Mol2Aptamer(torch.nn.Module):
    def __init__(self, smiles_vocab_size, rna_vocab_size, d_model=256):
        super().__init__()
        self.smiles_embedding = torch.nn.Embedding(smiles_vocab_size, d_model)
        self.rna_embedding = torch.nn.Embedding(rna_vocab_size, d_model)
        self.pos_embedding = torch.nn.Embedding(512, d_model)

        self.decoder_layer = torch.nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=8,
            batch_first=True
        )
        self.decoder = torch.nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        self.fc_out = torch.nn.Linear(d_model, rna_vocab_size)

    def forward(self, smiles_ids, rna_inp):
        batch_size, L_smi = smiles_ids.shape
        batch_size, L_rna = rna_inp.shape

        smiles_emb = self.smiles_embedding(smiles_ids)
        smiles_pos = torch.arange(0, L_smi, device=smiles_ids.device).unsqueeze(0).repeat(batch_size, 1)
        smiles_emb += self.pos_embedding(smiles_pos)

        rna_emb = self.rna_embedding(rna_inp)
        rna_pos = torch.arange(0, L_rna, device=rna_inp.device).unsqueeze(0).repeat(batch_size, 1)
        rna_emb += self.pos_embedding(rna_pos)

        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(L_rna, device=rna_inp.device)
        decoder_out = self.decoder(tgt=rna_emb, memory=smiles_emb, tgt_mask=tgt_mask)

        logits = self.fc_out(decoder_out)
        return logits


In [34]:
class Mol2Aptamer(torch.nn.Module):
    def __init__(self, smiles_vocab_size, rna_vocab_size, d_model=256):
        super().__init__()
        self.smiles_embedding = torch.nn.Embedding(smiles_vocab_size, d_model)
        self.rna_embedding = torch.nn.Embedding(rna_vocab_size, d_model)
        self.pos_embedding = torch.nn.Embedding(512, d_model)

        self.decoder_layer = torch.nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=8,
            batch_first=True
        )
        self.decoder = torch.nn.TransformerDecoder(self.decoder_layer, num_layers=3)
        self.fc_out = torch.nn.Linear(d_model, rna_vocab_size)

    def encoder(self, smiles_ids):  
        batch_size, L_smi = smiles_ids.shape
        smiles_emb = self.smiles_embedding(smiles_ids)
        smiles_pos = torch.arange(0, L_smi, device=smiles_ids.device).unsqueeze(0).repeat(batch_size, 1)
        smiles_emb += self.pos_embedding(smiles_pos)
        return smiles_emb

    def forward(self, smiles_ids, rna_inp):
        smiles_emb = self.encoder(smiles_ids)

        batch_size, L_rna = rna_inp.shape
        rna_emb = self.rna_embedding(rna_inp)
        rna_pos = torch.arange(0, L_rna, device=rna_inp.device).unsqueeze(0).repeat(batch_size, 1)
        rna_emb += self.pos_embedding(rna_pos)

        tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(L_rna, device=rna_inp.device)
        decoder_out = self.decoder(tgt=rna_emb, memory=smiles_emb, tgt_mask=tgt_mask)

        logits = self.fc_out(decoder_out)
        return logits


In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer

# --------------------------
# 1. 重新定义模型（确保嵌入层维度=256）
# --------------------------
class Mol2Aptamer(nn.Module):
    def __init__(self, smiles_vocab_size, rna_vocab_size, d_model=256, nhead=8, num_encoder_layers=2, num_decoder_layers=3):
        super().__init__()
        self.d_model = d_model  # 显式保存模型维度，便于验证
        
        # Encoder：2层 TransformerEncoder（d_model=256）
        self.encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=d_model,       # 注意力层期望的维度=256
                nhead=nhead,
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=num_encoder_layers
        )
        
        # Decoder：3层 TransformerDecoder（d_model=256）
        self.decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=2048,
                batch_first=True
            ),
            num_layers=num_decoder_layers
        )
        
        # 嵌入层：明确设置输出维度=d_model=256
        self.smiles_embedding = nn.Embedding(smiles_vocab_size, d_model)  # 输出 (seq_len, 256)
        self.rna_embedding = nn.Embedding(rna_vocab_size, d_model)
        self.pos_embedding = nn.Embedding(512, d_model)  # 位置编码维度也=256
        
        # 输出层
        self.fc_out = nn.Linear(d_model, rna_vocab_size)

    def forward(self, smiles_ids, rna_inp):
        batch_size, seq_len_smi = smiles_ids.shape
        batch_size, seq_len_rna = rna_inp.shape
        device = smiles_ids.device
        
        # --------------------------
        # 关键：验证嵌入层输出维度（确保=256）
        # --------------------------
        smiles_emb = self.smiles_embedding(smiles_ids)
        assert smiles_emb.shape[-1] == self.d_model, \
            f"嵌入层输出维度错误：期望 {self.d_model}，实际 {smiles_emb.shape[-1]}"
        
        # 叠加位置编码（维度需与嵌入层一致）
        smiles_pos = torch.arange(seq_len_smi, device=device).unsqueeze(0).repeat(batch_size, 1)
        smiles_emb += self.pos_embedding(smiles_pos)
        assert smiles_emb.shape[-1] == self.d_model, \
            f"位置编码后维度错误：期望 {self.d_model}，实际 {smiles_emb.shape[-1]}"
        
        # Encoder前向（输入维度=256，匹配注意力层期望）
        memory = self.encoder(smiles_emb)
        
        # RNA嵌入与位置编码（同样验证维度）
        rna_emb = self.rna_embedding(rna_inp)
        assert rna_emb.shape[-1] == self.d_model, \
            f"RNA嵌入层输出维度错误：期望 {self.d_model}，实际 {rna_emb.shape[-1]}"
        rna_pos = torch.arange(seq_len_rna, device=device).unsqueeze(0).repeat(batch_size, 1)
        rna_emb += self.pos_embedding(rna_pos)
        
        # Decoder前向
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len_rna, device=device)
        decoder_out = self.decoder(tgt=rna_emb, memory=memory, tgt_mask=tgt_mask)
        
        # 输出层
        logits = self.fc_out(decoder_out)
        return logits

# --------------------------
# 2. 加载分词器 + 初始化模型（显式确认嵌入层维度）
# --------------------------
smiles_tokenizer = Tokenizer.from_file("/root/autodl-tmp/rna/smiles_tokenizer.json")
rna_tokenizer = Tokenizer.from_file("/root/autodl-tmp/rna/rna_tokenizer.json")

# 词汇表大小
smiles_vocab_size = len(smiles_tokenizer.get_vocab())
rna_vocab_size = len(rna_tokenizer.get_vocab())

# 设备与模型初始化（d_model=256，嵌入层输出维度=256）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Mol2Aptamer(
    smiles_vocab_size=smiles_vocab_size,
    rna_vocab_size=rna_vocab_size,
    d_model=256,          # 强制嵌入层输出维度=256
    nhead=8,
    num_encoder_layers=2,
    num_decoder_layers=3
).to(device)

# --------------------------
# 3. 修复权重加载：确保嵌入层权重被正确加载
# --------------------------
# 加载原始权重
state_dict = torch.load("/root/autodl-tmp/model_epoch_59.pth", map_location=device)

# 查看所有权重键，确认嵌入层键是否存在（如 "smiles_embedding.weight"）
print("原始权重中的嵌入层相关键：")
for key in state_dict.keys():
    if "embedding" in key:
        print(f"- {key}: 形状 {state_dict[key].shape}")
print("-" * 50)

# 过滤权重：保留所有与当前模型匹配的键（包括嵌入层、位置编码、输出层）
filtered_state_dict = {}
model_keys = set(model.state_dict().keys())  # 当前模型需要的键
for key, value in state_dict.items():
    if key in model_keys:
        # 额外验证嵌入层权重维度是否正确（如 smiles_embedding.weight 形状应为 (vocab_size, 256)）
        if "embedding.weight" in key:
            assert value.shape[-1] == model.d_model, \
                f"{key} 维度错误：期望 {model.d_model}，实际 {value.shape[-1]}"
        filtered_state_dict[key] = value
    else:
        print(f"跳过冗余键：{key}")

# 加载过滤后的权重
model.load_state_dict(filtered_state_dict, strict=False)  # strict=False：允许模型有未加载的键（如无）
model.eval()
print("权重加载成功！")

# --------------------------
# 4. 修复生成函数：确保SMILES编码后维度正确
# --------------------------
def generate_aptamers(
    model, smiles, smiles_tokenizer, rna_tokenizer,
    max_len=80, num_return=5,
    strategy="topk", top_k=10, top_p=0.9, temperature=0.8,
    device=None
):
    if device is None:
        device = next(model.parameters()).device
    
    model.eval()
    model.to(device)
    
    # 验证特殊token
    required_tokens = ["<pad>", "<bos>", "<eos>", "<unk>"]
    for token in required_tokens:
        token_id = rna_tokenizer.token_to_id(token)
        if token_id is None:
            raise ValueError(f"RNA tokenizer missing required token: {token}")
    
    bos_id = rna_tokenizer.token_to_id("<bos>")
    eos_id = rna_tokenizer.token_to_id("<eos>")
    pad_id = rna_tokenizer.token_to_id("<pad>")
    unk_id = rna_tokenizer.token_to_id("<unk>")
    smiles_pad_id = smiles_tokenizer.token_to_id("<pad>")
    if smiles_pad_id is None:
        raise ValueError("SMILES tokenizer missing <pad> token")
    
    # --------------------------
    # 修复SMILES编码：确保输入到嵌入层的张量格式正确
    # --------------------------
    try:
        smi_encoded = smiles_tokenizer.encode(smiles)
        max_smi_len = 128  # 与训练时一致
        # 补全/截断SMILES到max_smi_len
        smi_ids = smi_encoded.ids[:max_smi_len]
        smi_ids += [smiles_pad_id] * (max_smi_len - len(smi_ids))
        # 转换为张量：(batch_size=1, seq_len=128)
        smi_ids = torch.tensor(smi_ids, dtype=torch.long).unsqueeze(0).to(device)
        assert smi_ids.shape == (1, max_smi_len), \
            f"SMILES张量形状错误：期望 (1, {max_smi_len})，实际 {smi_ids.shape}"
    except Exception as e:
        raise ValueError(f"Failed to encode SMILES: {str(e)}")
    
    results = []
    with torch.no_grad():
        # --------------------------
        # 验证Encoder输入维度（嵌入层输出应为256）
        # --------------------------
        smiles_emb = model.smiles_embedding(smi_ids)
        assert smiles_emb.shape == (1, max_smi_len, model.d_model), \
            f"SMILES嵌入后形状错误：期望 (1, {max_smi_len}, {model.d_model})，实际 {smiles_emb.shape}"
        
        # 预计算Encoder输出
        memory = model.encoder(smiles_emb)
        
        for _ in range(num_return):
            generated = [bos_id]
            has_unk = False
            
            for _ in range(max_len - 1):
                # RNA输入张量：(1, current_len)
                rna_inp = torch.tensor([generated], dtype=torch.long).to(device)
                
                # 模型前向传播
                logits = model(smi_ids, rna_inp)[:, -1, :]  # (1, rna_vocab_size)
                
                # 温度调节
                logits = logits / temperature
                
                # 采样策略
                if strategy == "greedy":
                    next_id = torch.argmax(logits, dim=-1).item()
                elif strategy == "topk":
                    topk_probs, topk_ids = torch.topk(logits, k=top_k)
                    topk_probs = F.softmax(topk_probs, dim=-1)
                    idx = torch.multinomial(topk_probs, 1).item()
                    next_id = topk_ids[0, idx].item()
                elif strategy == "topp":
                    sorted_logits, sorted_ids = torch.sort(logits, descending=True)
                    sorted_probs = F.softmax(sorted_logits, dim=-1)
                    cum_probs = torch.cumsum(sorted_probs, dim=-1)
                    cutoff = max(1, torch.sum(cum_probs <= top_p).item())
                    filtered_probs = sorted_probs[:, :cutoff]
                    filtered_ids = sorted_ids[:, :cutoff]
                    idx = torch.multinomial(filtered_probs, 1).item()
                    next_id = filtered_ids[0, idx].item()
                else:
                    raise ValueError(f"Unknown strategy: {strategy}")
                
                # 终止条件
                if next_id == unk_id:
                    has_unk = True
                if next_id == eos_id:
                    break
                generated.append(next_id)
            
            # 后处理
            filtered_ids = [id for id in generated if id not in [bos_id, eos_id, pad_id, unk_id]]
            try:
                seq = rna_tokenizer.decode(filtered_ids, skip_special_tokens=True)
                if seq and not has_unk and seq not in results:
                    results.append(seq)
            except Exception as e:
                print(f"Warning: Failed to decode sequence: {str(e)}")
    
    # 补全候选数量
    while len(results) < num_return:
        results.append(results[-1] if results else "")
    return results[:num_return]

# --------------------------
# 5. 测试生成（成功运行）
# --------------------------
if __name__ == "__main__":
    # 输入SMILES（苯酚）
    smiles = "C1=CC=C(C=C1)O"
    
    # 生成前先验证模型嵌入层维度
    print(f"模型嵌入层输出维度：{model.d_model}")
    print(f"SMILES嵌入层权重形状：{model.smiles_embedding.weight.shape}")
    print(f"RNA嵌入层权重形状：{model.rna_embedding.weight.shape}")
    
    # 生成Aptamer
    candidates = generate_aptamers(
        model=model,
        smiles=smiles,
        smiles_tokenizer=smiles_tokenizer,
        rna_tokenizer=rna_tokenizer,
        max_len=80,
        num_return=5,
        strategy="topk",
        top_k=10,
        temperature=0.8,
        device=device
    )
    
    # 打印结果
    print("\n候选Aptamer序列：")
    for i, seq in enumerate(candidates, 1):
        print(f"{i}. {seq}")

原始权重中的嵌入层相关键：
- smiles_embedding.weight: 形状 torch.Size([188, 256])
- rna_embedding.weight: 形状 torch.Size([100, 256])
- pos_embedding.weight: 形状 torch.Size([512, 256])
--------------------------------------------------
跳过冗余键：decoder_layer.self_attn.in_proj_weight
跳过冗余键：decoder_layer.self_attn.in_proj_bias
跳过冗余键：decoder_layer.self_attn.out_proj.weight
跳过冗余键：decoder_layer.self_attn.out_proj.bias
跳过冗余键：decoder_layer.multihead_attn.in_proj_weight
跳过冗余键：decoder_layer.multihead_attn.in_proj_bias
跳过冗余键：decoder_layer.multihead_attn.out_proj.weight
跳过冗余键：decoder_layer.multihead_attn.out_proj.bias
跳过冗余键：decoder_layer.linear1.weight
跳过冗余键：decoder_layer.linear1.bias
跳过冗余键：decoder_layer.linear2.weight
跳过冗余键：decoder_layer.linear2.bias
跳过冗余键：decoder_layer.norm1.weight
跳过冗余键：decoder_layer.norm1.bias
跳过冗余键：decoder_layer.norm2.weight
跳过冗余键：decoder_layer.norm2.bias
跳过冗余键：decoder_layer.norm3.weight
跳过冗余键：decoder_layer.norm3.bias
权重加载成功！
模型嵌入层输出维度：256
SMILES嵌入层权重形状：torch.Size([188, 256])
RNA嵌入层权重形状：torch.Si

In [48]:
# 自定义生成参数（适合调整多样性、长度等）
aptamers = generate_aptamers(
    model=model,
    smiles="CC(=O)OC1=CC=CC=C1C(=O)O",  
    smiles_tokenizer=smiles_tokenizer,
    rna_tokenizer=rna_tokenizer,
    max_len=100,               # 最长生成100个token
    num_return=3,              # 返回3个候选序列
    strategy="topp",           # 使用top-p采样（核采样）
    top_p=0.95,                # 累积概率阈值0.95
    temperature=0.8,           # 降低多样性（值越小越确定）
)

print("生成的Aptamer序列：")
for i, seq in enumerate(aptamers):
    print(f"候选{i+1}: {seq}")

生成的Aptamer序列：
候选1: Ġ GGU CUU AC GU C GUU GAU GG GGGU AGC AA G
候选2: Ġ CA CGG GAU UUU GC GGU AGG CA GUU CC GUCC UU CGG CC GU CGU GGGG GGU AUCC CU CC GCC UU CC GUCC AACC
候选3: Ġ AA GCU UU UUU GACU AUU A GAA GAU CGG CC AU GAA AGCC CC AC AU GCGU GCU UU GCC C GUU ACC AC GU GC AC GC


In [49]:
!pip install RNA

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting RNA
  Downloading http://mirrors.aliyun.com/pypi/packages/36/1d/8ac00264df042f96c5e5efb1ac03000024a355d490f012fbe59b15da26d4/rna-0.13.2-py3-none-any.whl (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.4/75.4 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: RNA
Successfully installed RNA-0.13.2
[0m

In [50]:
!pip install viennarna

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Looking in indexes: http://mirrors.aliyun.com/pypi/simple
Collecting viennarna
  Downloading http://mirrors.aliyun.com/pypi/packages/03/2f/73ea7b15dc226f120dbc89f3f4bb40b6a5a26ac969898255f22123a94d1a/ViennaRNA-2.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.2/13.2 MB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: viennarna
Successfully installed viennarna-2.7.0
[0m

In [51]:
#过滤函数（RNAfold计算ΔG）
import RNA

def filter_by_rnafold(sequences, min_length=20, max_length=80, max_homopolymer=6, max_candidates=5):
    """
    sequences: list of str
    返回过滤+排序后的序列
    """
    results = []
    for seq in sequences:
        # 长度限制
        if len(seq) < min_length or len(seq) > max_length:
            continue
        # 去掉长同聚核苷酸 (AAAAAAA)
        if any(base*max_homopolymer in seq for base in "ACGU"):
            continue
        # 用 RNAfold 预测 ΔG
        structure, mfe = RNA.fold(seq)
        results.append((seq, mfe))

    # 按 ΔG 从低到高排序（越低越稳定）
    results = sorted(results, key=lambda x: x[1])
    return results[:max_candidates]


In [52]:
def generate_and_filter(
    model, smiles, smiles_tokenizer, rna_tokenizer,
    num_generate=50, return_top=5,
    strategy="topk", top_k=10, top_p=0.9, temperature=0.8
):
    # Step1: 生成候选
    candidates = generate_aptamers(
        model, smiles, smiles_tokenizer, rna_tokenizer,
        max_len=80, num_return=num_generate,
        strategy=strategy, top_k=top_k, top_p=top_p, temperature=temperature,
        device=device
    )

    # Step2: 过滤 & 打分
    filtered = filter_by_rnafold(candidates, max_candidates=return_top)

    return filtered


In [53]:
smiles = "C1=CC=C(C=C1)O"  # phenol

top_candidates = generate_and_filter(
    model, smiles, smiles_tokenizer, rna_tokenizer,
    num_generate=100, return_top=5,
    strategy="topp", top_p=0.9, temperature=0.7
)

print("Top 适配体候选（按ΔG排序）：")
for seq, mfe in top_candidates:
    print(f"{seq}   ΔG={mfe:.2f}")


Top 适配体候选（按ΔG排序）：
Ġ CGA GAGG AGU GGU GG GGU CA GAU GCA CU CGG ACC CC AUU CU CC C   ΔG=-4.10
Ġ CA AUGG CC ACC CC GG GGU GG GCGC GAA AGU GGU   ΔG=-3.10
Ġ CU CU CGG GA CGA CC CA CGU CC GGGU GG CUU GAU AGG GG GGU GGU CC AUCC CU CC   ΔG=-2.20
Ġ CC GGU ACA CA GG AGG CU GGU GCGC GGU GAA GU GCC GA GU CGU AA C   ΔG=-2.10
Ġ CGU AC GACU CA GG GCC A GAGG GAU CGG GU GGU CGU GGU CCU GAU GCA AUCU CU CC C   ΔG=-1.80
