In [1]:
from pymatgen.io.cif import CifParser

def get_xyz_from_cif(cif_file):
    parser = CifParser(cif_file)
    structure = parser.get_structures()[0]
    xyz_coords = []
    for site in structure:
        xyz_coords.append(site.coords)
    return xyz_coords

In [7]:
import gemmi

def cif_to_pdb(cif_file, pdb_file):
    # Read CIF file
    doc = gemmi.cif.read_file(cif_file)
    block = doc.sole_block()
    
    # Create a new model
    structure = gemmi.make_structure_from_block(block)
    
    # Write to PDB file
    structure.write_pdb(pdb_file)

# Example usage
cif_to_pdb("/home/yubeen/af3_docking_block/benchmark_101/7df1_F_J_C/7df1_f_j_c_ab/7df1_f_j_c_ab_model.cif", "output.pdb")

In [2]:
import os

In [8]:
for i in os.listdir('/home/yubeen/af3_docking_block/benchmark_101'):
    if ('.' in i) or (len(i.split('_'))<4) : continue
    # print(i.lower(), i)
    af3_item = i.lower().replace('#','')
    input_path = f"/home/yubeen/af3_docking_block/benchmark_101/{i}/{af3_item}_ab/{af3_item}_ab_model.cif"
    output_path = f'/home/kkh517/alphafold2.3_ab_benchmark/{i}/af3_new_ab_block.pdb'
    # print(f"input_path: {input_path}")
    if not os.path.exists(output_path.replace('af3_new_ab_block.pdb','new_ab_block.pdb')):
        output_i = i.replace('Z',i.split('_')[2].lower())
        output_path = f'/home/kkh517/alphafold2.3_ab_benchmark/{output_i}/af3_new_ab_block.pdb'
    assert os.path.exists(input_path), "input_path should exists"
    assert os.path.exists(output_path.replace('af3_new_ab_block.pdb','new_ab_block.pdb')), "output_origin_path should exists"

    cif_to_pdb(input_path, output_path)

In [9]:
from tqdm import tqdm

In [10]:
for i in tqdm(os.listdir('/home/yubeen/af3_docking_block/benchmark_101')):
    if ('.' in i) or (len(i.split('_'))<4) : continue
    # print(i.lower(), i)
    af3_item = i.lower().replace('#','')
    input_path = f"/home/yubeen/af3_docking_block/benchmark_101/{i}/{af3_item}_ag/{af3_item}_ag_model.cif"
    output_path = f'/home/kkh517/alphafold2.3_ab_benchmark/{i}/af3_new_ag_block.pdb'
    # print(f"input_path: {input_path}")
    if not os.path.exists(output_path.replace('af3_new_ag_block.pdb','new_ag_block.pdb')):
        output_i = i.replace('Z',i.split('_')[2].lower())
        output_path = f'/home/kkh517/alphafold2.3_ab_benchmark/{output_i}/af3_new_ag_block.pdb'
    assert os.path.exists(input_path), "input_path should exists"
    assert os.path.exists(output_path.replace('af3_new_ag_block.pdb','new_ag_block.pdb')), "output_origin_path should exists"

    cif_to_pdb(input_path, output_path)

100%|██████████| 110/110 [00:05<00:00, 19.97it/s]


In [60]:
import numpy as np
from scipy.spatial.distance import cdist

def kabsch_algorithm(X, Y):
    """
    Kabsch Algorithm을 사용하여 X를 Y에 정렬하기 위한 최적 회전 행렬 R과 변환 벡터 t 계산.
    """
    assert X.shape == Y.shape, "입력 배열 크기가 같아야 합니다."
    
    # 중심화
    X_mean = X.mean(axis=0)
    Y_mean = Y.mean(axis=0)
    X_centered = X - X_mean
    Y_centered = Y - Y_mean
    
    # 상관 행렬 H 계산
    H = np.dot(X_centered.T, Y_centered)
    
    # SVD 분해
    U, S, Vt = np.linalg.svd(H)
    
    # 회전 행렬 R 계산
    R = np.dot(Vt.T, U.T)
    
    # 반사 방지 (det(R) < 0 인 경우 수정)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = np.dot(Vt.T, U.T)
    
    # 변환 벡터 t 계산
    t = Y_mean - np.dot(R, X_mean)
    
    return R, t

def calculate_rmsd(X, Y):
    """
    두 개의 (L,3) ndarray를 받아 RMSD를 계산.
    Kabsch Algorithm을 적용하여 최적 정렬 후 계산.
    """
    assert X.shape == Y.shape, "두 입력 배열의 크기가 같아야 합니다."
    
    # 최적 회전 및 변환 행렬 계산
    R, t = kabsch_algorithm(X, Y)
    
    # 변환 적용 (올바른 회전 적용 방식)
    X_aligned = np.dot(R, X.T).T + t
    # write_pdb(X, [X.shape[0]], 'X.pdb')
    # write_pdb(X_aligned, [X_aligned.shape[0]], 'x_alinged.pdb')
    # write_pdb(Y, [Y.shape[0]], 'Y.pdb')
    # RMSD 계산
    diff = X_aligned - Y
    rmsd = np.sqrt(np.sum(diff**2) / X.shape[0])
    
    return rmsd, R, t

# 예제 사용법
if __name__ == "__main__":
    X = np.random.rand(10, 3)
    Y = np.random.rand(10, 3)
    
    rmsd, R, t = calculate_rmsd(X, Y)
    print("RMSD:", rmsd)
    print("Rotation Matrix (R):\n", R)
    print("Translation Vector (t):", t)

PDB file 'X.pdb' written successfully.
PDB file 'x_alinged.pdb' written successfully.
PDB file 'Y.pdb' written successfully.
RMSD: 0.5251359277710124
Rotation Matrix (R):
 [[ 0.91538032 -0.0328748  -0.4012457 ]
 [ 0.1825922  -0.85435609  0.48655498]
 [-0.3588021  -0.51864719 -0.77605808]]
Translation Vector (t): [0.43716232 0.60630665 1.52220815]


In [29]:
from glob import glob
from pathlib import Path
import string
import os
import pandas as pd
from collections import defaultdict
import sys
import copy
import numpy as np
from collections import defaultdict
import torch


to1letter = {
    "UNK": "X",
    "ALA": "A",
    "ARG": "R",
    "ASN": "N",
    "ASP": "D",
    "CYS": "C",
    "GLN": "Q",
    "GLU": "E",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LEU": "L",
    "LYS": "K",
    "MET": "M",
    "PHE": "F",
    "PRO": "P",
    "SER": "S",
    "THR": "T",
    "TRP": "W",
    "TYR": "Y",
    "VAL": "V",
}

num2aa=[
    'ALA','ARG','ASN','ASP','CYS',
    'GLN','GLU','GLY','HIS','ILE',
    'LEU','LYS','MET','PHE','PRO',
    'SER','THR','TRP','TYR','VAL',
    'UNK','MAS',
    ]

aa2num= {x:i for i,x in enumerate(num2aa)}

# full sc atom representation (Nx14)
bb_idx = {" N  " : 0, " CA " : 1, " C  " : 2, " O  " : 3}
aa2long=[
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # ala
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2",  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), # arg
    (" N  "," CA "," C  "," O  "," CB "," CG "," OD1"," ND2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HD2","2HD2",  None,  None,  None,  None,  None,  None,  None), # asn
    (" N  "," CA "," C  "," O  "," CB "," CG "," OD1"," OD2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ",  None,  None,  None,  None,  None,  None,  None,  None,  None), # asp
    (" N  "," CA "," C  "," O  "," CB "," SG ",  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HG ",  None,  None,  None,  None,  None,  None,  None,  None), # cys
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," OE1"," NE2",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2",  None,  None,  None,  None,  None), # gln
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," OE1"," OE2",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ",  None,  None,  None,  None,  None,  None,  None), # glu
    (" N  "," CA "," C  "," O  ",  None,  None,  None,  None,  None,  None,  None,  None,  None,  None," H  ","1HA ","2HA ",  None,  None,  None,  None,  None,  None,  None,  None,  None,  None), # gly
    (" N  "," CA "," C  "," O  "," CB "," CG "," ND1"," CD2"," CE1"," NE2",  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HD2"," HE1"," HE2",  None,  None,  None,  None,  None,  None), # his
    (" N  "," CA "," C  "," O  "," CB "," CG1"," CG2"," CD1",  None,  None,  None,  None,  None,  None," H  "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1",  None,  None), # ile
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2",  None,  None), # leu
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," CE "," NZ ",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), # lys
    (" N  "," CA "," C  "," O  "," CB "," CG "," SD "," CE ",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ",  None,  None,  None,  None), # met
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ",  None,  None,  None," H  "," HA ","1HB ","2HB "," HD1"," HD2"," HE1"," HE2"," HZ ",  None,  None,  None,  None), # phe
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD ",  None,  None,  None,  None,  None,  None,  None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ",  None,  None,  None,  None,  None,  None), # pro
    (" N  "," CA "," C  "," O  "," CB "," OG ",  None,  None,  None,  None,  None,  None,  None,  None," H  "," HG "," HA ","1HB ","2HB ",  None,  None,  None,  None,  None,  None,  None,  None), # ser
    (" N  "," CA "," C  "," O  "," CB "," OG1"," CG2",  None,  None,  None,  None,  None,  None,  None," H  "," HG1"," HA "," HB ","1HG2","2HG2","3HG2",  None,  None,  None,  None,  None,  None), # thr
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," H  "," HA ","1HB ","2HB "," HD1"," HE1"," HZ2"," HH2"," HZ3"," HE3",  None,  None,  None), # trp
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ",  None,  None," H  "," HA ","1HB ","2HB "," HD1"," HE1"," HE2"," HD2"," HH ",  None,  None,  None,  None), # tyr
    (" N  "," CA "," C  "," O  "," CB "," CG1"," CG2",  None,  None,  None,  None,  None,  None,  None," H  "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2",  None,  None,  None,  None), # val
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # unk
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # mask
    ]

class QuaternionBase: #from Galaxy.core.quaternion
    def __init__(self):
        self.q = []
    def __repr__(self):
        return self.q
    def rotate(self):
        if 'R' in dir(self):
            return self.R
        #
        self.R = np.zeros((3,3))
        #
        self.R[0][0] = self.q[0]**2 + self.q[1]**2 - self.q[2]**2 - self.q[3]**2
        self.R[0][1] = 2.0*(self.q[1]*self.q[2] - self.q[0]*self.q[3])
        self.R[0][2] = 2.0*(self.q[1]*self.q[3] + self.q[0]*self.q[2])
        #
        self.R[1][0] = 2.0*(self.q[1]*self.q[2] + self.q[0]*self.q[3])
        self.R[1][1] = self.q[0]**2 - self.q[1]**2 + self.q[2]**2 - self.q[3]**2
        self.R[1][2] = 2.0*(self.q[2]*self.q[3] - self.q[0]*self.q[1])
        #
        self.R[2][0] = 2.0*(self.q[1]*self.q[3] - self.q[0]*self.q[2])
        self.R[2][1] = 2.0*(self.q[2]*self.q[3] + self.q[0]*self.q[1])
        self.R[2][2] = self.q[0]**2 - self.q[1]**2 - self.q[2]**2 + self.q[3]**2
        return self.R
    
class QuaternionQ(QuaternionBase):
    def __init__(self, q):
        self.q = q


def ls_rmsd(_X, _Y): #from Galaxy.utils.subPDB
    # Kabsch algorithm & turn into quaternion
    
    X = copy.copy(_X) #(n, 3)
    Y = copy.copy(_Y) #(n, 3)
    n = float(len(X))

    X_cntr = X.transpose().sum(1)/n # (3,)
    Y_cntr = Y.transpose().sum(1)/n
    X -= X_cntr
    Y -= Y_cntr
    Xtr = X.transpose() # (3, n)
    Ytr = Y.transpose() # (3, n)
    X_norm = (Xtr*Xtr).sum() #  
    Y_norm = (Ytr*Ytr).sum()
    
    Rmatrix = np.zeros(9).reshape(3,3)
    for i in range(3):
        for j in range(3):
            Rmatrix[i][j] = Xtr[i].dot(Ytr[j])
    S = np.zeros(16).reshape((4,4))
    S[0][0] =  Rmatrix[0][0] + Rmatrix[1][1] + Rmatrix[2][2]
    S[1][0] =  Rmatrix[1][2] - Rmatrix[2][1]
    S[0][1] =  S[1][0]
    S[1][1] =  Rmatrix[0][0] - Rmatrix[1][1] - Rmatrix[2][2]
    S[2][0] =  Rmatrix[2][0] - Rmatrix[0][2]
    S[0][2] =  S[2][0]
    S[2][1] =  Rmatrix[0][1] + Rmatrix[1][0]
    S[1][2] =  S[2][1]
    S[2][2] = -Rmatrix[0][0] + Rmatrix[1][1] - Rmatrix[2][2]
    S[3][0] =  Rmatrix[0][1] - Rmatrix[1][0]
    S[0][3] =  S[3][0]
    S[3][1] =  Rmatrix[0][2] + Rmatrix[2][0]
    S[1][3] =  S[3][1]
    S[3][2] =  Rmatrix[1][2] + Rmatrix[2][1]
    S[2][3] =  S[3][2]
    S[3][3] = -Rmatrix[0][0] - Rmatrix[1][1] + Rmatrix[2][2]
    #
    eigl,eigv = np.linalg.eigh(S) 
    q = eigv.transpose()[-1] #(4,)
    sU = QuaternionQ(q).rotate() #(3,3)
    sT = Y_cntr - sU.dot(X_cntr)
    #
    # breakpoint()

    rmsd = np.sqrt(max(0.0, (X_norm + Y_norm - 2.0 * eigl[-1]))/n)
    return rmsd, (sT,sU)    

class PDB:
    def __init__ (self, pdb_fn):
        self.pdb_fn = pdb_fn
    
    def read(self):
        sequence = defaultdict(str)
        coord_bb = defaultdict(dict)  #chain: {idx: coordinate [3]}
        coord_all_atom = defaultdict(dict) #chain: {idx: coordinate [14, 3]}
        idx_atm = defaultdict(lambda: defaultdict(list))
        coord_bb = defaultdict(defaultdict)
        idx_total = defaultdict(set)
        with open(self.pdb_fn) as f_pdb:
            lines = f_pdb.readlines()
            for line in lines:
                if line.startswith('ATOM'):
                    chain = line[21]
                    #xyz = np.full((14, 3), np.nan, dtype=np.float32)
                    resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
                    idx_total[chain].add(resNo)
                    if resNo not in coord_all_atom[chain]:
                        coord_all_atom[chain][resNo] = np.full((14, 3), np.nan, dtype=np.float32)
                    for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]][:14]):
                        if tgtatm == atom:
                            coord_all_atom[chain][resNo][i_atm, :] = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
                    #coord_all_atom[chain][resNo] = xyz
                    # for i in [' CA ', ' N  ', ' C  ', ' O  ']:
                    #     if atom == i:
                    #         idx_atm[chain][i].append(resNo)
                    # if atom == ' CA ':
                    #     sequence[chain] += aa
                    # if atom == ' CA ' or atom == ' N  ' or atom == ' C  ' or atom == ' O  ':
                    #     if resNo not in coord_bb[chain]:
                    #         coord_bb[chain][resNo] = np.full((4, 3), np.nan, dtype = np.float32)
                    #     coord_bb[chain][resNo][bb_idx[atom], : ] = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
                    for i in [' CA ']:
                        if atom == i:
                            idx_atm[chain][i].append(resNo)
                    if atom == ' CA ':
                        sequence[chain] += aa
                    if atom == ' CA ' :
                        if resNo not in coord_bb[chain]:
                            coord_bb[chain][resNo] = np.full((4, 3), np.nan, dtype = np.float32)
                        coord_bb[chain][resNo][bb_idx[atom], : ] = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
                        
        return sequence, idx_total, idx_atm, coord_bb, coord_all_atom       
# [L, 14, 3]                
def get_interface_idx(coord, idx, receptor_chain, ligand_chain, device='cpu', iface_cutoff=10.0):
    chain_order = []
    receptor_length = 0
    ligand_length = 0
    for i in receptor_chain:
        chain_order.append(i)
        receptor_length += len(idx[i])
    for i in ligand_chain:
        chain_order.append(i)
        ligand_length += len(idx[i])
    
    xyz_tot = []
    idx_match = defaultdict(dict)
    count = 0
    for i in chain_order:
        xyz_chain = np.full((len(idx[i]), 14, 3), np.nan, dtype=np.float32)
        for j, k in enumerate(idx[i]):
            xyz_chain[j, :, :] = coord[i][k]
        xyz_chain = torch.from_numpy(xyz_chain)
        xyz_tot.append(xyz_chain)
        idx_match[i] = dict(zip(list(range(count, count+len(idx[i]))), idx[i]))
        count += len(idx[i])
    xyz_tot = torch.cat(xyz_tot, dim = 0)
    xyz_tot = xyz_tot.to(device=device)    #[total_len, 14, 3]
    
    dist = xyz_tot[:, None, :, None, :] - xyz_tot[None, :, None, :, :]
    dist = (dist**(2)).sum(dim = -1)
    dist = (dist)**(0.5) #[L, L, 14, 14]
    dist = dist.view(*dist.shape[:2], -1)
    dist = torch.nan_to_num(dist, nan=100.0)
    dist = torch.min(dist, dim = -1)[0]
    dist_dict = defaultdict(dict)
#    for i, j in enumerate(dist):
#        print(i, j)
#        for k, v in idx_match.items():  #A:{0:447, 1:448..}
#            print(k, v)
#            if i in v.keys():
#                dist_dict[k] = j[v[i]]
#                break
    mask = torch.le(dist, iface_cutoff) #[L, L]
    mask[:receptor_length, :receptor_length] = False #False for intra region
    mask[receptor_length:, receptor_length:] = False
    
    interface_lists = torch.unique(torch.where(mask==True)[0])
    interface_pair = torch.where(mask==True)
    
    interface_final = defaultdict(set)
    interface_pair_list = set()
    for i in interface_lists:
        i = i.item()
        for k, v in idx_match.items(): #k:chain v:{idx1:resno1, idx2:resno2,...}
            if i in v.keys():
                interface_final[k].add(v[i])
                
    for i, j in zip(interface_pair[0], interface_pair[1]):
        i = i.item()
        j = j.item()
        if i < j:
            continue
        chain_i, chain_j = None, None
        idx_i, idx_j = None, None
        for k, v in idx_match.items():
            if i in v:
                chain_i = k
                idx_i = v[i]
            if j in v:
                chain_j = k
                idx_j = v[j]
        interface_pair_list.add((chain_i, idx_i, chain_j, idx_j))                
    
    return interface_final, interface_pair_list

def calc_rmsd(ref, model, R, t):
    assert len(ref) == len(model)
    length = len(ref)
    # ref = np.array(ref).transpose() #(3, n)
    # model = np.array(model).transpose() #(3, n)
    # t = t.reshape(3, 1)
    # aligned_model = np.einsum('ij, jk -> ik', R, model) + t
    # aligned_model = np.dot(R, model.T).T + t
    aligned_model = np.dot(R, model.T).T + t
    # print(f"aligned_model shape{aligned_model.shape}")
    # write_pdb(aligned_model, [119, 107], 'aligned_model.pdb')
    # write_pdb(ref,[119,107],'ref.pdb')

    rmsd = np.sqrt(((aligned_model - ref)**2).sum(0).sum()/length)
    return rmsd

def get_capri(model_pdb, ref_pdb, receptor_chain, ligand_chain):
    
    model_seq, model_idx_tot, model_idx_atm, model_coord_bb, model_coord_all = PDB(model_pdb).read()
    ref_seq, ref_idx_tot, ref_idx_atm, ref_coord_bb, ref_coord_all = PDB(ref_pdb).read()
    
    # should only consider overlapped residues (reference does not have to be complete)
    idx_overlap = defaultdict(dict)
    for ch in model_idx_atm.keys():
        for atm in model_idx_atm[ch].keys():
            overlap = list(set(model_idx_atm[ch][atm]) & set(ref_idx_atm[ch][atm]))
            idx_overlap[ch][atm] = overlap
    #print(idx_overlap)
    model_rec_bb, ref_rec_bb = [], []
    model_lig_bb, ref_lig_bb = [], []
    
    # print('model_coord_bb', model_coord_bb.keys())
    # print('ref_coord_bb', ref_coord_bb.keys())
    # print('receptor_chain', receptor_chain)
    # print('ligand_chain', ligand_chain)
    for i in receptor_chain:
        for atm in idx_overlap[i].keys():
            for idx in idx_overlap[i][atm]:
                model_rec_bb.append(model_coord_bb[i][idx][bb_idx[atm]])
                ref_rec_bb.append(ref_coord_bb[i][idx][bb_idx[atm]])
    
    for i in ligand_chain:
        for atm in idx_overlap[i].keys():
            for idx in idx_overlap[i][atm]:
                model_lig_bb.append(model_coord_bb[i][idx][bb_idx[atm]])
                ref_lig_bb.append(ref_coord_bb[i][idx][bb_idx[atm]])
    
    # print('model_rec_bb', model_rec_bb)
    # print('ref_rec_bb', ref_rec_bb)
    model_rec_bb = np.array(model_rec_bb)
    ref_rec_bb = np.array(ref_rec_bb)
    model_lig_bb = np.array(model_lig_bb)
    ref_lig_bb = np.array(ref_lig_bb)
    
    model_rec_bb = model_rec_bb.reshape(-1, 3)
    ref_rec_bb = ref_rec_bb.reshape(-1, 3)
    model_lig_bb = model_lig_bb.reshape(-1, 3)
    ref_lig_bb = ref_lig_bb.reshape(-1, 3)
    
    #print('model_shape', model_rec_bb.shape)
    #print('ref_shape', ref_rec_bb.shape)
    assert model_rec_bb.shape == ref_rec_bb.shape
    assert model_lig_bb.shape == ref_lig_bb.shape
    
    #1. Calculate l-rmsd
    # Get receptor aligned R, T matrix
    model_bb = np.concatenate((model_rec_bb, model_lig_bb))
    ref_bb = np.concatenate((ref_rec_bb, ref_lig_bb))

    total_rmsd, (t_rec, R_rec) = ls_rmsd(model_bb, ref_bb)
    # total_rmsd, R_tot, t_tot = calculate_rmsd(model_bb, ref_bb)
    # print(model_bb.shape)
    print(f"total rmsd {total_rmsd}")
    receptor_rmsd, (t_rec, R_rec) = ls_rmsd(model_rec_bb, ref_rec_bb)
    # receptor_rmsd, R_rec, t_rec = calculate_rmsd(model_rec_bb, ref_rec_bb)
    print(f"receptor rmsd {receptor_rmsd}")
    l_rmsd = calc_rmsd(ref_lig_bb, model_lig_bb, R_rec, t_rec)
    
    
    #2. Calculate i-rmsd
    # Get interface region of refernce 
    model_interface_bb = []
    ref_interface_bb = []
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device', device)
    ref_idx_interface, ref_pair_interface = get_interface_idx(ref_coord_all, ref_idx_tot, receptor_chain, ligand_chain, \
            iface_cutoff = 10.0, device = device)
    
    count_ori = defaultdict(int)
    for k, v in ref_idx_interface.items(): #k: chain #v: interface residue index
        sort_idx = sorted(v)
        for atm in idx_overlap[k].keys(): #if both exists in receptor and ligan
            for idx in sort_idx:
                if idx in idx_overlap[k][atm]:
                    model_interface_bb.append(model_coord_bb[k][idx][bb_idx[atm]])
                    ref_interface_bb.append(ref_coord_bb[k][idx][bb_idx[atm]])

    model_interface_bb = np.array(model_interface_bb)
    ref_interface_bb = np.array(ref_interface_bb)
    
    model_interface_bb = model_interface_bb.reshape(-1, 3)
    ref_interface_bb = ref_interface_bb.reshape(-1, 3)
    
    # print('model_interface_bb', model_interface_bb)
    # print('ref_interface_bb', ref_interface_bb)
    i_rmsd, (t_inf, R_inf) = ls_rmsd(model_interface_bb, ref_interface_bb)
    # i_rmsd, R_inf, t_inf = calculate_rmsd(model_interface_bb, ref_interface_bb)
    
    model_coord_all_overlap = defaultdict(dict)
    # print(f"model_coord_all\n{model_coord_all}")
    print(model_coord_all.keys())
    # try:
    for ch, idxs in ref_coord_all.items():
        for idx in idxs:
            model_coord_all_overlap[ch][idx] = model_coord_all[ch][idx]
    # except Exception as e:
    #     print(f"Error: {e}")
    #     print(f"model_coord_all: {model_coord_all.keys()}")
    #     print(f"ref_coord_all: {ref_coord_all.keys()}")
        # continue

    #3. Calculate fNAT
    ref_idx_interface, ref_pair_interface = get_interface_idx(ref_coord_all, ref_idx_tot, receptor_chain, ligand_chain, iface_cutoff = 5.0)
    model_idx_interface, model_pair_interface = get_interface_idx(model_coord_all_overlap, ref_idx_tot, receptor_chain, ligand_chain, iface_cutoff = 5.0)
    
    pair_both = set(ref_pair_interface) & set(model_pair_interface)
    f_nat = float(len(pair_both))/len(set(ref_pair_interface))
    
    return l_rmsd, i_rmsd, f_nat

def eval_fnat(fnat):
    if fnat >= 0.5:
        return 'high'
    elif fnat >= 0.3:
        return 'medium'
    elif fnat >= 0.1:
        return 'acceptable'
    elif fnat < 0.1:
        return 'incorrect'

def eval_lrmsd(lrmsd):
    if lrmsd <= 1.0:
        return 'high'
    elif lrmsd <= 5.0:
        return 'medium'
    elif lrmsd <= 10.0:
        return 'acceptable'
    elif lrmsd > 10.0:
        return 'incorrect'

def eval_irmsd(irmsd):
    if irmsd <= 1.0:
        return 'high'
    elif irmsd <= 2.0:
        return 'medium'
    elif irmsd <= 4.0:
        return 'acceptable'
    elif irmsd > 4.0:
        return 'incorrect'

#lst = ['high', 'medium', 'acceptable', 'incorrect']
lst_dict = {'high':3, 'medium':2, 'acceptable':1, 'incorrect':0}
lst_dict_reverse = {3: 'high', 2:'medium', 1:'acceptable', 0:'incorrect'}

def calc_capri_criteria(lrmsd, irmsd, fnat):
    lrmsd_num = lst_dict[eval_lrmsd(lrmsd)]
    irmsd_num = lst_dict[eval_irmsd(irmsd)]
    fnat_num = lst_dict[eval_fnat(fnat)]
    result = min(fnat_num, max(lrmsd_num, irmsd_num))
    
    return lst_dict_reverse[result]


def calc_dockq(lrmsd, irmsd, fnat):
    lrmsd, irmsd = 1/(1+(lrmsd/8.5)**2), 1/(1+(irmsd/1.5)**2)
    dockq = (lrmsd + irmsd + fnat)/3
    return dockq


def change_chain(pdb, name):
   # change the chain name and chain numbering to chothia
    chains = []
    _, hchain, lchain, agchain = name.split('_')[:4]
    
    #for AF-based
    #chain_dict = AF_get_chain_dict(name)
    #print(chain_dict)
    
    # for RF
    for i in [hchain, lchain, agchain]:
        if i != '#':
            for j in i:
                chains.append(j)
                
    # chain_alphabet = list(string.ascii_uppercase)
    chain_alphabet=['H','L','T']
    chain_dict = dict(zip(chain_alphabet, chains))
    
    newlines = []
    pdblines = open(pdb).readlines()
    antigen_not_exist = True
    for line in pdblines:
        if line.startswith('ATOM'):
            chain = line[21]
            new_chain = chain_dict[chain]
            if new_chain in agchain:
                antigen_not_exist = False
            if line[21] == 'L':
                if line[21] == 'L':
                    line = line[:23] + f"{int(line[23:26]) - 115:3d}" + line[26:]
            newline = f'{line[:21]}{new_chain}{line[22:]}'
            newlines.append(newline)
        else:
            newlines.append(line)
    #if antigen_not_exist == True:
    #    print(pdb)
    # os.makedirs(f'{Path(pdb).parent}/get_capri', exist_ok=True)
    # new_filename = f'{Path(pdb).parent}/get_capri/{Path(pdb).stem}_rechain.pdb'
    new_filename=f'{Path(pdb).stem}_rechain.pdb'
    f_out = open(new_filename, 'w')
    f_out.writelines(newlines)
    f_out.write('TER\n')
    f_out.close()


In [32]:
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_wt.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_goodMPNN_350.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_badMPNN_350.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_wt_mAb.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_goodMPNN_mAb.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_badMPNN_mAb.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_bd_c.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_gd_c.pdb"
# pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/scripts/999th_epi_1s78_D_C_A_gt_c.pdb"
pdb = "/home/kkh517/Github/RFantibody/scripts/examples/rf2/example_outputs/15th_epi_7sbd_H_L_C_3_HLT_best.pdb"
change_chain(pdb, '7sbd_H_L_C')
# model_pdb = f'{Path(pdb).parent}/get_capri/{Path(pdb).stem}_rechain.pdb'
model_pdb = f'{Path(pdb).stem}_rechain.pdb'
# model_pdb = f"7sbd_H_L_C_RFantibody_output.pdb"
print(model_pdb)
# ref_pdb = "/home/kkh517/submit_files/Project/epitope_sampler_halfblood/antibody_meeting_inputs/1s78_D_C_A_wt/1s78_D_C_A_renum.pdb"
# ref_pdb ="/home/kkh517/submit_files/Project/epitope_sampler_halfblood/antibody_meeting_inputs/1s78_D_C_A_wt/1s78_D_C_A_renum.pdb"
ref_pdb = "/home/kkh517/benchmark_set_after210930/7sbd_H_L_C/new.pdb"
# print(f"TMscore -c -ter 0 {model_pdb} {ref_pdb}")

15th_epi_7sbd_H_L_C_3_HLT_best_rechain.pdb


In [33]:
ab_chain = 'HL' ; ag_chain = 'C'
l_rmsd, i_rmsd, f_nat = get_capri(model_pdb, ref_pdb, receptor_chain=ag_chain, ligand_chain=ab_chain)
# l_rmsd, i_rmsd, f_nat = get_capri(model_pdb, ref_pdb, receptor_chain='C', ligand_chain='DC')
capri_criteria = calc_capri_criteria(l_rmsd, i_rmsd, f_nat)
dockq = calc_dockq(l_rmsd, i_rmsd, f_nat)
print(l_rmsd, i_rmsd, f_nat, dockq, capri_criteria)

total rmsd 12.78239818633521
receptor rmsd 1.734770995509093
device cpu
dict_keys(['H', 'L', 'C'])
37.34305432220897 14.237077566103983 0.03508771929824561 0.03177491608353966 incorrect


In [25]:
import numpy as np

def write_pdb(xyz_tensor, L_s, output_file="output.pdb"):
    """
    Write a PDB file from an XYZ tensor and chain length list.
    
    Parameters:
    xyz_tensor (np.ndarray): Tensor of shape (N, 3) containing atomic coordinates.
    L_s (list): List of chain lengths.
    output_file (str): Output PDB file name.
    """
    atom_format = "ATOM  {:5d}  CA  ALA {:1s}{:4d}    {:8.3f}{:8.3f}{:8.3f}  1.00  0.00           C"
    ter_format = "TER"
    
    with open(output_file, 'w') as f:
        atom_index = 1
        res_index = 1
        xyz_index = 0
        chain_id = 'A'
        
        for chain_length in L_s:
            for _ in range(chain_length):
                if xyz_index >= len(xyz_tensor):
                    raise ValueError("xyz_tensor length is smaller than the sum of L_s.")
                f.write(atom_format.format(atom_index, chain_id, res_index, *xyz_tensor[xyz_index]) + "\n")
                atom_index += 1
                res_index += 1
                xyz_index += 1
            
            # Write TER after each chain
            f.write(ter_format + "\n")
            
            # Update chain_id to the next alphabet
            chain_id = chr(ord(chain_id) + 1)
            res_index = 1  # Reset residue index for the new chain

    print(f"PDB file '{output_file}' written successfully.")

In [2]:
# Task 1: Identify data points with "[Errno 2] No such file or directory:"
error_data_points = []
with open('/home/kkh517/submit_files/Project/epitope_sampler_halfblood/test.log', 'r') as file:
    for line in file:
        if "[Errno 2] No such file or directory:" in line:
            split_line = line.split('/')
            error_data_points.append(split_line[-4])

# Task 2: Sort data points by size
data_points = []
with open('/home/kkh517/submit_files/Project/epitope_sampler_halfblood/test.log', 'r') as file:
    for line in file:
        if "#"*50 in line:
            item = next(file).strip()
            size = next(file).strip()
            data_points.append((item, size))

sorted_data_points = sorted(data_points, key=lambda x: int(x[1].split(':')[-1]))

# Print the results
print("Data points with '[Errno 2] No such file or directory:':")
print(error_data_points)
print(len(set(error_data_points)))
print("\nSorted data points by size:")
gpu02_dict = {}
gpu01_dict = {}
gpu01_dict2 = {}
i = 0
length_dict = {}
for item, size in sorted_data_points:
    size = size.split(':')[-1]
    item = item.split(':')[-1].split("'")[-2]
    print(f"Item: {item}, Size: {size}")
    length_dict[item] = int(size)
    i+=1
    if int(size) > 800:
        if i %2 == 0:
            gpu01_dict[f"{str(i)}"] = [f"{item}"]
        else:
            gpu01_dict2[f'{str(i)}'] = [f'{item}']
    else:
        gpu02_dict[f"{str(i)}"] = [f"{item}"]

    

Data points with '[Errno 2] No such file or directory:':
['7uij_H_L_CD', '7vgr_D_C_AB', '7y9t_D_#_AB', '7yix_D_Z_AB', '8b7h_H_L_AB', '8eln_L_#_IJ', '8et0_D_#_AB', '8j80_E_C_AB', '8ozb_C_#_EF', '8pnu_J_#_GHI', '8pnu_J_#_GHI', '8t03_C_D_AB', '8t05_C_D_AB', '8tqi_H_L_AB', '8u3s_B_#_AC', '8ulf_H_L_AB', '9g7k_C_#_AB', '9ima_C_D_AB']
17

Sorted data points by size:
Item: 8c3l_D_#_C, Size:  183
Item: 8qf4_E_#_A, Size:  203
Item: 8pih_C_#_A, Size:  251
Item: 8c5h_N_#_S, Size:  257
Item: 8ozb_C_#_EF, Size:  269
Item: 8k33_B_#_A, Size:  270
Item: 8pij_B_#_A, Size:  289
Item: 9fzc_D_#_B, Size:  296
Item: 8d9y_B_A_I, Size:  304
Item: 7w71_I_M_B, Size:  316
Item: 8f5i_X_Y_A, Size:  332
Item: 7wki_B_#_A, Size:  334
Item: 7yru_H_L_A, Size:  337
Item: 8djg_C_D_E, Size:  338
Item: 8av2_C_#_A, Size:  339
Item: 8jel_C_D_J, Size:  341
Item: 8rz0_F_Z_E, Size:  354
Item: 7sbd_H_L_C, Size:  358
Item: 8gkl_H_L_E, Size:  361
Item: 8dcn_D_E_F, Size:  362
Item: 8f6o_A_B_C, Size:  363
Item: 8db4_A_B_E, Size:  365

In [48]:
item.split(':')[-1].split("'")[-2]

IndexError: list index out of range

In [37]:
for name in error_data_points:
    print(name, name.split('_')[-1])

7uij_H_L_CD CD
7vgr_D_C_AB AB
7y9t_D_#_AB AB
7yix_D_Z_AB AB
8b7h_H_L_AB AB
8eln_L_#_IJ IJ
8et0_D_#_AB AB
8j80_E_C_AB AB
8ozb_C_#_EF EF
8pnu_J_#_GHI GHI
8pnu_J_#_GHI GHI
8t03_C_D_AB AB
8t05_C_D_AB AB
8tqi_H_L_AB AB
8u3s_B_#_AC AC
8ulf_H_L_AB AB
9g7k_C_#_AB AB
9ima_C_D_AB AB


In [3]:
import json

# Define the file path
file_path = '/home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu01_dict.json'

# Save gpu01_dict as a JSON file
with open(file_path, 'w') as file:
    json.dump(gpu01_dict, file)

print(f"gpu01_dict saved as JSON file: {file_path}")


gpu01_dict saved as JSON file: /home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu01_dict.json


In [4]:
import json

# Define the file path
file_path = '/home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu01_dict2.json'

# Save gpu01_dict as a JSON file
with open(file_path, 'w') as file:
    json.dump(gpu01_dict2, file)

print(f"gpu01_dict saved as JSON file: {file_path}")


gpu01_dict saved as JSON file: /home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu01_dict2.json


In [5]:
import json

# Define the file path
file_path = '/home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu02_dict.json'

# Save gpu01_dict as a JSON file
with open(file_path, 'w') as file:
    json.dump(gpu02_dict, file)

print(f"gpu02_dict saved as JSON file: {file_path}")


gpu02_dict saved as JSON file: /home/kkh517/submit_files/Project/epitope_sampler_halfblood/gpu02_dict.json


In [6]:
len(gpu01_dict),len(gpu01_dict2),len(gpu02_dict)

(10, 11, 80)

In [34]:
vals = list(gpu01_dict.values()) + list(gpu01_dict2.values()) + list(gpu02_dict.values())
vals = set([i for sublist in vals for i in sublist])
# vals

In [14]:
import os
already_set = set(os.listdir('/home/kkh517/submit_files/Project/epitope_sampler_halfblood/inference_pdb/halfblood_1.0.1_After210930_ES/'))
already_set = set([i for i in already_set if len(i.split('_'))>3])

In [46]:
length = 0
for i in set(already_set):
    # print(i, length_dict[i])
    if length_dict[i] > length:
        length = length_dict[i]
        # print(i, length)

length
    

1238