In [None]:
import sys, torch
from pathlib import Path
repo = Path("/Users/tangren/Documents/PolymersGenerator")
sys.path.append(str(repo / "src"))  # 允许导入 src 包

In [None]:
# 导入模块与设备
from src.tokenizer import PolyBertTokenizer
from src.dataset import make_loader
from src.model import VAESmiles
from src.train import train_one_epoch, val_loss, set_seed
from transformers import AutoModel
import torch.optim as optim

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "mps"
                      if torch.backends.mps.is_available() else "cpu")


In [None]:
# 加载数据与tokenizer
csv_path = repo / "PSMILES_Tg_only.csv"
tokenizer = PolyBertTokenizer("kuelumbus/polyBERT")
train_loader = make_loader(
    csv_path,
    tokenizer,
    batch_size=128,
    shuffle=True,
    col="PSMILES",
    max_len=256,
)
val_loader = make_loader(
    csv_path,
    tokenizer,
    batch_size=128,
    shuffle=False,
    col="PSMILES",
    max_len=256,
)


In [None]:
# 构建带polyBERT编码器的VAE模型
polybert = AutoModel.from_pretrained("kuelumbus/polyBERT")
model = VAESmiles(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    encoder_hid_dim=polybert.config.hidden_size,
    decoder_hid_dim=512,
    z_dim=128,
    n_layers=1,
    pad_id=tokenizer.pad_id,
    bos_id=tokenizer.bos_id,
    eos_id=tokenizer.eos_id,
    drop=0.1,
    use_polybert=True,
    polybert=polybert,
    freeze_polybert=True,
    polybert_pooling="cls",
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)



In [None]:
# 训练循环
epochs, best = 10, float("inf")
for epoch in range(epochs):
    kl_w = min(1.0, (epoch + 1) / 10.0)
    train_loss = train_one_epoch(model, train_loader, optimizer,
                                 kl_w, tokenizer.pad_id, device)
    val_loss_value = val_loss(model, val_loader, kl_w,
                              tokenizer.pad_id, device)
    print(f"[{epoch+1}/{epochs}] train={train_loss:.4f} "
          f"val={val_loss_value:.4f} kl_w={kl_w:.2f}")

    if val_loss_value + 1e-3 < best:
        best = val_loss_value
        (repo / "checkpoints").mkdir(exist_ok=True)
        torch.save(
            {
                "model": model.state_dict(),
                "tokenizer_name": "kuelumbus/polyBERT",
                "tokenizer": tokenizer.get_vocab(),
                "pad_token_id": tokenizer.pad_id,
                "bos_token_id": tokenizer.bos_id,
                "eos_token_id": tokenizer.eos_id,
                "use_polybert": True,
            },
            repo / "checkpoints/notebook.pt",
        )


In [None]:
# 生成与重构
@torch.no_grad()
def sample_smiles(model, tokenizer, num=16, max_len=256):
    z = torch.randn(num, model.mu.out_features, device=device)
    token_ids = model.sample(z, max_len=max_len)
    return [tokenizer.decode(row.tolist()) for row in token_ids.cpu()]

@torch.no_grad()
def reconstruct(model, tokenizer, smiles):
    ids = tokenizer.encode(smiles)
    enc = torch.tensor(ids, device=device).unsqueeze(0)
    mask = (enc != tokenizer.pad_id).long()
    mu, logvar = model.encode(enc, mask)
    z = model.reparameterize(mu, logvar)
    out = model.sample(z, max_len=enc.size(1))
    return tokenizer.decode(out.squeeze(0).tolist())

model.eval()
generated = sample_smiles(model, tokenizer, num=10)
recon = reconstruct(model, tokenizer, "[*]#C[SiH2]C#Cc1cccc(C#[*])c1")


In [None]:
# 推理用保存的模型
ckpt = torch.load(repo / "checkpoints/notebook.pt", map_location=device)
tokenizer = PolyBertTokenizer(ckpt["tokenizer_name"])
polybert = AutoModel.from_pretrained(ckpt["tokenizer_name"])
model.load_state_dict(ckpt["model"])
model.to(device).eval()
# 之后可复用 sample_smiles / reconstruct
