In [8]:
import traceback

import pandas as pd
from pathlib import Path

from tqdm import tqdm

from Bio.SVDSuperimposer import SVDSuperimposer
from Bio.PDB import PDBParser
from Bio.Align import PairwiseAligner
from Bio.SeqUtils import seq1
import numpy as np
import warnings

atoms_list = {
    "backbone": ["N", "CA", "C", "O"],
    "ca" : ["CA"]
}

def get_plddt_cat(plddt: float) -> str:
    if plddt < 70:
        return 'bad'
    if plddt < 90:
        return 'good'
    return 'high'

def get_rmsd_cat(rmsd: float) -> str:
    if rmsd > 5.0:
        return 'bad'
    if rmsd >= 2.0:
        return 'good'
    return 'high'

In [9]:
def get_plddts(pdb_code: str, benchmark_folder: Path, region_type: str, model_type: str):
    allowed_model_types = {"antibody", "antigen"}
    if model_type not in allowed_model_types:
        raise ValueError(f"Unrecognised {model_type=}, must be one of {allowed_model_types}")

    allowed_region_types = {'CDR-EpiVague', 'Para-Epi', 'full'}
    if region_type not in allowed_region_types:
        raise ValueError(f"Unrecognised {region_type=}, must be one of {allowed_region_types}")
    parser = PDBParser()

    if model_type == 'antibody':
        native = benchmark_folder/f'{pdb_code}/{pdb_code}_true_complex.pdb'
        model = benchmark_folder/f'{pdb_code}/AF2_{pdb_code}_antibody_model_imgt.pdb'

        model_chains = {chain.id: chain for chain in parser.get_structure('model', model).get_chains()}
        native_chains = {chain.id: chain for chain in parser.get_structure('native', native).get_chains()
                         if chain.id in model_chains.keys()}

        if set(native_chains.keys()) != set(model_chains.keys()):
            raise ValueError("Model chain ids not equal to native chain ids.")
    else:
        native = benchmark_folder/f'{pdb_code}/{pdb_code}_antigen.pdb'
        model = benchmark_folder/f'{pdb_code}/{pdb_code}_AF2_{model_type}_model.pdb'
        native_chains = {chain.id: chain for chain in parser.get_structure('native', native).get_chains()}
        model_chains = {chain.id: chain for chain in parser.get_structure('model', model).get_chains()}

    chain_id_mappings = {native_chain_id : model_chain_id for native_chain_id, model_chain_id in zip(native_chains.keys(),
                                                                                                     model_chains.keys())}

    region_native_def_nums = {chain: set() for chain in native_chains}
    region_model_def_nums = {chain: set() for chain in model_chains}

    if model_type == 'antibody':
        if region_type == 'CDR-EpiVague':
            with open(benchmark_folder/f'{pdb_code}/{pdb_code}_residue_constraints_antibody.csv') as file:
                file.readline()
                for line in file:
                    words = line.strip().split(',')
                    if words[1][-1].isalpha():
                        insert_code = words[1][-1]
                        number = int(words[1][:-1])
                    else:
                        insert_code = ' '
                        number = int(words[1])
                    model_resname =  model_chains[words[0]][(' ', number, insert_code)].resname
                    if model_resname != words[2].upper() :
                        raise ValueError(f"For {pdb_code=}, chain id {words[0]} residue {number}{insert_code}, "
                                         f"got mismatching residue to constraint, {model_resname=}, constraint={words[2]}")
                    residue_key = (' ', number, insert_code)
                    if residue_key not in native_chains[words[0]].child_dict.keys():
                        warnings.warn(f'For {pdb_code} and {region_type=} chain {words[0]} {residue_key} not found in native model')
                        continue
                    region_native_def_nums[words[0]].add(residue_key)
                    region_model_def_nums[words[0]].add(residue_key)
        elif region_type == 'Para-Epi':
            with open(benchmark_folder/f'{pdb_code}/{pdb_code}_constraint_pairs.txt') as file:
                file.readline()
                for line in file:
                    antibody_line = line.split(':')[0]
                    words = antibody_line.strip().split(',')
                    if words[1][-1].isalpha():
                        insert_code = words[1][-1]
                        number = int(words[1][:-1])
                    else:
                        insert_code = ' '
                        number = int(words[1])
                    model_resname =  model_chains[words[0]][(' ', number, insert_code)].resname
                    if model_resname != words[2].upper() :
                        raise ValueError(f"For {pdb_code=}, chain id {words[0]} residue {number}{insert_code}, "
                                         f"got mismatching residue to constraint, {model_resname=}, constraint={words[2]}")
                    residue_key = (' ', number, insert_code)
                    if residue_key not in native_chains[words[0]].child_dict.keys():
                        warnings.warn(f'For {pdb_code} and {region_type=} chain {words[0]} {residue_key} not found in native model')
                        continue
                    region_native_def_nums[words[0]].add(residue_key)
                    region_model_def_nums[words[0]].add(residue_key)
        else:
            for native_chain_id, model_chain_id in chain_id_mappings.items():
                native_chain = native_chains[native_chain_id]
                for residue in native_chain.get_residues():
                    if residue.resname != 'PCA':
                        region_native_def_nums[native_chain_id].add(residue.id)
                        region_model_def_nums[model_chain_id].add(residue.id)
    else:
        if region_type == 'CDR-EpiVague':
            with open(benchmark_folder/f'{pdb_code}/{pdb_code}_residue_constraints_{model_type}.csv') as file:
                file.readline()
                for line in file:
                    words = line.strip().split(',')
                    if words[1][-1].isalpha():
                        insert_code = words[1][-1]
                        number = int(words[1][:-1])
                    else:
                        insert_code = ' '
                        number = int(words[1])
                    region_native_def_nums[words[0]].add((' ', number, insert_code))
                    native_resname =  native_chains[words[0]][(' ', number, insert_code)].resname
                    if native_resname != words[2].upper() :
                        raise ValueError(f"For {pdb_code=}, chain id {words[0]} residue {number}{insert_code}, "
                                         f"got mismatching residue to constraint, {native_resname=}, constraint={words[2]}")
            with open(benchmark_folder/f'{pdb_code}/{pdb_code}_af2_residue_constraints_{model_type}.csv') as file:
                file.readline()
                for line in file:
                    words = line.strip().split(',')
                    if words[1][-1].isalpha():
                        insert_code = words[1][-1]
                        number = int(words[1][:-1])
                    else:
                        insert_code = ' '
                        number = int(words[1])
                    region_model_def_nums[words[0]].add((' ', number, insert_code))
                    model_resname =  model_chains[words[0]][(' ', number, insert_code)].resname
                    if model_resname != words[2].upper() :
                        raise ValueError(f"For {pdb_code=}, chain id {words[0]} residue {number}{insert_code}, "
                                         f"got mismatching residue to constraint, {model_resname=}, constraint={words[2]}")

        elif region_type == 'Para-Epi':
            with (benchmark_folder/f'{pdb_code}/{pdb_code}_constraint_pairs.txt').open() as inf:
                for line in inf.readlines()[1:]:
                    antigen_line = line.split(':')[1]
                    chain_id, resnum, restype = antigen_line.strip().split(',')
                    resnum = int(resnum)
                    native_restype =  native_chains[chain_id][(' ',resnum,' ')].resname
                    if native_restype != restype:
                        raise ValueError(f"For {pdb_code=}, chain id {chain_id} residue {resnum}, "
                                             f"got mismatching residue to constraint, {native_restype=}, constraint={restype}")
                    region_native_def_nums[chain_id].add((' ', resnum, ' '))
            with (benchmark_folder/f'{pdb_code}/{pdb_code}_AF2_constraint_pairs.txt').open() as inf:
                for line in inf.readlines()[1:]:
                    antigen_line = line.split(':')[1]
                    chain_id, resnum, restype = antigen_line.strip().split(',')
                    resnum = int(resnum)
                    model_restype =  model_chains[chain_id][(' ',resnum,' ')].resname
                    if model_restype != restype:
                        raise ValueError(f"For {pdb_code=}, chain id {chain_id} residue {resnum}, "
                                             f"got mismatching residue to constraint, {model_restype=}, constraint={restype}")
                    region_model_def_nums[chain_id].add((' ', resnum, ' '))
        else:
            for native_chain_id, model_chain_id in chain_id_mappings.items():
                nat_ress = list(native_chains[native_chain_id].get_residues())
                model_ress = list(model_chains[model_chain_id].get_residues())

                seq_native = "".join([seq1(res.resname) for res in nat_ress])
                seq_model = "".join([seq1(res.resname) for res in model_ress])

                aligner = PairwiseAligner()

                alignment = aligner.align(seq_native,seq_model)

                nat_ress_aligned = sum([nat_ress[start:end] for start, end in alignment[0].aligned[0]], start=[])
                model_ress_aligned = sum([model_ress[start:end] for start, end in alignment[0].aligned[1]], start=[])
                for nat_res, model_res in zip(nat_ress_aligned, model_ress_aligned):
                    region_native_def_nums[native_chain_id].add(nat_res.id)
                    region_model_def_nums[model_chain_id].add(model_res.id)


    try:
        native_region_res = sum([[res for res in native_chains[chain_id].get_residues() if res.id in region_native_def_nums[chain_id]]
                                 for chain_id in chain_id_mappings.keys()],start=[])
        model_region_res = sum([[res for res in model_chains[chain_id].get_residues() if res.id in region_model_def_nums[chain_id]]
                                 for chain_id in chain_id_mappings.values()],start=[])
        native_region_atom_coords = [list(atom.coord)  for res in native_region_res for atom in res
                                       if atom.get_id() in atoms_list['backbone']]

        model_region_atom_coords = [list(atom.coord) for res in model_region_res for atom in res
                                       if atom.get_id() in atoms_list['backbone']]
    except Exception as e:
        print(traceback.format_exc())
        raise ValueError(f"For {pdb_code=} got error {e}")

    svd = SVDSuperimposer()
    svd.set(np.array(native_region_atom_coords), np.array(model_region_atom_coords))
    svd.run()
    rmsd_region = svd.get_rms()

    model_region_atom_plddt = [atom.bfactor  for res in model_region_res for atom in res
                                           if atom.get_id() in atoms_list['ca']]

    region_ave_plddt = np.mean(model_region_atom_plddt)

    row = {'pdb': pdb_code, 'rmsd_region': rmsd_region, 'plddt_ave_region': region_ave_plddt}
    return row

In [10]:
records = []
benchmark_folder = Path('../../benchmark_haddock_27_July_2024')
for path in tqdm(list(benchmark_folder.iterdir())):
    if path.is_dir():
        pdb_code = path.name
    else:
        continue
    try:
        para_epi_record = get_plddts(pdb_code,benchmark_folder,region_type='Para-Epi',model_type='antibody')
        vague_record = get_plddts(pdb_code,benchmark_folder,region_type='CDR-EpiVague',model_type='antibody')
        full_record = get_plddts(pdb_code,benchmark_folder,region_type='full',model_type='antibody')
        record = {'pdb': full_record['pdb'],
              'rmsd_full': full_record['rmsd_region'], 'plddt_ave_full': full_record['plddt_ave_region'],
              'rmsd_full_cat': get_rmsd_cat(full_record['rmsd_region']),
              'plddt_full_cat': get_plddt_cat(full_record['plddt_ave_region']),
              'rmsd_vague': vague_record['rmsd_region'], 'plddt_ave_vague': vague_record['plddt_ave_region'],
              'rmsd_vague_cat': get_rmsd_cat(vague_record['rmsd_region']),
              'plddt_vague_cat': get_plddt_cat(vague_record['plddt_ave_region']),
              'rmsd_para_epi': para_epi_record['rmsd_region'], 'plddt_ave_para_epi': para_epi_record['plddt_ave_region'],
              'rmsd_para_epi_cat': get_rmsd_cat(para_epi_record['rmsd_region']),
              'plddt_para_epi_cat': get_plddt_cat(para_epi_record['plddt_ave_region']),
              }

        records.append(record)
    except Exception as e:
        print(f"Got error {e} for {pdb_code=}.")
df_antibody = pd.DataFrame().from_records(records)

100%|██████████| 84/84 [00:22<00:00,  3.80it/s]


In [11]:
df_antibody.to_csv('../data/AF2_antibody_rmsd_plddt_multi_regions.csv')

In [12]:
records = []
benchmark_folder = Path('../../benchmark_haddock_27_July_2024')

for path in tqdm(list(benchmark_folder.iterdir())):
    if path.is_dir():
        pdb_code = path.name
    else:
        continue
    try:
        para_epi_record = get_plddts(pdb_code,benchmark_folder,region_type='Para-Epi',model_type='antigen')
        vague_record = get_plddts(pdb_code,benchmark_folder,region_type='CDR-EpiVague',model_type='antigen')
        full_record = get_plddts(pdb_code,benchmark_folder,region_type='full',model_type='antigen')
        record = {'pdb': full_record['pdb'],
                  'rmsd_full': full_record['rmsd_region'], 'plddt_ave_full': full_record['plddt_ave_region'],
                  'rmsd_full_cat': get_rmsd_cat(full_record['rmsd_region']),
                  'plddt_full_cat': get_plddt_cat(full_record['plddt_ave_region']),
                  'rmsd_vague': vague_record['rmsd_region'], 'plddt_ave_vague': vague_record['plddt_ave_region'],
                  'rmsd_vague_cat': get_rmsd_cat(vague_record['rmsd_region']),
                  'plddt_vague_cat': get_plddt_cat(vague_record['plddt_ave_region']),
                  'rmsd_para_epi': para_epi_record['rmsd_region'], 'plddt_ave_para_epi': para_epi_record['plddt_ave_region'],
                  'rmsd_para_epi_cat': get_rmsd_cat(para_epi_record['rmsd_region']),
                  'plddt_para_epi_cat': get_plddt_cat(para_epi_record['plddt_ave_region']),
                  }

        records.append(record)
    except Exception as e:
       print(f"Got error {e} for {pdb_code=}.")
df_antigen = pd.DataFrame().from_records(records)

100%|██████████| 84/84 [00:17<00:00,  4.72it/s]


In [13]:
df_antigen.to_csv('../data/AF2_antigen_rmsd_plddt_multi_regions.csv')