In [1]:
import re
import numpy as np
import pandas as pd
import os
from copy import deepcopy
import shutil
from Bio import PDB
from rdkit import Chem

In [2]:
# from RFDiffusion
def calc_rmsd(xyz1, xyz2, eps=1e-6):
    """
    Calculates RMSD between two sets of atoms (L, 3)
    """
    # center to CA centroid
    xyz1 = xyz1 - xyz1.mean(0)
    xyz2 = xyz2 - xyz2.mean(0)

    # Computation of the covariance matrix
    C = xyz2.T @ xyz1

    # Compute otimal rotation matrix using SVD
    V, S, W = np.linalg.svd(C)

    # get sign to ensure right-handedness
    d = np.ones([3,3])
    d[:,-1] = np.sign(np.linalg.det(V)*np.linalg.det(W))

    # Rotation matrix U
    U = (d*V) @ W

    # Rotate xyz2
    xyz2_ = xyz2 @ U
    L = xyz2_.shape[0]
    rmsd = np.sqrt(np.sum((xyz2_-xyz1)*(xyz2_-xyz1), axis=(0,1)) / L + eps)

    return rmsd, U

def extract_backbone_positions(pdb_file, bb=True):
    # Create a PDB parser
    parser = PDB.PDBParser(QUIET=True)

    # Parse the structure
    structure = parser.get_structure('protein', pdb_file)

    # Initialize a list to hold backbone coordinates
    coords = []

    # Iterate over all residues in all chains
    for model in structure:
        for chain in model:
            for residue in chain:
                # Skip heteroatoms and water molecules
                if PDB.is_aa(residue, standard=True):
                    # Extract backbone atoms N, CA, C
                    try:
                        if bb:
                            n_coord = residue['N'].get_coord()
                            ca_coord = residue['CA'].get_coord()
                            c_coord = residue['C'].get_coord()
                            o_coord = residue['O'].get_coord()
                            
                            # Append the coordinates as a tuple (N, CA, C)
                            coords.append((n_coord, ca_coord, c_coord, o_coord))
                        else:
                            coords.append([atom.get_coord() for atom in residue if atom.element != 'H']) 
                    except KeyError:
                        # In case the residue is missing any backbone atom
                        continue

    return coords

def update_positions(structure, new_positions):
    atom_index = 0
    for model in structure:
        for chain in model:
            for residue in chain:
                for atom in residue:
                    atom.coord = new_positions[atom_index]
                    atom_index += 1

### stat 5A residue num

In [5]:
path = 'generated_result/CP_SS_TS_mask2'
rf_gen_list = os.listdir(f'{path}/0_diffusion')

pos_right = np.array([[-5.440000057220459, 2.9570000171661377, 2.7920000553131104],
             [-3.7829999923706055, 2.055999994277954, 1.2660000324249268],
             [-2.7790000438690186, 2.6679999828338623, 2.0239999294281006],
             [-2.382999897003174, 1.4229999780654907, -0.7110000252723694],
             [-3.50600004196167, 1.3359999656677246, 0.026000000536441803],
             [-3.1080000400543213, 3.428999900817871, 3.134000062942505],
             [-5.11299991607666, 2.194000005722046, 1.6779999732971191],
             [-4.434999942779541, 3.5829999446868896, 3.5199999809265137]])
pos_right_expand = pos_right[np.newaxis, np.newaxis, :, :]

pos_left = np.array([[-0.8939999938011169, -0.2370000034570694, 0.04899999871850014],
            [-1.7589999437332153, -1.4110000133514404, 0.3709999918937683],
            [-2.996999979019165, -3.259000062942505, -0.41600000858306885],
            [-3.6059999465942383, -3.7170000076293945, -1.718000054359436],
            [-1.8639999628067017, -1.8040000200271606, 1.5119999647140503],
            [-2.319999933242798, -2.0199999809265137, -0.675000011920929]])
pos_left_expand = pos_left[np.newaxis, np.newaxis, :, :]

contact_dict = {}
for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = np.array(extract_backbone_positions(f'{path}/0_diffusion/{f}'))

    squared_diff = np.sum((rf_pos[:, :, np.newaxis, :] - pos_right_expand) ** 2, axis=-1)
    distance = np.sqrt(squared_diff)
    min_dist = distance.min(axis=-1).min(axis=-1)
    right_num = sum(min_dist < 6)

    squared_diff = np.sum((rf_pos[:, :, np.newaxis, :] - pos_left_expand) ** 2, axis=-1)
    distance = np.sqrt(squared_diff)
    min_dist = distance.min(axis=-1).min(axis=-1)
    left_num = sum(min_dist < 6)

    squared_diff = np.sum((rf_pos[:, :, np.newaxis, :] - lig_pos) ** 2, axis=-1)
    distance = np.sqrt(squared_diff)
    min_dist = distance.min(axis=-1).min(axis=-1)
    clash_num = sum(min_dist < 3.2)
    
    contact_dict[int(f.split('_')[-1].split('.')[0])] = {
        'file': f,
        'right res num': right_num,
        'left res num': left_num,
        'clash num': clash_num
    }

In [6]:
import pandas as pd

contact_df = pd.DataFrame(contact_dict).T
contact_df['total'] = contact_df['right res num'] + contact_df['left res num']
contact_df = contact_df.sort_values('total', ascending=False)
contact_df.to_csv(f'{path}/contact_stat.csv', index=False)

In [9]:
new_path = 'CP_SS_TS_mask2'
os.makedirs(new_path, exist_ok=True)

for f in contact_df[(contact_df['total'] > 5) & (contact_df['clash num'] == 0)]['file']:
    shutil.copyfile(f'{path}/0_diffusion/{f}', f'{new_path}/{f}')

## AF2 output

### Align AF2

In [3]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_gen_list = os.listdir(f'{path}/2_af2')
align_path = f'{path}/2_af2_aligned'
os.makedirs(align_path, exist_ok=True)

for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = np.array(extract_backbone_positions(f'{path}/0_diffusion/{f}')).reshape(-1,3)    
    pdb_id = f.split('.')[0]
    # af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    af2_files = [f for f in af2_gen_list if '.pdb' in f]
    for af2_f in af2_files:
        af2_pos = np.array(extract_backbone_positions(f'{path}/2_af2/{af2_f}')).reshape(-1,3)
        rmsd, U = calc_rmsd(deepcopy(rf_pos), deepcopy(af2_pos))
        full_af2_pos = np.concatenate(extract_backbone_positions(f'{path}/2_af2/{af2_f}', bb=False))
        full_af2_pos = (full_af2_pos - full_af2_pos.mean(0)) @ U + rf_pos.mean(0)

        shutil.copyfile(f'{path}/2_af2/{af2_f}', f'{align_path}/{af2_f}')
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('protein', f'{align_path}/{af2_f}')
        update_positions(structure, full_af2_pos)

        io = PDB.PDBIO()
        io.set_structure(structure)
        io.save(f'{align_path}/{af2_f}')

### align ligand with fix residue

In [6]:
with open('pl_Benchmark/CP_SS_TS/7jrq_CP_SS_TS.pdb', 'r') as f:
    ref_lig_lines = f.readlines()
ref_lig_lines = [l for l in ref_lig_lines if 'HETATM' in l]
ref_lig_lines = [re.sub(r'\s+', ' ', l).split(' ') for l in ref_lig_lines]
ref_ligand_coords = [l[5:8] for l in ref_lig_lines]
ref_ligand_coords = np.array(ref_ligand_coords).astype(float)
ligand_atom_type = [l[-2] for l in ref_lig_lines]

In [4]:
def extract_res_coord(structure, res_id):
    coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                hetflag, resseq, icode = residue.get_id()
                if resseq == res_id:
                    for atom in residue:
                        coords.append(atom.coord)

    return coords

In [6]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_gen_list = os.listdir(f'{path}/2_af2_aligned')
os.makedirs(f'{path}/2_af2_aligned_complex', exist_ok=True)
os.makedirs(f'{path}/2_af2_aligned_lig', exist_ok=True)
fix_res_id = 61

# rmsd, U_ligand = calc_rmsd(np.array(lig_pos), np.array(ref_ligand_coords))

parser = PDB.PDBParser(QUIET=True)
reference_structure = parser.get_structure('complex', 'pl_Benchmark/CP_SS_TS/7jrq_CP_SS_TS.pdb')
ref_H_coords = np.array(extract_res_coord(reference_structure, 64))

for af2_f in af2_gen_list:
    # parser = PDB.PDBParser(QUIET=True)
    # af2_structure = parser.get_structure('complex', f'{path}/2_af2_aligned/{af2_f}')
    # af2_H_coords = np.array(extract_res_coord(af2_structure, fix_res_id))

    # aligned_ref_H_coords = (ref_H_coords - ref_ligand_coords.mean(0)) @ U_ligand + lig_pos.mean(0)
    # rmsd, U_H = calc_rmsd(np.array(af2_H_coords), np.array(aligned_ref_H_coords))
    # new_ligand_coords = (lig_pos - aligned_ref_H_coords.mean(0)) @ U_H + af2_H_coords.mean(0)

    parser = PDB.PDBParser(QUIET=True)
    af2_structure = parser.get_structure('complex', f'{path}/4_redesign_af2_aligned/{af2_f}')
    af2_H_coords = np.array(extract_res_coord(af2_structure, fix_res_id))

    rmsd, U_H = calc_rmsd(np.array(af2_H_coords), np.array(ref_H_coords))
    new_ligand_coords = (ref_ligand_coords - ref_H_coords.mean(0)) @ U_H + af2_H_coords.mean(0)
    
    # rf_f = [r for r in rf_gen_list if '.pdb' in r][0]
    shutil.copyfile(f'{path}/2_af2_aligned/{af2_f}', 'tmp.pdb')
    
    with open('tmp.pdb', 'r') as f:
        lines = f.readlines()
    new_lines, lig_lines = [], []
    for l in lines:
        if l[:4] == 'ATOM':
            new_lines.append(l)
    lig_atom_id, lig_res_id, lig_chain_id = int(new_lines[-1][6:11]) + 1, int(new_lines[-1][22:26]) + 1, chr(ord(new_lines[-1][21]) + 1)
    for i, (pos, atom_type) in enumerate(zip(new_ligand_coords, ligand_atom_type)):
        j0 = str('HETATM').ljust(6)  # atom#6s
        j0_lig = str('ATOM').ljust(6)  # atom#6s
        j1 = str(lig_atom_id).rjust(5)  # aomnum#5d
        j1_lig = str(i+1).rjust(5)  # aomnum#5d
        j2 = str(atom_type).center(4)  # atomname$#4s
        j3 = 'HBA'.ljust(3)  # resname#1s
        j4 = lig_chain_id.rjust(1)  # Astring
        j5 = str(lig_res_id).rjust(4)  # resnum
        j6 = str('%8.3f' % (float(pos[0]))).rjust(8)  # x
        j7 = str('%8.3f' % (float(pos[1]))).rjust(8)  # y
        j8 = str('%8.3f' % (float(pos[2]))).rjust(8)  # z\
        j9 = str('%6.2f' % (1.00)).rjust(6)  # occ
        j10 = str('%6.2f' % (25.02)).ljust(6)  # temp
        j11 = str(atom_type).rjust(12)  # elname
        new_lines.append("%s%s %s %s %s%s    %s%s%s%s%s%s\n" % (j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11))
        lig_lines.append("%s%s %s %s %s%s    %s%s%s%s%s%s\n" % (j0_lig, j1_lig, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11))
        lig_atom_id += 1
    
    with open(f'{path}/2_af2_aligned_complex/{af2_f}', 'w') as f:
        f.writelines(new_lines)
    f.close()

    with open(f'{path}/2_af2_aligned_lig/{af2_f}', 'w') as f:
        f.writelines(lig_lines)
    f.close()

### stat rmsd, plddt, clash

In [None]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_path = f'{path}/2_af2_aligned_complex'
af2_gen_list = os.listdir(af2_path)
af2_gen_list = [f for f in af2_gen_list if '.pdb' in f]
rmsd_res = {}

score_file = pd.read_csv(f'{path}/2_af2/scores.csv')
for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = extract_backbone_positions(f'{path}/0_diffusion/{f}')
    pdb_id = f.split('.')[0]
    # af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    af2_files = [f for f in af2_gen_list if '.pdb' in f]
    for af2_f in af2_files:
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}')
        rmsd = calc_rmsd(np.array(af2_pos).reshape(-1,3), np.array(rf_pos).reshape(-1,3))[0]
        
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}', bb=False)
        af2_pos = [np.array(p) for p in af2_pos]

        with open(f'{path}/2_af2_aligned_lig/{af2_f}', 'r') as lig_f:
            ligand_pdb_block = lig_f.read()
        mol = Chem.MolFromPDBBlock(ligand_pdb_block, sanitize=False, removeHs=False)
        lig_pos = mol.GetConformer(0).GetPositions()
        mol = remove_bond(mol, 'C', 'Fe')
        frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
        Fe_DPP = max(frags, key=lambda m: m.GetNumAtoms())
        DPP_pos = Fe_DPP.GetConformer(0).GetPositions()

        # sequence = score_file.iloc[int(af2_f.split('_')[5]) - 1]['Sequence']
        sequence = score_file[score_file['Name'] == af2_f[:af2_f.find('model')-1]]['Sequence'].item()
        distance = [np.linalg.norm(p[:, np.newaxis] - DPP_pos, axis=2).min() for p in af2_pos]
        clash_num = (np.array(distance) < 3.2).sum()
        clash_index = np.where(np.array(distance) < 3.2)[0]
        DPP_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

        distance = [np.linalg.norm(p[:, np.newaxis] - lig_pos, axis=2).min() for p in af2_pos]
        clash_index = np.where(np.array(distance) < 3.2)[0]
        lig_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]
        
        # rmsd_res[af2_f.split('_')[5]] = {
        rmsd_res[af2_f[:af2_f.find('model')-1]] = {
            'rmsd': rmsd,
            'clash_num': clash_num,
            'DPP_clash_indices': DPP_clash_indices,
            'ligand_clash_indices': lig_clash_indices,
            'ori_pdb': f,
            'af2_pdb': af2_f
        }

## Prepare pocket redesign

### refomulate pdb

In [3]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_gen_list = os.listdir(f'{path}/2_af2')
ligand_path = f'{path}/2_af2_aligned'
align_path = f'{path}/2_af2_aligned'
os.makedirs(align_path, exist_ok=True)

fix_res = ['REMARK 666 MATCH TEMPLATE X HBA    0 MATCH MOTIF A HIS 61 1 1 \n']

f = [f for f in rf_gen_list if '.pdb' in f][0]
with open(f'{path}/0_diffusion/{f}', 'r') as r:
    lines = r.readlines()
ligand = [l for l in lines if 'HETATM' in l]
for i in range(4):
    l = list(ligand[i])
    l[13] = str(i + 1)
    ligand[i] = ''.join(l)
for i in range(44):
    l = list(ligand[i + 4])
    idx = str(i + 1)
    if len(idx) > 1:
        l[13] = idx[0]
        l[14] = idx[1]
    else:
        l[13] = idx[0]
    ligand[i + 4] = ''.join(l)
for i in range(2):
    l = list(ligand[i + 48])
    l[13] = str(i + 1)
    ligand[i + 48] = ''.join(l)
l = list(ligand[-1])
l[13] = 'E'
l[14] = '1'
l[-2] = 'E'
ligand[-1] = ''.join(l)
with open(f'{path}/ligand.pdb', 'w') as r:
    r.write(''.join(ligand))

for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = np.array(extract_backbone_positions(f'{path}/0_diffusion/{f}')).reshape(-1,3)    
    pdb_id = f.split('.')[0]
    # af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    af2_files = [f for f in af2_gen_list if '.pdb' in f]
    
    for af2_f in af2_files:
        af2_pos = np.array(extract_backbone_positions(f'{path}/2_af2/{af2_f}')).reshape(-1,3)
        rmsd, U = calc_rmsd(deepcopy(rf_pos), deepcopy(af2_pos))
        full_af2_pos = np.concatenate(extract_backbone_positions(f'{path}/2_af2/{af2_f}', bb=False))
        full_af2_pos = (full_af2_pos - full_af2_pos.mean(0)) @ U + rf_pos.mean(0)

        shutil.copyfile(f'{path}/2_af2/{af2_f}', f'{align_path}/{af2_f}')
        parser = PDB.PDBParser(QUIET=True)
        structure = parser.get_structure('protein', f'{align_path}/{af2_f}')
        update_positions(structure, full_af2_pos)

        io = PDB.PDBIO()
        io.set_structure(structure)
        io.save(f'{align_path}/{af2_f}')

        with open(f'{align_path}/{af2_f}', 'r') as r:
            lines = r.readlines()
    
        prot = [l for l in lines if 'ATOM' in l]
        lines = fix_res + prot + ligand
        with open(f'{align_path}/{af2_f}', 'w') as r:
            r.write(''.join(lines))

### stat rmsd, plddt, clash

In [4]:
from rdkit import Chem
def remove_bond(mol, atom1_symbol, atom2_symbol):
    bond_to_remove = None
    for bond in mol.GetBonds():
        atom1 = bond.GetBeginAtom()
        atom2 = bond.GetEndAtom()
        
        if (atom1.GetSymbol() == atom1_symbol and atom2.GetSymbol() == atom2_symbol) or \
           (atom1.GetSymbol() == atom2_symbol and atom2.GetSymbol() == atom1_symbol):
            bond_to_remove = (atom1.GetIdx(), atom2.GetIdx())
            break
    
    if bond_to_remove:
        mol = Chem.RWMol(mol)
        mol.RemoveBond(bond_to_remove[0], bond_to_remove[1])
        mol = mol.GetMol() 
    
    return mol

In [5]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_path = f'{path}/2_af2_aligned'
af2_gen_list = os.listdir(af2_path)
af2_gen_list = [f for f in af2_gen_list if '.pdb' in f]
rmsd_res = {}

with open(f'{path}/ligand.pdb', 'r') as lig_f:
    ligand_pdb_block = lig_f.read()
mol = Chem.MolFromPDBBlock(ligand_pdb_block, sanitize=False, removeHs=False)
lig_pos = mol.GetConformer(0).GetPositions()
mol = remove_bond(mol, 'C', 'Fe')
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
Fe_DPP = max(frags, key=lambda m: m.GetNumAtoms())
DPP_pos = Fe_DPP.GetConformer(0).GetPositions()

score_file = pd.read_csv(f'{path}/2_af2/scores.csv')
for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = extract_backbone_positions(f'{path}/0_diffusion/{f}')
    pdb_id = f.split('.')[0]
    # af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    af2_files = [f for f in af2_gen_list if '.pdb' in f]
    for af2_f in af2_files:
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}')
        rmsd = calc_rmsd(np.array(af2_pos).reshape(-1,3), np.array(rf_pos).reshape(-1,3))[0]
        
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}', bb=False)
        af2_pos = [np.array(p) for p in af2_pos]

        # sequence = score_file.iloc[int(af2_f.split('_')[5]) - 1]['Sequence']
        sequence = score_file[score_file['Name'] == af2_f[:af2_f.find('model')-1]]['Sequence'].item()
        distance = [np.linalg.norm(p[:, np.newaxis] - DPP_pos, axis=2).min() for p in af2_pos]
        clash_num = (np.array(distance) < 3.2).sum()
        clash_index = np.where(np.array(distance) < 3.2)[0]
        DPP_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

        distance = [np.linalg.norm(p[:, np.newaxis] - lig_pos, axis=2).min() for p in af2_pos]
        clash_index = np.where(np.array(distance) < 3.2)[0]
        lig_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]
        
        # rmsd_res[af2_f.split('_')[5]] = {
        rmsd_res[af2_f[:af2_f.find('model')-1]] = {
            'rmsd': rmsd,
            'clash_num': clash_num,
            'DPP_clash_indices': DPP_clash_indices,
            'ligand_clash_indices': lig_clash_indices,
            'ori_pdb': f,
            'af2_pdb': af2_f
        }

In [6]:
import pandas as pd
scores = pd.read_csv(f'{path}/2_af2/scores.csv')
for n, p in zip(scores['Name'], scores['lDDT']):
    key = n.split('_')[-1]
    # rmsd_res[key]['plddt'] = p
    rmsd_res[n]['plddt'] = p

selected_res = {}
for k, v in rmsd_res.items():
    if v['plddt'] > 87 and v['rmsd'] < 1.5:
        selected_res[k] = v

df = pd.DataFrame(selected_res)
df = df.T
df.to_csv(f'{path}/scores.csv')

In [27]:
new_path = f'{path}/2_af2_redesign'
os.makedirs(new_path, exist_ok=True)

for v in selected_res.values():
    o = v['ori_pdb']
    af2 = v['af2_pdb']
    if v['clash_num'] < 10:
        # shutil.copyfile(f'{path}/0_diffusion/{o}', f'{new_path}/{o}')
        shutil.copyfile(f'{af2_path}/{af2}', f'{new_path}/{af2}')

### Stat Redesign RMSD

#### Not adjust Ligang with HIS

In [22]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'

redesign_path = f'{path}/3.1_design_pocket_ligandMPNN'
redesign_list = os.listdir(redesign_path)

redesign_af2_path = f'{path}/4_redesign_af2'
redesign_af2_list = os.listdir(redesign_af2_path)
redesign_af2_list = [f for f in redesign_af2_list if '.pdb' in f]

align_redesign_af2_path = f'{path}/4_redesign_af2_aligned'
os.makedirs(align_redesign_af2_path, exist_ok=True)

interact_res_id = 61

# with open(f'{path}/ligand.pdb', 'r') as lig_f:
#     ligand_pdb_block = lig_f.read()
# mol = Chem.MolFromPDBBlock(ligand_pdb_block, sanitize=False, removeHs=False)
# lig_pos = mol.GetConformer(0).GetPositions()
# mol = remove_bond(mol, 'C', 'Fe')
# frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
# Fe_DPP = max(frags, key=lambda m: m.GetNumAtoms())
# DPP_pos = Fe_DPP.GetConformer(0).GetPositions()

rmsd_res = {}
scores = pd.read_csv(f'{redesign_af2_path}/redesign_sequence.csv')
for f in redesign_list:
    if '.pdb' not in f:
        continue
    redesign_pos = np.array(extract_backbone_positions(f'{redesign_path}/{f}')).reshape(-1,3)
    
    pdb_id = f.replace('.pdb', '')
    af2_f = [f for f in redesign_af2_list if pdb_id in f][0]
    redesign_af2_pos = np.array(extract_backbone_positions(f'{redesign_af2_path}/{af2_f}')).reshape(-1,3)  

    # align af2 structure
    rmsd, U = calc_rmsd(deepcopy(redesign_pos), deepcopy(redesign_af2_pos))
    full_af2_pos = np.concatenate(extract_backbone_positions(f'{redesign_af2_path}/{af2_f}', bb=False))
    full_af2_pos = (full_af2_pos - full_af2_pos.mean(0)) @ U + redesign_pos.mean(0)

    shutil.copyfile(f'{redesign_af2_path}/{af2_f}', f'{align_redesign_af2_path}/{af2_f}')
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('protein', f'{align_redesign_af2_path}/{af2_f}')
    update_positions(structure, full_af2_pos)
    io = PDB.PDBIO()
    io.set_structure(structure)
    io.save(f'{align_redesign_af2_path}/{af2_f}')

    # calculate rmsd
    redesign_af2_pos = np.array(extract_backbone_positions(f'{align_redesign_af2_path}/{af2_f}')).reshape(-1,3)  
    rmsd = calc_rmsd(redesign_af2_pos, redesign_pos)[0]

    # stat clash
    with open(f'{redesign_path}/{f}', 'r') as lig_f:
        ligand_pdb_block = lig_f.read()
    ligand_pdb_block = ligand_pdb_block.split('\n')
    ligand_pdb_block = [l for l in ligand_pdb_block if 'HETATM' in l and l[-3] != 'H']
    ligand_pdb_block = '\n'.join(ligand_pdb_block)
    mol = Chem.MolFromPDBBlock(ligand_pdb_block, sanitize=False, removeHs=False)
    lig_pos = mol.GetConformer(0).GetPositions()
    frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    Fe_DPP = max(frags, key=lambda m: m.GetNumAtoms())
    DPP_pos = Fe_DPP.GetConformer(0).GetPositions()

    redesign_pos = extract_backbone_positions(f'{redesign_path}/{f}', bb=False)
    redesign_pos = [np.array(p) for p in redesign_pos]

    sequence = scores[scores['Name'] == af2_f[:af2_f.rfind('model')-1]]['Sequence'].item()
    distance = [np.linalg.norm(p[:, np.newaxis] - DPP_pos, axis=2).min() for p in redesign_pos]
    clash_num = (np.array(distance) < 3.2).sum()
    clash_index = np.where(np.array(distance) < 3.2)[0]
    DPP_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

    distance = [np.linalg.norm(p[:, np.newaxis] - lig_pos, axis=2).min() for p in redesign_pos]
    clash_index = np.where(np.array(distance) < 3.2)[0]
    lig_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

    # rmsd of HIS
    redesign_pos = extract_backbone_positions(f'{redesign_path}/{f}', bb=False)
    redesign_af2_pos = extract_backbone_positions(f'{align_redesign_af2_path}/{af2_f}', bb=False)
    redesign_H_pos = np.array(redesign_pos[interact_res_id]).reshape(-1,3)
    af2_H_pos = np.array(redesign_af2_pos[interact_res_id]).reshape(-1,3)
    interact_rmsd = np.sqrt(np.sum((redesign_H_pos-af2_H_pos) ** 2, axis=(0,1)) / redesign_H_pos.shape[0])
    
    rmsd_res[pdb_id] = {
        'rmsd': rmsd,
        'H_rmsd': interact_rmsd,
        'clash_num': clash_num,
        'DPP_clash_indices': DPP_clash_indices,
        'ligand_clash_indices': lig_clash_indices,
        'ori_pdb': f,
        'af2_pdb': af2_f
    }

In [23]:
import pandas as pd
for n, p in zip(scores['Name'], scores['lDDT']):
    rmsd_res[n]['plddt'] = p

selected_res = {}
for k, v in rmsd_res.items():
    if v['plddt'] > 87 and v['rmsd'] < 1.5:
        selected_res[k] = v

df = pd.DataFrame(selected_res)
df = df.T
df.to_csv(f'{redesign_af2_path}/scores.csv')

In [11]:
new_path = f'{path}/selected_res'
os.makedirs(new_path, exist_ok=True)
os.makedirs(f'{new_path}/redesign', exist_ok=True)
os.makedirs(f'{new_path}/redesign_af2', exist_ok=True)

for v in selected_res.values():
    o = v['ori_pdb']
    af2 = v['af2_pdb']
    shutil.copyfile(f'{redesign_path}/{o}', f'{new_path}/redesign/{o}')
    shutil.copyfile(f'{align_redesign_af2_path}/{af2}', f'{new_path}/redesign_af2/{af2}')

#### Adjust Ligand with HIS

In [5]:
## align af2 structures
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'

redesign_path = f'{path}/3.1_design_pocket_ligandMPNN'
redesign_list = os.listdir(redesign_path)

redesign_af2_path = f'{path}/4_redesign_af2'
redesign_af2_list = os.listdir(redesign_af2_path)
redesign_af2_list = [f for f in redesign_af2_list if '.pdb' in f]

align_redesign_af2_path = f'{path}/4_redesign_af2_aligned'
os.makedirs(align_redesign_af2_path, exist_ok=True)

for f in redesign_list:
    if '.pdb' not in f:
        continue
    redesign_pos = np.array(extract_backbone_positions(f'{redesign_path}/{f}')).reshape(-1,3)
    
    pdb_id = f.replace('.pdb', '')
    af2_f = [f for f in redesign_af2_list if pdb_id in f][0]
    redesign_af2_pos = np.array(extract_backbone_positions(f'{redesign_af2_path}/{af2_f}')).reshape(-1,3)  

    # align af2 structure
    rmsd, U = calc_rmsd(deepcopy(redesign_pos), deepcopy(redesign_af2_pos))
    full_af2_pos = np.concatenate(extract_backbone_positions(f'{redesign_af2_path}/{af2_f}', bb=False))
    full_af2_pos = (full_af2_pos - full_af2_pos.mean(0)) @ U + redesign_pos.mean(0)

    shutil.copyfile(f'{redesign_af2_path}/{af2_f}', f'{align_redesign_af2_path}/{af2_f}')
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('protein', f'{align_redesign_af2_path}/{af2_f}')
    update_positions(structure, full_af2_pos)
    io = PDB.PDBIO()
    io.set_structure(structure)
    io.save(f'{align_redesign_af2_path}/{af2_f}')

In [81]:
def extract_res_coord(structure, res_id):
    coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                hetflag, resseq, icode = residue.get_id()
                if resseq == res_id:
                    for atom in residue:
                        if atom.element != 'H':
                            coords.append(atom.coord)

    return coords

def extract_ligand_info(pdb_path):
    with open(pdb_path, 'r') as f:
        lig_lines = f.readlines()
    lig_lines = [l for l in lig_lines if 'HETATM' in l]
    lig_lines = [re.sub(r'\s+', ' ', l).split(' ') for l in lig_lines]
    return lig_lines

ref_lig_lines = extract_ligand_info('pl_Benchmark/CP_SS_TS/7jrq_CP_SS_TS.pdb')
ref_ligand_coords = [l[5:8] for l in ref_lig_lines]
ref_ligand_coords = np.array(ref_ligand_coords).astype(float)
ligand_atom_type = [l[-2] for l in ref_lig_lines]

parser = PDB.PDBParser(QUIET=True)
reference_structure = parser.get_structure('complex', 'pl_Benchmark/CP_SS_TS/7jrq_CP_SS_TS.pdb')
ref_H_coords = np.array(extract_res_coord(reference_structure, 64))

In [82]:
## align ligand with H
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'

rf_gen_list = os.listdir(f'{path}/0_diffusion')
rf_gen_list = [r for r in rf_gen_list if '.pdb' in r]
af2_gen_list = os.listdir(f'{path}/4_redesign_af2_aligned')
redesign_af2_aligned_complex_path = f'{path}/4_redesign_af2_aligned_complex'
os.makedirs(redesign_af2_aligned_complex_path, exist_ok=True)
fix_res_id = 61

for af2_f in af2_gen_list:
    parser = PDB.PDBParser(QUIET=True)
    af2_structure = parser.get_structure('complex', f'{path}/4_redesign_af2_aligned/{af2_f}')
    af2_H_coords = np.array(extract_res_coord(af2_structure, fix_res_id))

    # # use redesign output as reference
    # ref_lig_lines = extract_ligand_info(f"{path}/3.1_design_pocket_ligandMPNN/{af2_f[:af2_f.rfind('model')-1] + '.pdb'}")
    # ref_ligand_coords = [l[6:9] for l in ref_lig_lines]
    # ref_ligand_coords = np.array(ref_ligand_coords).astype(float)
    # ligand_atom_type = [l[-2] for l in ref_lig_lines]
    
    # parser = PDB.PDBParser(QUIET=True)
    # reference_structure = parser.get_structure('complex', f"{path}/3.1_design_pocket_ligandMPNN/{af2_f[:af2_f.rfind('model')-1] + '.pdb'}")
    # ref_H_coords = np.array(extract_res_coord(reference_structure, fix_res_id))

    rmsd, U_H = calc_rmsd(np.array(af2_H_coords), np.array(ref_H_coords))
    new_ligand_coords = (ref_ligand_coords - ref_H_coords.mean(0)) @ U_H + af2_H_coords.mean(0)
    
    with open(f'{path}/4_redesign_af2_aligned/{af2_f}', 'r') as f:
        lines = f.readlines()
    new_lines = []
    for l in lines:
        if l[:4] == 'ATOM':
            new_lines.append(l)
    lig_atom_id, lig_res_id, lig_chain_id = int(new_lines[-1][6:11]) + 1, int(new_lines[-1][22:26]) + 1, chr(ord(new_lines[-1][21]) + 1)
    for i, (pos, atom_type) in enumerate(zip(new_ligand_coords, ligand_atom_type)):
        j0 = str('HETATM').ljust(6)  # atom#6s
        j0_lig = str('ATOM').ljust(6)  # atom#6s
        j1 = str(lig_atom_id).rjust(5)  # aomnum#5d
        j1_lig = str(i+1).rjust(5)  # aomnum#5d
        j2 = str(atom_type).center(4)  # atomname$#4s
        j3 = 'HBA'.ljust(3)  # resname#1s
        j4 = lig_chain_id.rjust(1)  # Astring
        j5 = str(lig_res_id).rjust(4)  # resnum
        j6 = str('%8.3f' % (float(pos[0]))).rjust(8)  # x
        j7 = str('%8.3f' % (float(pos[1]))).rjust(8)  # y
        j8 = str('%8.3f' % (float(pos[2]))).rjust(8)  # z\
        j9 = str('%6.2f' % (1.00)).rjust(6)  # occ
        j10 = str('%6.2f' % (25.02)).ljust(6)  # temp
        j11 = str(atom_type).rjust(12)  # elname
        new_lines.append("%s%s %s %s %s%s    %s%s%s%s%s%s\n" % (j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11))
        lig_atom_id += 1
    
    with open(f'{redesign_af2_aligned_complex_path}/{af2_f}', 'w') as f:
        f.writelines(new_lines)
    f.close()

In [83]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn_constrain'

redesign_path = f'{path}/3.1_design_pocket_ligandMPNN'
align_redesign_af2_complex_path = f'{path}/4_redesign_af2_aligned_complex'
redesign_list = os.listdir(redesign_path)
redesign_af2_list = os.listdir(align_redesign_af2_complex_path)

interact_res_id = 61

rmsd_res = {}
scores = pd.read_csv(f'{redesign_af2_path}/redesign_sequence.csv')
for f in redesign_list:
    if '.pdb' not in f:
        continue
    redesign_pos = np.array(extract_backbone_positions(f'{redesign_path}/{f}')).reshape(-1,3)

    pdb_id = f.replace('.pdb', '')
    af2_f = [f for f in redesign_af2_list if pdb_id in f][0]
    
    # calculate rmsd
    redesign_af2_pos = np.array(extract_backbone_positions(f'{align_redesign_af2_complex_path}/{af2_f}')).reshape(-1,3)  
    rmsd = calc_rmsd(redesign_af2_pos, redesign_pos)[0]

    # stat clash
    with open(f'{align_redesign_af2_complex_path}/{af2_f}', 'r') as lig_f:
        ligand_pdb_block = lig_f.read()
    ligand_pdb_block = ligand_pdb_block.split('\n')
    ligand_pdb_block = [l for l in ligand_pdb_block if 'HETATM' in l and l[-3] != 'H']
    ligand_pdb_block = '\n'.join(ligand_pdb_block)
    mol = Chem.MolFromPDBBlock(ligand_pdb_block, sanitize=False, removeHs=False)
    lig_pos = mol.GetConformer(0).GetPositions()
    frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
    Fe_DPP = max(frags, key=lambda m: m.GetNumAtoms())
    DPP_pos = Fe_DPP.GetConformer(0).GetPositions()

    redesign_af2_pos = extract_backbone_positions(f'{align_redesign_af2_complex_path}/{af2_f}', bb=False)
    redesign_af2_pos = [np.array(p) for p in redesign_af2_pos]

    sequence = scores[scores['Name'] == af2_f[:af2_f.rfind('model')-1]]['Sequence'].item()
    distance = [np.linalg.norm(p[:, np.newaxis] - DPP_pos, axis=2).min() for p in redesign_af2_pos]
    clash_num = (np.array(distance) < 3.2).sum()
    clash_index = np.where(np.array(distance) < 3.2)[0]
    DPP_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

    distance = [np.linalg.norm(p[:, np.newaxis] - lig_pos, axis=2).min() for p in redesign_af2_pos]
    clash_index = np.where(np.array(distance) < 3.2)[0]
    lig_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]

    # rmsd of HIS
    redesign_pos = extract_backbone_positions(f'{redesign_path}/{f}', bb=False)
    redesign_af2_pos = extract_backbone_positions(f'{align_redesign_af2_complex_path}/{af2_f}', bb=False)
    redesign_H_pos = np.array(redesign_pos[interact_res_id]).reshape(-1,3)
    af2_H_pos = np.array(redesign_af2_pos[interact_res_id]).reshape(-1,3)
    interact_rmsd = np.sqrt(np.sum((redesign_H_pos-af2_H_pos) ** 2, axis=(0,1)) / redesign_H_pos.shape[0])
    
    rmsd_res[pdb_id] = {
        'rmsd': rmsd,
        'H_rmsd': interact_rmsd,
        'clash_num': clash_num,
        'DPP_clash_indices': DPP_clash_indices,
        'ligand_clash_indices': lig_clash_indices,
        'ori_pdb': f,
        'af2_pdb': af2_f
    }

In [84]:
import pandas as pd
for n, p in zip(scores['Name'], scores['lDDT']):
    rmsd_res[n]['plddt'] = p

selected_res = {}
for k, v in rmsd_res.items():
    if v['plddt'] > 87 and v['rmsd'] < 1.5:
        selected_res[k] = v

df = pd.DataFrame(selected_res)
df = df.T
# df.to_csv(f'{redesign_af2_path}/aligned_ref_H_scores.csv')

In [46]:
new_path = f'{path}/selected_res_aligned_ref_H'
os.makedirs(new_path, exist_ok=True)
os.makedirs(f'{new_path}/redesign', exist_ok=True)
os.makedirs(f'{new_path}/redesign_af2', exist_ok=True)

for v in selected_res.values():
    o = v['ori_pdb']
    af2 = v['af2_pdb']
    shutil.copyfile(f'{redesign_path}/{o}', f'{new_path}/redesign/{o}')
    shutil.copyfile(f'{align_redesign_af2_complex_path}/{af2}', f'{new_path}/redesign_af2/{af2}')

### stat rmsd

In [3]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_gen_list = os.listdir(f'{path}/2_af2_aligned')
af2_gen_list = [f for f in af2_gen_list if '.pdb' in f]
min_res = {}

for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = extract_backbone_positions(f'{path}/0_diffusion/{f}')
    pdb_id = f.split('.')[0]
    af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    min_rmsd = 9999999
    for af2_f in af2_files:
        af2_pos = extract_backbone_positions(f'{path}/2_af2_aligned/{af2_f}')
        rmsd = calc_rmsd(np.array(af2_pos).reshape(-1,3), np.array(rf_pos).reshape(-1,3))[0]
        if rmsd < min_rmsd:
            min_rmsd = rmsd
            min_file = af2_f
            
    if min_rmsd == 9999999:
        continue
    min_res[pdb_id.split('_')[-1]] = {
        'rmsd': min_rmsd,
        'ori_pdb': f,
        'af2_pdb': min_file
    }

In [6]:
new_path = 'CP_SS_TS_243'
os.makedirs(new_path, exist_ok=True)

for v in selected_res.values():
    rmsd = v['rmsd']
    o = v['ori_pdb']
    af2 = v['af2_pdb']
    if rmsd <= 2:
        shutil.copyfile(f'{path}/0_diffusion/{o}', f'{new_path}/{o}')
        shutil.copyfile(f'{path}/2_af2_aligned/{af2}', f'{new_path}/{af2}')

### stat 8A residue

In [36]:
path = 'generated_result/CP_SS_TS_mask2_71_proteinmpnn'
rf_gen_list = os.listdir(f'{path}/0_diffusion')
af2_path = 'CP_SS_TS_mask2_71_proteinmpnn'
af2_gen_list = os.listdir(af2_path)
af2_gen_list = [f for f in af2_gen_list if 'model' in f]

for f in rf_gen_list:
    if '.pdb' not in f:
        continue
    rf_pos = extract_backbone_positions(f'{path}/0_diffusion/{f}')
    pdb_id = f.split('.')[0]
    # af2_files = [f for f in af2_gen_list if pdb_id + '_' in f and 'native' not in f]
    af2_files = [f for f in af2_gen_list if '.pdb' in f]
    for af2_f in af2_files:
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}')
        rmsd = calc_rmsd(np.array(af2_pos).reshape(-1,3), np.array(rf_pos).reshape(-1,3))[0]
        
        af2_pos = extract_backbone_positions(f'{af2_path}/{af2_f}', bb=False)
        af2_pos = [np.array(p) for p in af2_pos]

        distance = [np.linalg.norm(p[:, np.newaxis] - lig_pos, axis=2).min() for p in af2_pos]
        clash_index = np.where(np.array(distance) < 8)[0]
        lig_clash_indices = [sequence[idx] + str(idx + 1) for idx in clash_index]
        fix_indices = ['A' + str(idx + 1) for idx in range(len(sequence)) if idx not in clash_index]
        
        break

In [38]:
af2_f

'0_diffusion_71_T0.3_0_236_model_4.0_r3_af2.pdb'

In [31]:
' '.join(lig_clash_indices)

'V17 L21 L26 L27 L32 L38 L39 L42 V43 T45 F46 L47 A49 T50 K58 H61 R62 A63 L64 T65 K66 K67 L68 N69 A70 L72 L78 L86 E87 A88 V89 V90 A91 A92 A93 N94 A95 A97 P98 D99 D100 P101 E102 L104 I105 A106 A107 L108 Y109 D110 G111 L112 A113 L117 I118 A123 L126 L130 I134'

In [37]:
' '.join(fix_indices)

'A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A16 A18 A19 A20 A22 A23 A24 A25 A28 A29 A30 A31 A33 A34 A36 A37 A52 A55 A56 A59 A60 A63 A71 A73 A74 A75 A76 A77 A79 A80 A81 A82 A83 A84 A85 A87 A88 A95 A96 A99 A114 A115 A116 A119 A120 A121 A122 A124 A125 A128 A129 A131 A132 A133 A135 A136 A137 A138 A139'