# 测试ProstT5 load

In [None]:
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model', do_lower_case=False)

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained("/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.float() if device.type=='cpu' else model.half()
print(model)

# prepare your protein sequences/structures as a list.
# Amino acid sequences are expected to be upper-case ("PRTEINO" below)
# while 3Di-sequences need to be lower-case.
sequence_examples = ["PRTEINO", "SEQWENCE"]
min_len = min([ len(s) for s in sequence_examples])
max_len = max([ len(s) for s in sequence_examples])

# replace all rare/ambiguous amino acids by X (3Di sequences does not have those) and introduce white-space between all sequences (AAs and 3Di)
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

# add pre-fixes accordingly. For the translation from AAs to 3Di, you need to prepend "<AA2fold>"
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Generation configuration for "folding" (AA-->3Di)
gen_kwargs_aa2fold = {
                  "do_sample": True,
                  "num_beams": 3, 
                  "top_p" : 0.95, 
                  "temperature" : 1.2, 
                  "top_k" : 6,
                  "repetition_penalty" : 1.2,
}

# translate from AA to 3Di (AA-->3Di)
with torch.no_grad():
  translations = model.generate( 
              ids.input_ids, 
              attention_mask=ids.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              early_stopping=True, # stop early if end-of-text token is generated
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_aa2fold
  )
# Decode and remove white-spaces between tokens
decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings



print("Input AA sequences: ", sequence_examples)
print("Predicted 3Di sequences: ", structure_sequences)
# Now we can use the same model and invert the translation logic
# to generate an amino acid sequence from the predicted 3Di-sequence (3Di-->AA)

# add pre-fixes accordingly. For the translation from 3Di to AA (3Di-->AA), you need to prepend "<fold2AA>"
sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]

# tokenize sequences and pad up to the longest sequence in the batch
ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Example generation configuration for "inverse folding" (3Di-->AA)
gen_kwargs_fold2AA = {
            "do_sample": True,
            "top_p" : 0.85,
            "temperature" : 1.0,
            "top_k" : 3,
            "repetition_penalty" : 1.2,
}

# translate from 3Di to AA (3Di-->AA)
with torch.no_grad():
  backtranslations = model.generate( 
              ids_backtranslation.input_ids, 
              attention_mask=ids_backtranslation.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              #early_stopping=True, # stop early if end-of-text token is generated; only needed for beam-search
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_fold2AA
)
# Decode and remove white-spaces between tokens
decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings

print("input 3Di sequences: ", sequence_examples_backtranslation)
print("Predicted back-translated AA sequences: ", aminoAcid_sequences)

# 测试T2struc load

In [None]:
import torch 
import os
from omegaconf import OmegaConf
from transformers import AutoTokenizer, EsmTokenizer
from collections import OrderedDict
from transformers import GenerationConfig
import torch.nn as nn
import logging
from rich.logging import RichHandler
import math
from models import StructureTokenPredictionModel
import logging 

logger = logging.getLogger("rich")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_T2Struc_tokenizers(cfg):
    text_tokenizer =  AutoTokenizer.from_pretrained(cfg.model["lm"])
    structure_tokenizer = EsmTokenizer.from_pretrained(cfg.model["tokenizer"])
    return text_tokenizer, structure_tokenizer


def load_T2Struc(model_dir_or_weight: str,
                dtype: torch.dtype = torch.bfloat16,
                device: torch.device = DEVICE):
    """
    Load T2Struc model from a directory (containing config.yaml + pytorch_model.bin)
    or directly from a weight file path.
    """
    # 1) Resolve paths
    if os.path.isdir(model_dir_or_weight):
        model_dir = model_dir_or_weight
        cfg_path = os.path.join(model_dir, "config.yaml")
        weight_path = os.path.join(model_dir, "pytorch_model.bin")
    else:
        weight_path = model_dir_or_weight
        model_dir = os.path.dirname(weight_path)
        cfg_path = os.path.join(model_dir, "config.yaml")

    if not os.path.isfile(cfg_path):
        raise FileNotFoundError(f"config.yaml not found: {cfg_path}")
    if not os.path.isfile(weight_path):
        raise FileNotFoundError(f"checkpoint not found: {weight_path}")

    # 2) Load config
    cfg = OmegaConf.load(cfg_path)

    # 3) Build model (on CPU first, in desired dtype)
    model = StructureTokenPredictionModel(cfg.model).to(dtype=dtype)

    # 4) Load weights (to CPU)
    logger.info(f"Loading T2Struc weights from: {weight_path}")
    state_dict = torch.load(weight_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=True)

    # 5) Move to device + eval
    model = model.to(device=device).eval()

    text_tokenizer, structure_tokenizer = load_T2Struc_tokenizers(cfg)


    return model, text_tokenizer, structure_tokenizer

T2Struc, text_tokenizer, structure_tokenizer = load_T2Struc("/t9k/mnt/AMP/weights/T2struc-fined/2025-11-27_21-54-10/final_model")
print(T2Struc)


# 测试合并模型 end to end 


In [None]:
import torch
import os
import logging
from omegaconf import OmegaConf
# 假设 EndToEndModel 保存在 models 文件夹下
from models.EndToEndModel import EndToEndModel

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_and_convert_weights():
    # ===================== Paths (Based on your notebook) ===================== #
    t2struc_path = "/t9k/mnt/AMP/weights/T2struc-fined/2025-11-27_21-54-10/final_model"
    prostt5_path = "/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model"
    output_path = "/t9k/mnt/AMP/weights/EndToEndModel_merged"
    
    device = torch.device('cpu') # Load on CPU to save memory during merge

    logger.info("Loading T2Struc Config...")
    t2struc_cfg_path = os.path.join(t2struc_path, "config.yaml")
    t2struc_cfg = OmegaConf.load(t2struc_cfg_path)
    
    # ===================== Initialize Empty EndToEndModel ===================== #
    logger.info("Initializing EndToEndModel structure...")
    # 我们传入 t2struc 的 model 配置部分，和 ProstT5 的路径
    model = EndToEndModel(t2struc_cfg.model, prostt5_path)
    
    # ===================== Load T2Struc Weights ===================== #
    t2struc_weight_path = os.path.join(t2struc_path, "pytorch_model.bin")
    logger.info(f"Loading T2Struc weights from {t2struc_weight_path}...")
    t2struc_state_dict = torch.load(t2struc_weight_path, map_location="cpu")
    
    # T2Struc 的权重在 EndToEndModel 中对应的 prefix 是 "t2struc."
    # 原始权重里的 key 类似 "lm.shared.weight", "plm.transformer..."
    # 我们需要给它们加上前缀
    new_t2struc_state_dict = {}
    for k, v in t2struc_state_dict.items():
        new_key = "t2struc." + k
        new_t2struc_state_dict[new_key] = v
        
    # 加载到模型中 (strict=False, 因为还有 prostt5 和 projector 的权重没加载)
    keys = model.load_state_dict(new_t2struc_state_dict, strict=False)
    logger.info(f"T2Struc weights loaded. Missing keys (expected ProstT5/Proj): {len(keys.missing_keys)}")

    # ===================== Load ProstT5 Weights ===================== #
    # ProstT5 通常是一个标准的 HuggingFace 模型，可能是 pytorch_model.bin 或 model.safetensors
    # 这里的路径是文件夹，我们让 T5ForConditionalGeneration 自己加载，然后我们提取它的 state_dict
    logger.info(f"Loading ProstT5 weights from {prostt5_path}...")
    
    # 这里的技巧是：EndToEndModel 初始化时已经通过 from_pretrained 加载了 ProstT5 的初始权重（如果路径正确）
    # 在 EndToEndModel.__init__ 中：self.prostt5 = T5ForConditionalGeneration(config)
    # 如果那里用的是 Config 初始化而不是 from_pretrained(path)，权重是随机的。
    # 为了保险起见，我们在这里显式加载一次权重并赋值。
    
    from transformers import T5ForConditionalGeneration
    temp_prostt5 = T5ForConditionalGeneration.from_pretrained(prostt5_path)
    prostt5_state_dict = temp_prostt5.state_dict()
    
    # ProstT5 的权重在 EndToEndModel 中对应的 prefix 是 "prostt5."
    new_prostt5_state_dict = {}
    for k, v in prostt5_state_dict.items():
        new_key = "prostt5." + k
        new_prostt5_state_dict[new_key] = v
    
    # 加载 ProstT5 权重
    keys = model.load_state_dict(new_prostt5_state_dict, strict=False)
    logger.info(f"ProstT5 weights loaded. Missing keys (expected Proj only): {len(keys.missing_keys)}")
    
    # 检查 Projector 权重
    # Projector 是随机初始化的，不需要加载，但应该在 missing_keys 中
    projector_keys = [k for k in keys.missing_keys if "projector" in k]
    logger.info(f"Projector keys (randomly initialized): {projector_keys}")

    # ===================== Save Merged Model ===================== #
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        
    save_path = os.path.join(output_path, "pytorch_model.bin")
    logger.info(f"Saving combined model to {save_path}...")
    torch.save(model.state_dict(), save_path)
    
    # 同时保存 Config 以便后续加载
    # 保存 T2Struc 的 config
    OmegaConf.save(t2struc_cfg, os.path.join(output_path, "t2struc_config.yaml"))
    # 保存 ProstT5 的 config
    model.prostt5.config.save_pretrained(output_path)
    
    logger.info("Done!")

if __name__ == "__main__":
    load_and_convert_weights()

## 测试end to end model流程能否跑通


In [1]:
import torch
import os
from omegaconf import OmegaConf
from transformers import AutoTokenizer, EsmTokenizer, T5Tokenizer
from models.EndToEndModel import EndToEndModel  # 确保能导入你定义的类


def load_end_to_end_model(merged_model_dir, device='cuda'):
    """
    加载合并后的 EndToEndModel 和所需的 Tokenizers

    Args:
        merged_model_dir: 上一步保存权重的目录 (例如 /t9k/mnt/AMP/weights/EndToEndModel_merged)
        device: 'cuda' or 'cpu' or torch.device
    """
    print(f"Loading model from {merged_model_dir}...")

    # ================= 1. 加载配置 =================
    t2struc_config_path = os.path.join(merged_model_dir, "t2struc_config.yaml")
    if not os.path.exists(t2struc_config_path):
        raise FileNotFoundError(f"Config not found at {t2struc_config_path}")

    t2struc_cfg = OmegaConf.load(t2struc_config_path)

    # ================= 2. 实例化模型 =================
    model = EndToEndModel(t2struc_cfg.model, prostt5_config_path=merged_model_dir)

    # ================= 3. 加载合并后的权重 =================
    weight_path = os.path.join(merged_model_dir, "pytorch_model.bin")
    print(f"Loading weights from {weight_path}...")

    state_dict = torch.load(weight_path, map_location='cpu')

    # strict=True 确保所有键都匹配（T2Struc, ProstT5, Projectors）
    model.load_state_dict(state_dict, strict=True)

    model.to(device)
    model.eval()

    # ================= 4. 加载 Tokenizers =================
    print("Loading tokenizers...")

    # (1) Text Tokenizer
    text_tokenizer_path = t2struc_cfg.model.lm
    text_tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_path)

    # (2) Structure Tokenizer
    structure_tokenizer_path = t2struc_cfg.model.tokenizer
    structure_tokenizer = EsmTokenizer.from_pretrained(structure_tokenizer_path)

    # (3) ProstT5 Tokenizer
    try:
        prostt5_tokenizer = T5Tokenizer.from_pretrained(merged_model_dir, do_lower_case=False)
    except Exception:
        prostt5_orig_path = '/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model'
        print(f"Tokenizer not found in merged dir, loading from {prostt5_orig_path}")
        prostt5_tokenizer = T5Tokenizer.from_pretrained(prostt5_orig_path, do_lower_case=False)

    print("Model and tokenizers loaded successfully!")

    return model, text_tokenizer, structure_tokenizer, prostt5_tokenizer


# -----------------------------
# 1. 基础设置
# -----------------------------
merged_path = "/t9k/mnt/AMP/weights/EndToEndModel_merged"
device = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------
# 2. 加载模型
# -----------------------------
model, text_tokenizer, structure_tokenizer, prostt5_tokenizer = load_end_to_end_model(merged_path, device)
model.eval()

# 输出 structure tokenizer 的 special token ids
start_id = structure_tokenizer.cls_token_id
stop_id = structure_tokenizer.eos_token_id
pad_id = structure_tokenizer.pad_token_id
print(f"Structure Tokenizer - start_id: {start_id}, stop_id: {stop_id}, pad_id: {pad_id}")

# 输出 model 的参数总量
total_params = sum(p.numel() for p in model.parameters())
print(f"Total model parameters: {total_params}")

# -----------------------------
# 3. 构造测试输入
# -----------------------------
input_text = "generate an  peptide with alpha-helical structure."

# 编码文本
text_inputs = text_tokenizer(
    input_text,
    return_tensors="pt",
    padding=True,
    truncation=True
)
text_inputs = {k: v.to(device) for k, v in text_inputs.items()}

# 构造 batch（Generate 模式只需要 text）
batch = {
    "text_ids": text_inputs["input_ids"],
    "text_masks": text_inputs["attention_mask"],
    "max_structure_len": 64,  # 中间结构生成最大长度
    "max_seq_len": 50         # 最终序列生成最大长度
}

# -----------------------------
# 4. 推理（Generate）
# -----------------------------
with torch.no_grad():
    generated_ids = model.generate(batch)

# -----------------------------
# 5. 解码结果
# -----------------------------
generated_seq = prostt5_tokenizer.batch_decode(
    generated_ids,
    skip_special_tokens=True
)

# -----------------------------
# 6. 打印结果
# -----------------------------
print(f"Input: {input_text}")
print(f"Generated Sequence: {generated_seq[0]}")


  from .autonotebook import tqdm as notebook_tqdm


Loading model from /t9k/mnt/AMP/weights/EndToEndModel_merged...




Loading weights from /t9k/mnt/AMP/weights/EndToEndModel_merged/pytorch_model.bin...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading tokenizers...
Tokenizer not found in merged dir, loading from /t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model
Model and tokenizers loaded successfully!
Structure Tokenizer - start_id: 1, stop_id: 2, pad_id: 0
Total model parameters: 2668102912
Input: generate an  peptide with alpha-helical structure.
Generated Sequence: M T G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G G
