## PSMILES 标准化

In [1]:
from canonicalize_psmiles.canonicalize import canonicalize
from rdkit import Chem
from typing import Optional
import tqdm as notebook_tqdm

In [2]:
def has_two_stars(ps): # 确保有且仅有两个连接位点
    return ps.count('[*]') == 2

def canonicalize_or_skip(psmiles: str) -> Optional[str]:
    try:
        canon = canonicalize(psmiles)
        if not has_two_stars(canon):
            return None
        # 基础 RDKit 语法校验（PSMILES到分子对象可能需要将 [*] 替换为占位原子）
        tmp = canon.replace('[*]', '[Xe]')  # 占位为惰性原子检查价态/语法
        if Chem.MolFromSmiles(tmp) is None:
            return None
        return canon
    except Exception:
        return None



In [3]:
# 加载 CSV、自动识别 PSMILES 列、准备随机种子与工具函数。
import re
import random
from pathlib import Path
import pandas as pd

random.seed(42)

In [4]:
CSV_PATH = Path("data/PSMILES_Tg_only.csv")

# 载入数据 & 自动识别 PSMILES 列（必要时你可以手工改成具体列名）
df = pd.read_csv(CSV_PATH)
candidate_cols = [c for c in df.columns if str(c).lower() in ("PSMILES")]
ps_col = candidate_cols[0] if candidate_cols else df.columns[0]
ps_list = df[ps_col].astype(str).tolist()

print("Using column:", ps_col, "Total rows:", len(ps_list))

Using column: PSMILES Total rows: 7367


In [5]:
from transformers import AutoTokenizer

ENC_NAME = "kuelumbus/polyBERT"     # polyBERT 自带 SentencePiece 词表
tok = AutoTokenizer.from_pretrained(ENC_NAME, use_fast=False)

print("vocab_size =", len(tok))
print("mask       =", tok.mask_token, tok.mask_token_id)
print("pad        =", tok.pad_token, tok.pad_token_id)


  from .autonotebook import tqdm as notebook_tqdm


vocab_size = 270
mask       = [MASK] 268
pad        = [PAD] 267


In [None]:
import re

# 是否强制恰好两个 [*]（True/False 可切换）
ENFORCE_TWO_STARS = True

def simple_clean(ps: str) -> str:
    # 仅去空白与极少量全角符号；不做任何 canonicalize 重排
    ps = ps.strip()
    ps = re.sub(r"\s+", "", ps)
    ps = ps.replace("＃", "#").replace("／", "/").replace("＊", "*")
    return ps

def valid_two_stars(ps: str) -> bool:
    return ps.count("[*]") == 2 if ENFORCE_TWO_STARS else True


In [None]:
import random, numpy as np

def span_corrupt_ids(
    input_ids,
    mask_id,
    attention_mask=None,
    tokenizer=None,
    corruption_ratio=0.2,
    mean_span=3,
    protect_two_stars=True,
):
    """
    更安全的 span corruption：
    - 不遮 CLS/SEP/PAD/MASK 等特殊 token
    - 只遮 attention_mask==1 的位置
    - 默认保护 PSMILES 的聚合位点 token '[*]'
    """
    ids = list(input_ids)
    L = len(ids)
    if mask_id is None or L == 0:
        return ids

    # 可遮位置的候选集合
    candidates = list(range(L))
    if attention_mask is not None:
        candidates = [i for i in candidates if attention_mask[i] == 1]

    # 保护集合：所有特殊 token（含 CLS/SEP/PAD/MASK 等）
    protected = set()
    if tokenizer is not None and getattr(tokenizer, "all_special_ids", None):
        protected.update(tokenizer.all_special_ids)

    # 保护 '[*]'（聚合端点）
    star_id = None
    if protect_two_stars:
        try:
            star_id = tokenizer.convert_tokens_to_ids("[*]")
            if isinstance(star_id, int) and star_id >= 0:
                protected.add(star_id)
        except Exception:
            pass

    # 去掉受保护 token 的索引
    candidates = [i for i in candidates if ids[i] not in protected]
    if not candidates:
        return ids

    # 计算需要遮的数量（基于候选区）
    num_to_mask = max(1, int(len(candidates) * corruption_ratio))
    covered = 0
    used = set()

    while covered < num_to_mask and len(used) < len(candidates):
        # 随机起点
        i = random.choice(candidates)
        if i in used:
            continue
        span = max(1, np.random.poisson(mean_span))
        # 尝试向右扩展 span
        for k in range(i, min(i + span, L)):
            if k in used:
                continue
            if k not in candidates:
                continue
            ids[k] = mask_id
            used.add(k)
            covered += 1
            if covered >= num_to_mask:
                break

    return ids



In [None]:
def test_span_corrupt(psmiles: str, corruption_ratio=0.2, mean_span=3):
    # 分词
    enc = tok(psmiles, truncation=True, max_length=128, padding="max_length", return_tensors="pt")
    input_ids = enc["input_ids"][0].tolist()
    attention_mask = enc["attention_mask"][0].tolist()

    # 掩码（调用你写的 span_corrupt_ids 函数）
    masked_ids = span_corrupt_ids(
        input_ids=input_ids,
        mask_id=tok.mask_token_id,
        attention_mask=attention_mask,
        tokenizer=tok,
        corruption_ratio=corruption_ratio,
        mean_span=mean_span,
        protect_two_stars=True
    )

    # 转回 token
    tokens_original = tok.convert_ids_to_tokens(input_ids)
    tokens_masked = tok.convert_ids_to_tokens(masked_ids)

    # 打印结果
    print("原始 PSMILES:")
    print(psmiles)
    print("\n原始 tokens:")
    print(tokens_original)
    print("\n掩码后 tokens:")
    print(tokens_masked)
    print("\n掩码后可读形式:")
    print(tok.convert_tokens_to_string(tokens_masked))


# 示例运行
test_span_corrupt("[*]CC(=O)OCC[*]")


In [None]:
def batch_mask_psmiles(ps_list, corruption_ratio=0.2, mean_span=3):
    masked_psmiles = []
    for psmiles in notebook_tqdm.tqdm(ps_list, desc="Processing PSMILES"):
        enc = tok(psmiles, truncation=True, max_length=128, padding="max_length", return_tensors="pt")
        input_ids = enc["input_ids"][0].tolist()
        attention_mask = enc["attention_mask"][0].tolist()

        masked_ids = span_corrupt_ids(
            input_ids=input_ids,
            mask_id=tok.mask_token_id,
            attention_mask=attention_mask,
            tokenizer=tok,
            corruption_ratio=corruption_ratio,
            mean_span=mean_span,
            protect_two_stars=True
        )

        masked_psmiles.append(tok.convert_tokens_to_string(tok.convert_ids_to_tokens(masked_ids)))
    return masked_psmiles

# 批量处理
masked_ps_list = batch_mask_psmiles(ps_list)

# 将结果保存到新的 DataFrame 列中
df["Masked_PSMILES"] = masked_ps_list
print(df.head())
df.to_csv("PSMILES_Tg_only_masked.csv", index=False)

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
MAX_LEN = 128 # 规定token序列的固定最大长度
class PSMILESDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len=128):
        self.data = dataframe # 包含了原始和掩码后PSMILES的DataFrame
        self.tokenizer = tokenizer # 用于分词的tokenizer，这里使用的是polyBERT的tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index] # 获取指定行
        masked_psmiles = row["Masked_PSMILES"] # 掩码后的PSMILES
        original_psmiles = row["PSMILES"] # 原始PSMILES

        # Tokenize masked PSMILES (input to encoder)
        masked_enc = self.tokenizer(
            masked_psmiles,
            truncation=True, # 超过max_len的部分被截断，一般情况下不会
            max_length=self.max_len, 
            padding="max_length", # 填充到固定长度
            return_tensors="pt"
        )

        # Tokenize original PSMILES (target labels)
        original_enc = self.tokenizer(
            original_psmiles,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = masked_enc["input_ids"].squeeze(0) # 被遮盖的 PSMILES 的 token id 序列，squeeze(0) 去掉批次维度
        attention_mask = masked_enc["attention_mask"].squeeze(0) # 1 表示真实 token，0 表示 padding（给模型的自注意力用）
        labels = original_enc["input_ids"].squeeze(0) #原始（未遮盖） PSMILES 的 token id 序列，作为目标。

        return {
            "input_ids": input_ids, 
            "attention_mask": attention_mask,
            "labels": labels
        }

# 构建数据集
dataset = PSMILESDataset(df, tok, max_len=MAX_LEN)

# 构建数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 示例：查看一个批次的数据
for batch in dataloader:
    print("Input IDs:", batch["input_ids"].shape)
    print("Attention Mask:", batch["attention_mask"].shape)
    print("Labels:", batch["labels"].shape)
    break

In [None]:
from transformers import AutoModel, BertConfig, BertLMHeadModel, EncoderDecoderModel
import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

encoder = AutoModel.from_pretrained(ENC_NAME) #加载预训练的polyBERT模型作为编码器
encoder.to(device)

for param in encoder.parameters():
    param.requires_grad = False # 冻结polyBERT的所有参数

hidden_layers = getattr(encoder.config, "num_hidden_layers", None)
TOP_ENCODER_LAYERS = 2  # # 以后需要更大容量再设为 4

def select_encoder_top_layers(model, num_layers=TOP_ENCODER_LAYERS):
    """这是一个工具函数。
    目标：找到 encoder 的最上面 num_layers 层 的参数。
    这些参数会被暂时保存起来，将来训练时可以“解冻”它们，让它们恢复 requires_grad=True, 即参与反向传播。
    """
    if hidden_layers is None:
        return []
    start_idx = max(hidden_layers - num_layers, 0)
    prefixes = [f"encoder.encoder.layer.{idx}." for idx in range(start_idx, hidden_layers)]
    selected = []
    for name, param in model.named_parameters():
        if any(name.startswith(prefix) for prefix in prefixes):
            selected.append((name, param))
    return selected

decoder_config = BertConfig( #基于BERT的解码器配置
    vocab_size=len(tok),
    hidden_size=512,
    intermediate_size=2048,
    num_hidden_layers=6,
    num_attention_heads=8,
    max_position_embeddings=MAX_LEN,
    is_decoder=True,
    add_cross_attention=True,
    pad_token_id=tok.pad_token_id,
    bos_token_id=tok.cls_token_id if tok.cls_token_id is not None else tok.pad_token_id,
    eos_token_id=tok.sep_token_id if tok.sep_token_id is not None else tok.pad_token_id,
    tie_word_embeddings=True,
)

decoder = BertLMHeadModel(decoder_config)

model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model.config.decoder_start_token_id = decoder_config.bos_token_id # decoder的开始token，解码时第一步喂 BOS
model.config.pad_token_id = tok.pad_token_id # 对齐 tokenizer
model.config.vocab_size = len(tok)  # 输出词表大小

model.config.tie_encoder_decoder = False # 共享编码器和解码器的词嵌入
model.config.use_cache = False # 训练阶段通常关掉 KV-cache，避免不必要的显存/图不兼容问题
model.tie_weights() # 共享编码器和解码器的词嵌入
model.to(device)

# Keep a handle to the parameters we plan to unfreeze after warmup
top_encoder_params = select_encoder_top_layers(model, num_layers=TOP_ENCODER_LAYERS)
for name, param in top_encoder_params:
    param.requires_grad = False
print(f"Reserved {len(top_encoder_params)} encoder parameter groups for later fine-tuning.")


In [None]:
from transformers import get_linear_schedule_with_warmup
import torch.nn as nn

decoder_lr = 2e-4 # 解码器学习率
encoder_lr = 2e-5 # 编码器顶部层的学习率
NUM_EPOCHS = 5 # 训练轮数
GRAD_CLIP_NORM = 1.0 # 全局范数 1.0，防止梯度爆炸，尤其是加入交叉注意力的生成模型
UNFREEZE_AFTER_STEPS = 10_000 # 在训练步数达到10,000之后，一次性解冻先前预留的顶层编码器参数

# 解码器参数
decoder_params = [param for name, param in model.named_parameters() if name.startswith("decoder") or name.startswith("lm_head")]
# 编码器顶部层参数
encoder_head_params = [param for _, param in top_encoder_params]

optimizer_groups = [{"params": decoder_params, "lr": decoder_lr}] # 组1：解码器（大学习率）
if encoder_head_params:
    optimizer_groups.append({"params": encoder_head_params, "lr": encoder_lr}) # 组2：预留的编码器顶层（小学习率）,当训练步数不够时，这组参数不会被更新

optimizer = torch.optim.AdamW(optimizer_groups, betas=(0.9, 0.999), weight_decay=0.01) # 带权重衰减的 Adam
total_steps = max(len(dataloader) * NUM_EPOCHS, 1) # 总训练步数
warmup_steps = int(0.06 * total_steps) # 前 6% 的步数线性升温学习率，之后线性下降到 0
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

loss_fn = nn.CrossEntropyLoss(ignore_index=tok.pad_token_id, label_smoothing=0.1) # 标准自回归语言建模损失，pad_token_id 的标签位置不计入损失
encoder_unfreeze_state = {"done": False}

def maybe_unfreeze_encoder(model, current_step, num_layers=TOP_ENCODER_LAYERS):
    # 当 global_step 达到阈值时，一次性把顶层编码器参数解冻。
    if encoder_unfreeze_state["done"] or current_step < UNFREEZE_AFTER_STEPS:
        return
    for name, param in select_encoder_top_layers(model, num_layers):
        param.requires_grad = True
    encoder_unfreeze_state["done"] = True
    print(f"Unfroze top {num_layers} encoder layers at step {current_step}.")

def train(model, dataloader, num_epochs, start_step=0):
    global_step = start_step
    model.train()
    for epoch in range(num_epochs):
        progress = notebook_tqdm.tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        for batch in progress:
            maybe_unfreeze_encoder(model, global_step) # 每个 step 开始前，检查是否要解冻顶层 encoder。
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels)
            decoder_attention_mask = (decoder_input_ids != tok.pad_token_id).long()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
            )
            logits = outputs.logits
            shift_logits = logits[:, :-1].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
            optimizer.step()
            scheduler.step()

            global_step += 1
            progress.set_postfix({"loss": loss.item(), "step": global_step})
    return global_step

# Example usage (commented out to avoid long runs by default)
final_step = train(model, dataloader, num_epochs=NUM_EPOCHS)
print(f"Training finished at step {final_step}.")


In [None]:

@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    total_loss, total_tokens = 0.0, 0
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        dec_in = model.prepare_decoder_input_ids_from_labels(labels)
        dec_mask = (dec_in != tok.pad_token_id).long()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=dec_in,
            decoder_attention_mask=dec_mask,
            labels=labels,
        )
        loss = outputs.loss
        mask = (labels[:, 1:] != tok.pad_token_id)
        total_loss += loss.item() * mask.sum().item()
        total_tokens += mask.sum().item()
    ppl = torch.exp(torch.tensor(total_loss / max(total_tokens, 1))).item()
    return total_loss / max(total_tokens, 1), ppl

val_loss, val_ppl = evaluate(model, dataloader)
print(f"val loss/token={val_loss:.4f}, perplexity={val_ppl:.2f}")
