In [1]:
import os

# 设置环境变量
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

# 打印环境变量以确认设置成功
print(os.environ.get('HF_ENDPOINT'))

https://hf-mirror.com


In [2]:
from transformers import AutoTokenizer, AutoModel
from tokenizers import Tokenizer
from transformers import GPT2LMHeadModel, AutoConfig,GPT2Tokenizer
from transformers import AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 从 Trainer 获取训练好的模型

# 检查 GPU 可用性并设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model =  AutoModelForSequenceClassification.from_pretrained("gene_eng_gpt2_v0_rna3d_ft").to(device)
tokenizer = AutoTokenizer.from_pretrained("gene_eng_gpt2_v0_rna3d_ft")
model.eval()  # 设置为评估模式

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(90000, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=768, out_features=3, bias=False)
)

In [4]:
def get_rna_pos(seq):
    """
    获得ran序列最后一个残基的三维坐标预测
    """
    # 分词和填充
    inputs = tokenizer(seq, truncation=True, padding='max_length', max_length=256, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)  # 移动到 GPU   形状: (1, 256)
    attention_mask = inputs["attention_mask"].to(device)  # 移动到 GPU  形状: (1, 256)
    
    # 推理
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = outputs.logits  # 形状: (1, 3)
    
    # 转换为 numpy 或列表
    predictions = predictions.squeeze(0).cpu().numpy().tolist()  # 从 GPU 移回 CPU 并转为 NumPy  ， [x, y, z]
    return predictions

In [9]:
def get_rna_all_pos(sequence):
    """
    获得rna每个残基的三维坐标预测
    """
    seq_pos_list_local = []
    for i in range(0, len(sequence)):
        seq = sequence[0:i+1].replace("U","T") #U-->T
        #如果seq长度大于1024则只要最后面的1024个字符
        if len(seq) > 1024:
            seq = seq[-1024:]
        pos = get_rna_pos(seq)
        #print(seq, pos)
        seq_pos_list_local.append(pos)
    return seq_pos_list_local

In [13]:
import numpy as np
from scipy.spatial import distance
from scipy.spatial.transform import Rotation

def get_d0(Lref):
    """
    get distance scaling factor in Angstroms
    """
    if Lref >= 30:
        return  0.6 * np.sqrt(Lref - 0.5) - 2.5
    elif 24 <= Lref < 30:
        return 0.7
    elif 20 <= Lref < 24:
        return 0.6
    elif 16 <= Lref < 20:
        return 0.5
    elif 12 <= Lref < 16:
        return 0.4
    else:  # Lref < 12
        return 0.3


def compute_tm_score(pos_predict_list, pos_truth_list):
    """
    计算 TM-score，用于评估预测的 RNA 3D 结构与真实结构的相似度。
    
    参数:
    pos_predict_list (list of lists): 预测的残基坐标，格式为 [[x1, y1, z1], [x2, y2, z2], ...]
    pos_truth_list (list of lists): 真实的残基坐标，格式为 [[x1, y1, z1], [x2, y2, z2], ...]
    
    返回:
    float: TM-score 值，范围通常在 0 到 1 之间，1 表示完全一致
    """
    # 转换为 NumPy 数组
    pos_predict = np.array(pos_predict_list)
    pos_truth = np.array(pos_truth_list)
    
    # 检查长度是否一致
    if pos_predict.shape != pos_truth.shape:
        raise ValueError("预测和真实坐标列表的长度必须一致")
    
    L_ref = len(pos_truth)  # 参考结构的残基数
    
    # 质心对齐：将两组坐标平移到原点
    centroid_predict = np.mean(pos_predict, axis=0)
    centroid_truth = np.mean(pos_truth, axis=0)
    pos_predict_centered = pos_predict - centroid_predict
    pos_truth_centered = pos_truth - centroid_truth
    
    # Kabsch 算法：计算旋转矩阵以对齐预测结构
    covariance_matrix = np.dot(pos_predict_centered.T, pos_truth_centered)
    U, _, Vt = np.linalg.svd(covariance_matrix)
    rotation_matrix = np.dot(Vt.T, U.T)
    # 确保旋转矩阵的行列式为正（避免镜像对齐）
    if np.linalg.det(rotation_matrix) < 0:
        Vt[-1, :] *= -1
        rotation_matrix = np.dot(Vt.T, U.T)
    pos_predict_aligned = np.dot(pos_predict_centered, rotation_matrix)
    
    # 计算对齐后每对残基的欧几里得距离
    d_i = np.linalg.norm(pos_predict_aligned - pos_truth_centered, axis=1)
    
    # 计算 d_0（缩放因子）
    d0 = get_d0(L_ref)  # 对于较短的序列，简化处理
    
    # 计算 TM-score
    tm_score = (1 / L_ref) * np.sum(1 / (1 + (d_i / d0)**2))
    
    return tm_score


def get_eva(seq, pos_list_gt):
    """
    获得rna 3d预测评分
    seq：rna序列
    pos_list: ran序列每个残基的位置坐标list
    """
    #获得每个残基预测坐标
    pos_list_predict = get_rna_all_pos(seq)

    #获得评分
    tm_score = compute_tm_score(pos_list_predict, pos_list_gt)

    return tm_score

In [14]:
#查看距离
from datasets import load_dataset
# 1. load ~11k samples from promoters prediction dataset
dataset = load_dataset("json", data_files="rna_pos_all_test.jsonl")
dataset

DatasetDict({
    train: Dataset({
        features: ['seq', 'pos_list'],
        num_rows: 12
    })
})

In [15]:
item = dataset["train"][0]
seq = item["seq"]
pos_list_gt = item["pos_list"]
get_eva(seq, pos_list_gt)

0.011943309635017758

In [16]:
eva_list = []
for item in dataset["train"]:
    seq = item["seq"]
    pos_list_gt = item["pos_list"]
    eva = get_eva(seq, pos_list_gt)
    #print(eva)
    eva_list.append(eva)

In [17]:
import numpy as np
# 计算平均值
average = np.mean(eva_list)
print("平均值:", average)

平均值: 0.1336094835798191


In [18]:
eva_list

[0.011943309635017758,
 0.01846775457298123,
 1.5831164338617964e-33,
 9.39129826446217e-36,
 0.3337100659504432,
 0.30910811204256333,
 0.017259134945893786,
 0.5009053766509648,
 0.048660292726138914,
 0.14083365863433844,
 0.10399353708952036,
 0.1184325607099676]