In [82]:
from biopandas.pdb import PandasPdb
from pathlib import Path
from scipy.spatial.distance import cdist
import numpy as np
from scipy.spatial.transform import Rotation as R
from Bio import PDB

In [83]:
def kabsch(P, Q):  # 对齐的旋转矩阵
    centroid_P = np.mean(P, axis=0)
    centroid_Q = np.mean(Q, axis=0)
    P_centered = P - centroid_P
    Q_centered = Q - centroid_Q
    H = np.dot(P_centered.T, Q_centered)
    U, S, Vt = np.linalg.svd(H)
    rotation_matrix = np.dot(Vt.T, U.T)
    if np.linalg.det(rotation_matrix) < 0:
        Vt[-1, :] *= -1
        rotation_matrix = np.dot(Vt.T, U.T)
    return rotation_matrix

def calculate_rmsd(P, Q):
    """
    Calculate the RMSD between two sets of points P and Q.
    """
    diff = P - Q
    rmsd = np.sqrt(np.mean(np.sum(diff**2, axis=1)))
    return rmsd

In [84]:
def pep_chain_id(true):
    atom_data = true.df['ATOM']
    chain_lengths = atom_data.groupby('chain_id').size()
    true_pep_chain_id = chain_lengths.idxmin()

    return true_pep_chain_id

In [85]:
def get_nearby_heavy_atoms(chain, center, radius=10.0):
    heavy_atoms = chain
    distances = cdist(heavy_atoms[['x_coord', 'y_coord', 'z_coord']].values, center.reshape(1, -1)) #是一个二维数组，形状为 (110, 1)
    nearby_atoms = heavy_atoms[distances <= radius]
    return nearby_atoms

In [86]:
def adjust_pred_residue_numbers(pred_chain, true_chain):
    # 获取 true_chain 和 pred_chain 中 residue_number 的起始编号
    true_start = true_chain['atom_number'].min()
    pred_start = pred_chain['atom_number'].min()
    
    # 计算偏移量
    offset = pred_start - true_start
    
    # 调整 pred_chain 中的 residue_number，使其与 true_chain 对齐
    adjusted_pred_chain = pred_chain.copy()
    adjusted_pred_chain['atom_number'] = adjusted_pred_chain['atom_number'] - offset
    
    return adjusted_pred_chain

In [87]:
root_path = Path("/home/light/mqy/ncaa/data/af3/ss_mono")
for subdir in root_path.iterdir():
    if ".csv" in subdir.name:
        continue
    if "-" in subdir.name:
        continue
    if "tmp" in subdir.name:
        continue
    true_pdb_dir = subdir / f"{subdir.name}_clean.pdb"
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('pdb', true_pdb_dir)
    for model in structure:
        for chain in model:
            for residue in chain:
                atoms_to_remove = [atom for atom in residue if atom.get_name().startswith('H')]
                for atom in atoms_to_remove:
                    residue.detach_child(atom.id)
    io = PDB.PDBIO()
    io.set_structure(structure)
    new_pdb_path = subdir / f"{subdir.name}_clean_dete_H.pdb"
    io.save(str(new_pdb_path))

In [88]:
root_path = Path("/home/light/mqy/ncaa/data/af3/ss_mono")
for subdir in root_path.iterdir():
    if ".csv" in subdir.name:
        continue
    pred_pdb_dir = subdir / f"{subdir.name}.pdb"
    if pred_pdb_dir.exists():
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('pdb', pred_pdb_dir)
        for model in structure:
            for chain in model:
                for residue in chain:
                    atoms_to_remove = [atom for atom in residue if atom.get_name().startswith('H')]
                    for atom in atoms_to_remove:
                        residue.detach_child(atom.id)
        io = PDB.PDBIO()
        io.set_structure(structure)
        new_pdb_path = subdir / f"{subdir.name}_dete_H.pdb"
        io.save(str(new_pdb_path))

In [89]:
root_path = Path("/home/light/mqy/ncaa/data/af3/ss_mono")
for subdir in root_path.iterdir():
    # print(subdir.name)
    if ".csv" in subdir.name:
        continue
    true_pdb_dir = subdir / f"{subdir.name}_clean_dete_H.pdb"
    pred_pdb_dir = subdir / f"{subdir.name}_dete_H.pdb"
    if pred_pdb_dir.exists():
        try:
            ppdb_true = PandasPdb().read_pdb(str(true_pdb_dir))
            ppdb_pred = PandasPdb().read_pdb(str(pred_pdb_dir))

            true_pep_chain_id = pep_chain_id(ppdb_true)
            # print(true_pep_chain_id)

            true_pep_chain = ppdb_true.df['ATOM'][ppdb_true.df['ATOM']['chain_id'] == true_pep_chain_id]
            # print("ss",len(true_pep_chain))
            pred_pep_chain = ppdb_pred.df['ATOM'][ppdb_pred.df['ATOM']['chain_id'] == 'A']
            # print("ss",len(true_pep_chain))
            if 'OXT' in true_pep_chain['atom_name'].values:
                true_pep_chain = true_pep_chain[:-1]
            # print(len(true_pep_chain), len(pred_pep_chain))
            
            true_non_std_aa = ppdb_true.df["HETATM"]
            pred_non_std_aa = ppdb_pred.df["HETATM"]
            non_std_aa_centers = true_non_std_aa.groupby('residue_name')[['x_coord', 'y_coord', 'z_coord']].mean().values #每个非标准氨基酸的中心坐标以 NumPy 数组的形式存储,(n,3)
            # print(non_std_aa_centers.shape)

            rmsd_list = []
            for center in non_std_aa_centers:
                # print(center)
                true_nearby_atoms = get_nearby_heavy_atoms(true_pep_chain, center)
                # print("aa",len(true_nearby_atoms))

                # 从 pred 中按照对应的序号取原子 这边不知道这个序号是不是和pdb一样，不一样的话这个是需要改掉的🌈
                adjusted_pred_pep_chain = adjust_pred_residue_numbers(pred_pep_chain, true_pep_chain)
                # print(true_pep_chain['atom_number'].tolist())
                # print(adjusted_pred_pep_chain['atom_number'].tolist())
                # print(true_nearby_atoms['atom_number'].tolist())
                pred_nearby_atoms = adjusted_pred_pep_chain[adjusted_pred_pep_chain['atom_number'].isin(true_nearby_atoms['atom_number'])]

                # print("bb",len(pred_nearby_atoms))
                # print(pred_nearby_atoms['residue_number'].tolist())

                # 对齐 10Å 范围内的重原子
                P_nearby = true_nearby_atoms[['x_coord', 'y_coord', 'z_coord']].values
                Q_nearby = pred_nearby_atoms[['x_coord', 'y_coord', 'z_coord']].values
                # print(P_nearby.shape, Q_nearby.shape)

                rotation_matrix = kabsch(P_nearby, Q_nearby)

                # 计算对齐后的非标准氨基酸的RMSD
                P_non_std = true_non_std_aa[['x_coord', 'y_coord', 'z_coord']].values
                Q_non_std = pred_non_std_aa[['x_coord', 'y_coord', 'z_coord']].values

                # 应用相同的旋转矩阵到非标准氨基酸原子
                Q_non_std_rotated = np.dot(Q_non_std - np.mean(Q_non_std, axis=0), rotation_matrix) + np.mean(P_non_std, axis=0)

                # 计算RMSD
                rmsd = calculate_rmsd(P_non_std, Q_non_std_rotated)
                rmsd_list.append(rmsd)
            ncaa_num = len(rmsd_list)
            mean_rmsd = sum(rmsd_list) / ncaa_num
            print(f"{subdir.name}'s rmsd = {mean_rmsd:.3f}")
        except Exception as e:
            print(f"{subdir.name}: {e}")


3LO9's rmsd = 0.978
1P9G's rmsd = 2.548
6MY3's rmsd = 3.105
2CRD's rmsd = 3.184
7N21's rmsd = 3.728
6MY2's rmsd = 2.702
4E86's rmsd = 0.709
1BIG's rmsd = 2.626
2MFX's rmsd = 0.970
1OMC's rmsd = 1.094
6MY1's rmsd = 3.014
4E83's rmsd = 3.249
3HJD's rmsd = 2.168
2M62's rmsd = 0.835
7N24's rmsd = 4.455
2MG6's rmsd = 4.178
3LO6's rmsd = 0.203
5UG3's rmsd = 1.584
7N25's rmsd = 4.340
1KFP's rmsd = 3.072
7N20's rmsd = 3.371
1K64's rmsd = 0.291
2EW4's rmsd = 0.832


In [90]:
# 写进csv