In [1]:
import sys, torch
from pathlib import Path
repo = Path("/Users/tangren/Documents/PolymersGenerator")
sys.path.append(str(repo / "src"))  # 允许导入 src 包

In [2]:
# 导入模块与设备
from src.tokenizer import PolyBertTokenizer
from src.dataset import make_loader
from src.modelv2 import VAESmiles
from src.train import train_one_epoch, val_loss, set_seed
from transformers import AutoModel
import torch.optim as optim
import tqdm as notebook_tqdm

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "mps"
                      if torch.backends.mps.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 加载数据与tokenizer
csv_path = "data/PSMILES_Tg_only.csv"
tokenizer = PolyBertTokenizer("./polybert")
train_loader = make_loader(
    csv_path,
    tokenizer,
    batch_size=128,
    shuffle=True,
    col="PSMILES",
    max_len=256,
)
val_loader = make_loader(
    csv_path,
    tokenizer,
    batch_size=128,
    shuffle=False,
    col="PSMILES",
    max_len=256,
)


In [4]:
# 构建带polyBERT编码器的VAE模型
polybert = AutoModel.from_pretrained("./polybert").to(device)
model = VAESmiles(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    encoder_hid_dim=polybert.config.hidden_size,
    decoder_hid_dim=512,
    z_dim=128,
    n_layers=1,
    pad_id=tokenizer.pad_id,
    bos_id=tokenizer.bos_id,
    eos_id=tokenizer.eos_id,
    drop=0.1,
    use_polybert=True,
    polybert=polybert,
    freeze_polybert=True,
    polybert_pooling="cls",
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)



In [None]:
# 训练循环
epochs, best = 10, float("inf")
for epoch in range(epochs):
    kl_w = min(1.0, (epoch + 1) / 10.0)
    train_loss = train_one_epoch(model, train_loader, optimizer,
                                 kl_w, tokenizer.pad_id, device)
    val_loss_value = val_loss(model, val_loader, kl_w,
                              tokenizer.pad_id, device)
    print(f"[{epoch+1}/{epochs}] train={train_loss:.4f} "
          f"val={val_loss_value:.4f} kl_w={kl_w:.2f}")

    if val_loss_value + 1e-3 < best:
        best = val_loss_value
        (repo / "checkpoints").mkdir(exist_ok=True)
        torch.save(
            {
                "model": model.state_dict(),
                "tokenizer_name": "./polybert",
                "tokenizer": tokenizer.get_vocab(),
                "pad_token_id": tokenizer.pad_id,
                "bos_token_id": tokenizer.bos_id,
                "eos_token_id": tokenizer.eos_id,
                "use_polybert": True,
            },
            repo / "checkpoints/notebook.pt",
        )



In [6]:
ckpt_path = repo / "checkpoints/modelv2_best.pt"
ckpt = torch.load(ckpt_path, map_location=device)

tokenizer = PolyBertTokenizer("./polybert")
polybert = AutoModel.from_pretrained("./polybert").to(device)

model = VAESmiles(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    encoder_hid_dim=polybert.config.hidden_size,
    decoder_hid_dim=512,
    z_dim=128,
    n_layers=1,
    pad_id=tokenizer.pad_id,
    bos_id=tokenizer.bos_id,
    eos_id=tokenizer.eos_id,
    drop=0.1,
    use_polybert=True,
    polybert=polybert,
    freeze_polybert=True,
    polybert_pooling="cls",
).to(device)
model.load_state_dict(ckpt["model"])
model.eval()

  ckpt = torch.load(ckpt_path, map_location=device)


VAESmiles(
  (drop): Dropout(p=0.1, inplace=False)
  (emb): Embedding(270, 256, padding_idx=267)
  (pos_emb): Embedding(256, 256)
  (latent_proj): Linear(in_features=128, out_features=256, bias=True)
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((256,), eps=1e-05, elementwi

In [7]:
kl_w = 1.0  # 评估时通常直接用 1
val_loss_value = val_loss(model, val_loader, kl_w, tokenizer.pad_id, device)
print(f"modelv2 验证集损失: {val_loss_value:.4f}")

modelv2 验证集损失: 0.7816


In [9]:
# 重构示例
import random, pandas as pd

df = pd.read_csv("data/PSMILES_Tg_only.csv")
subset = random.sample(df["PSMILES"].tolist(), 4)

def reconstruct(smiles):
    ids = tokenizer.encode(smiles)
    inp = torch.tensor(ids, device=device).unsqueeze(0)
    mask = (inp != tokenizer.pad_id).long()
    mu, logvar = model.encode(inp, mask)
    z = model.reparameterize(mu, logvar)
    out = model.sample(z, max_len=inp.size(1))
    return tokenizer.decode(out.squeeze(0).tolist())

for s in subset:
    rec = reconstruct(s)
    print(f"orig: {s}")
    print(f"reco: {rec}\n")


orig: [*]Oc1ccc2ccc(Oc3ccc4c(c3)C(=O)N(c3cccc(N5C(=O)c6ccc([*])cc6C5=O)c3)C4=O)cc2c1
reco: [*]Oc1ccc(Oc2ccc(Oc3ccc(N4C(=O)c4cccc(Oc5cccc(Oc6cccc([*])c6C5=O)ccc4)c3)c1)c1

orig: [*]C=CC([*])(C)c1ccccc1
reco: [*]CC([*])c1ccccc(C(=O)O

orig: [*]C(=O)NCCCCCCCNC(=O)C(OC)C([*])OC
reco: [*]CC([*])C(=O)OCCCCCCCCCCCCCCCCCCCC

orig: [*]c1ccc(OC(=O)Oc2ccc(C([*])(C)C)cc2CC)c(CC)c1
reco: [*]C(=O)c1ccc(C(=O)c2ccc(OC(=O)c3ccc(Oc4ccc(O



In [10]:
# 随机生成指标
from rdkit import Chem

@torch.no_grad() # 推理阶段不记录梯度，省显存、提速
def sample_smiles(num=256, max_len=256):
    z = torch.randn(num, model.mu.out_features, device=device) # 从标准正态分布采样潜变量 z
    token_ids = model.sample(z, max_len=max_len) # 让模型在潜变量条件下生成 token 序列
    return [tokenizer.decode(row.tolist()) for row in token_ids.cpu()]

gen = sample_smiles(num=512) # 一次性生成 512 个 SMILES 字符串（可能包含无效或重复）
def to_rdkit(smiles):
    return Chem.MolFromSmiles(smiles.replace("[*]", "[Xe]")) # 把 [*] 替换为 [Xe] 再交给 RDKit 解析

valid = [s for s in gen if to_rdkit(s)]
validity = len(valid) / len(gen) # 对每个生成的 SMILES 调 to_rdkit，能解析就当作有效
uniqueness = len(set(gen)) / len(gen) # 计算生成集合中不重复 SMILES 的比例
train_set = set(df["PSMILES"].astype(str)) # 训练集中所有 SMILES 的集合
novelty = len([s for s in set(gen) if s not in train_set]) / max(len(set(gen)), 1) # 计算生成集合中不在训练集的比例

print(f"Validity: {validity:.3f}")
print(f"Uniqueness: {uniqueness:.3f}")
print(f"Novelty: {novelty:.3f}")

Validity: 0.289
Uniqueness: 0.877
Novelty: 0.971


[17:43:03] SMILES Parse Error: extra open parentheses for input: '[Xe]c1ccc(Oc2ccc(C(=O)Nc3ccc(Oc4ccc(Oc5ccc(Oc6ccc(Oc7ccc(C([Xe])(C)cc7)ccccc7)ccc6)ccc5)ccc3)ccc2)c1'
[17:43:03] SMILES Parse Error: unclosed ring for input: '[Xe]c1ccc(Oc2ccc(-c3ccc(-c4ccc(-c5ccc(-c6ccc(-c6cccc(-c7cccccc([Xe])cc7)ccc7)ccc6)cc5)cc4)ccc3)ccc2)c1'
[17:43:03] Explicit valence for atom # 42 O, 3, is greater than permitted
[17:43:03] SMILES Parse Error: unclosed ring for input: '[Xe]c1ccc(-c2ccc(-c3ccc(-c4ccc(-c5ccc([Xe])ccc6)cc4)c3)cc2)c1'
[17:43:03] Can't kekulize mol.  Unkekulized atoms: 21 22 23 24 34
[17:43:03] SMILES Parse Error: unclosed ring for input: '[Xe]CC([Xe])(C)C(=O)Oc1ccc(C(=O)Oc2ccc(-c3ccccc(-c4)ccc3)cc2)c1'
[17:43:03] SMILES Parse Error: extra open parentheses for input: '[Xe]c1ccc(Oc2ccc(NC(=O)c3ccc(C(=O)N(c4ccc(Oc5ccc(C([Xe])(C)cccc6)ccc5)cc4)cc3)cc2)c1'
[17:43:03] SMILES Parse Error: unclosed ring for input: '[Xe]CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC(=O)Nc1ccc(Oc2ccc(NC(=O)c3cccc([Xe])ccc4)cc2

In [11]:
print("length of valid:",len(valid))
print("\nSome valid generated SMILES:", valid[1])
unique_valid = set(valid)
print("Number of unique valid SMILES:", len(unique_valid))
print("\nSome valid generated SMILES:", unique_valid)

length of valid: 148

Some valid generated SMILES: [*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCOCCOCOCOCOCOCOCOCOCOCOC
Number of unique valid SMILES: 92

Some valid generated SMILES: {'[*]CC([*])C(=O)OCCCCCCCCCCCCCCCCCCCCCCCCCCOCOCCCOCCCOCCOCOCCOCOCCCCC', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCC(=O)OCCCCCCCCCCCCCC', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCCCCOc1cc(Oc2ccc(OC)cc2)c1', '[*]CC([*])c1ccc(OC(=O)c2ccccc2)cc1', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCCCCCCCCCCC(=O)OCCCCCCCCCCCCCCC', '[*]C(=O)c1ccc(C(=O)c2ccc(C(=O)c3ccc(C([*])=O)cc3)cc2)cc1', '[*]Nc1ccc(C(=O)c2ccc(C([*])=O)cc2)cc1', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCCCCCCOc1ccc(OCCCCCC)cc1', '[*]CCCCCCCCC(=O)OCCCCCCCCCCCC([*])=O', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCCCCCCCCCCC(=O)O', '[*]CCCCCCCCCCCCCCCCCCCCCC(=O)Nc1ccc(C([*])=O)cc1', '[*]CC([*])(C)C(=O)OCCCCCCOc1ccc(OC)cc1', '[*]CC([*])c1ccc(OCCCCCC)cc1', '[*]CC([*])(C)C(=O)OCCCCCCCCCCCCCCCCCCCCCCC(=O)C', '[*]CC([*])(C)C(=O)OC', '[*]CC([*])C(=O)OCCCCCCCCCCCCCCC', '[*]CC([*])(C)C(=O)OCCCCCCCC

In [12]:
# 插值示例
def encode_to_z(smiles):
    ids = tokenizer.encode(smiles)
    inp = torch.tensor(ids, device=device).unsqueeze(0)
    mask = (inp != tokenizer.pad_id).long()
    mu, logvar = model.encode(inp, mask)
    return mu.squeeze(0), logvar.squeeze(0)

s1, s2 = subset[:2]
z1, _ = encode_to_z(s1)
z2, _ = encode_to_z(s2)

alphas = torch.linspace(0, 1, steps=6, device=device)
interpolations = []
for a in alphas:
    z = (1 - a) * z1 + a * z2
    ids = model.sample(z.unsqueeze(0), max_len=128)
    interpolations.append(tokenizer.decode(ids.squeeze(0).tolist()))

for a, seq in zip(alphas.tolist(), interpolations):
    print(f"α={a:.2f}: {seq}")

α=0.00: [*]Oc1ccc(Oc2ccc(NC(=O)c3ccc(Oc4ccc(Oc5ccc(Oc6cccc(N6C(=O)c7ccccc([*])C7=O)ccc6)c5)c3)cc1
α=0.20: [*]Oc1ccc(Oc2ccc(NC(=O)c3ccc(C(=O)N(c4ccc(Oc5ccc([*])cc6)c4)c3)cc2)c1
α=0.40: [*]CC([*])(C)C(=O)Oc1ccc(C(=O)Oc2cccc(Oc3ccc3)cc2)c1
α=0.60: [*]CC([*])(C)C(=O)Oc1ccc(C(=O)Oc2cccc(OC)c2)cccc1
α=0.80: [*]CC([*])(C)C(=O)OCCCCCCCC
α=1.00: [*]CC([*])c1ccc(C(=O)OCCCCCCCCC)cc1
