# 测试ProstT5 load

In [None]:
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model', do_lower_case=False)

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained("/t9k/mnt/AMP/weights/ProstT5-Distilled-12l/final_model").to(device)

# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
model.float() if device.type=='cpu' else model.half()

# prepare your protein sequences/structures as a list.
# Amino acid sequences are expected to be upper-case ("PRTEINO" below)
# while 3Di-sequences need to be lower-case.
sequence_examples = ["PRTEINO", "SEQWENCE"]
min_len = min([ len(s) for s in sequence_examples])
max_len = max([ len(s) for s in sequence_examples])

# replace all rare/ambiguous amino acids by X (3Di sequences does not have those) and introduce white-space between all sequences (AAs and 3Di)
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

# add pre-fixes accordingly. For the translation from AAs to 3Di, you need to prepend "<AA2fold>"
sequence_examples = [ "<AA2fold>" + " " + s for s in sequence_examples]

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Generation configuration for "folding" (AA-->3Di)
gen_kwargs_aa2fold = {
                  "do_sample": True,
                  "num_beams": 3, 
                  "top_p" : 0.95, 
                  "temperature" : 1.2, 
                  "top_k" : 6,
                  "repetition_penalty" : 1.2,
}

# translate from AA to 3Di (AA-->3Di)
with torch.no_grad():
  translations = model.generate( 
              ids.input_ids, 
              attention_mask=ids.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              early_stopping=True, # stop early if end-of-text token is generated
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_aa2fold
  )
# Decode and remove white-spaces between tokens
decoded_translations = tokenizer.batch_decode( translations, skip_special_tokens=True )
structure_sequences = [ "".join(ts.split(" ")) for ts in decoded_translations ] # predicted 3Di strings



print("Input AA sequences: ", sequence_examples)
print("Predicted 3Di sequences: ", structure_sequences)
# Now we can use the same model and invert the translation logic
# to generate an amino acid sequence from the predicted 3Di-sequence (3Di-->AA)

# add pre-fixes accordingly. For the translation from 3Di to AA (3Di-->AA), you need to prepend "<fold2AA>"
sequence_examples_backtranslation = [ "<fold2AA>" + " " + s for s in decoded_translations]

# tokenize sequences and pad up to the longest sequence in the batch
ids_backtranslation = tokenizer.batch_encode_plus(sequence_examples_backtranslation,
                                  add_special_tokens=True,
                                  padding="longest",
                                  return_tensors='pt').to(device)

# Example generation configuration for "inverse folding" (3Di-->AA)
gen_kwargs_fold2AA = {
            "do_sample": True,
            "top_p" : 0.85,
            "temperature" : 1.0,
            "top_k" : 3,
            "repetition_penalty" : 1.2,
}

# translate from 3Di to AA (3Di-->AA)
with torch.no_grad():
  backtranslations = model.generate( 
              ids_backtranslation.input_ids, 
              attention_mask=ids_backtranslation.attention_mask, 
              max_length=max_len, # max length of generated text
              min_length=min_len, # minimum length of the generated text
              #early_stopping=True, # stop early if end-of-text token is generated; only needed for beam-search
              num_return_sequences=1, # return only a single sequence
              **gen_kwargs_fold2AA
)
# Decode and remove white-spaces between tokens
decoded_backtranslations = tokenizer.batch_decode( backtranslations, skip_special_tokens=True )
aminoAcid_sequences = [ "".join(ts.split(" ")) for ts in decoded_backtranslations ] # predicted amino acid strings

print("input 3Di sequences: ", sequence_examples_backtranslation)
print("Predicted back-translated AA sequences: ", aminoAcid_sequences)

# 测试T2struc load

In [5]:
import torch 
import os
from omegaconf import OmegaConf
from transformers import AutoTokenizer, EsmTokenizer
from collections import OrderedDict
from transformers import GenerationConfig
import torch.nn as nn
import logging
from rich.logging import RichHandler
import math
from models import StructureTokenPredictionModel
import logging 

logger = logging.getLogger("rich")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_T2Struc_tokenizers(cfg):
    text_tokenizer =  AutoTokenizer.from_pretrained(cfg.model["lm"])
    structure_tokenizer = EsmTokenizer.from_pretrained(cfg.model["tokenizer"])
    return text_tokenizer, structure_tokenizer


def load_T2Struc(model_dir_or_weight: str,
                dtype: torch.dtype = torch.bfloat16,
                device: torch.device = DEVICE):
    """
    Load T2Struc model from a directory (containing config.yaml + pytorch_model.bin)
    or directly from a weight file path.
    """
    # 1) Resolve paths
    if os.path.isdir(model_dir_or_weight):
        model_dir = model_dir_or_weight
        cfg_path = os.path.join(model_dir, "config.yaml")
        weight_path = os.path.join(model_dir, "pytorch_model.bin")
    else:
        weight_path = model_dir_or_weight
        model_dir = os.path.dirname(weight_path)
        cfg_path = os.path.join(model_dir, "config.yaml")

    if not os.path.isfile(cfg_path):
        raise FileNotFoundError(f"config.yaml not found: {cfg_path}")
    if not os.path.isfile(weight_path):
        raise FileNotFoundError(f"checkpoint not found: {weight_path}")

    # 2) Load config
    cfg = OmegaConf.load(cfg_path)

    # 3) Build model (on CPU first, in desired dtype)
    model = StructureTokenPredictionModel(cfg.model).to(dtype=dtype)

    # 4) Load weights (to CPU)
    logger.info(f"Loading T2Struc weights from: {weight_path}")
    state_dict = torch.load(weight_path, map_location="cpu")
    model.load_state_dict(state_dict, strict=True)

    # 5) Move to device + eval
    model = model.to(device=device).eval()

    text_tokenizer, structure_tokenizer = load_T2Struc_tokenizers(cfg)


    return model, text_tokenizer, structure_tokenizer

T2Struc, text_tokenizer, structure_tokenizer = load_T2Struc("/t9k/mnt/AMP/weights/T2struc-fined/2025-11-27_21-54-10/final_model")
# print(T2Struc)
