# 2025.05.21 전처리

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

data = pd.read_csv('data/AF_data_info.csv')
data.head(1)

## Read PDB File

In [2]:
def read_pdb(pdb_path):
    records = []
    with open(pdb_path, 'r') as f:
        for line in f:
            if line.startswith(('ATOM', 'HETATM')):
                record = {
                    'record': line[0:6].strip(),
                    'atom_serial': int(line[6:11]),
                    'atom_name': line[12:16].strip(),
                    'alt_loc': line[16],
                    'res_name': line[17:20].strip(),
                    'chain_id': line[21],
                    'res_seq': int(line[22:26]),
                    'i_code': line[26],
                    'x': float(line[30:38]),
                    'y': float(line[38:46]),
                    'z': float(line[46:54]),
                    'occupancy': float(line[54:60]),
                    'b_factor': float(line[60:66]),
                    'element': line[76:78].strip(),
                    'charge': line[78:80].strip()
                }
                records.append(record)
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']
    return records

## Read Mol2 to PDB format

In [3]:
def read_mol2(mol2_path, ligand=False):
    with open(mol2_path) as f:
        lines = f.readlines()

    if ligand:
        atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
        bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
        atom_end = bond_start
        atom_lines = lines[atom_start+1:atom_end]
        return ligand_reader(atom_lines)
    
    else:
        atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
        bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
        chain_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SUBSTRUCTURE'))
        set_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SET'))

        atom_end = bond_start
        bond_end = chain_start
        chain_end = set_start

        atom_lines = lines[atom_start+1:atom_end]
        bond_lines = lines[bond_start+1:bond_end]
        chain_lines = lines[chain_start+1:chain_end]
        return to_pdb_format(atom_lines, bond_lines, chain_lines)


def ligand_reader(atom):
    records = []
    for idx, line in enumerate(atom):
        fields = line.split()
        atom_id = int(fields[0])
        atom_name = fields[1]
        x, y, z = map(float, fields[2:5])
        atom_type = fields[5]
        resid = int(fields[6])
        resname_full = fields[7]
        res_name = resname_full[:3]
        occupancy = 1.0
        b_factor = 0.0
        element = atom_type.split('.')[0][0].upper()
        charge = fields[8] if len(fields) > 8 and fields[8].replace('.', '', 1).replace('-', '', 1).isdigit() else ''
        record = {
            'record': 'ATOM',
            'atom_serial': atom_id,
            'atom_name': atom_name,
            'alt_loc': '',
            'res_name': res_name,
            'chain_id': '',
            'res_seq': resid,
            'i_code': '',
            'x': x,
            'y': y,
            'z': z,
            'atom_type': atom_type,
            'occupancy': occupancy,
            'b_factor': b_factor,
            'element': element,
            'charge': charge
        }
        records.append(record)
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']       
    return records.reset_index(drop=True)


def parse_substructure(lines):
    substructures = []
    for line in lines:
        parts = line.strip().split()
        sub = {
            'id': int(parts[0]),
            'name': parts[1],
            'root_atom': int(parts[2]),
            'type': parts[3],
            'dict_type': parts[4],
            'chain': parts[5],
            'residue': parts[6],
            'unk': parts[7],
            'status': parts[8],
        }
        substructures.append(sub)
    return pd.DataFrame(substructures)


def to_pdb_format(atom, bond, chain):
    records = []
    chains = parse_substructure(chain)
    
    box = {}
    prev = None
    root_atom_chain = None
    chain_atoms = [1]
    for idx, line in enumerate(atom):
        fields = line.split()
        atom_id = int(fields[0])
        atom_name = fields[1]
        x, y, z = map(float, fields[2:5])
        atom_type = fields[5]
        resid = int(fields[6])
        
        resname_full = fields[7]
        res_name = resname_full[:3]
        
        if idx == 0:
            prev = resid
        else:
            if resid == prev:
                chain_atoms.append(atom_id)
            else:
                for a in chain_atoms:
                    box[a] = root_atom_chain
                prev = resid
                chain_atoms = [atom_id]

        if atom_id in chains['root_atom'].values:
            chain_id = chains[chains['root_atom'] == atom_id]['chain'].values[0]
            root_atom_chain = chain_id
        else:
            chain_id = None
            
        occupancy = 1.0
        b_factor = 0.0
        element = atom_type.split('.')[0][0].upper()
        charge = fields[8] if len(fields) > 8 and fields[8].replace('.', '', 1).replace('-', '', 1).isdigit() else ''
        record = {
            'record': 'ATOM',
            'atom_serial': atom_id,
            'atom_name': atom_name,
            'alt_loc': '',
            'res_name': res_name,
            'chain_id': chain_id,
            'res_seq': resid,
            'i_code': '',
            'x': x,
            'y': y,
            'z': z,
            'atom_type': atom_type,
            'occupancy': occupancy,
            'b_factor': b_factor,
            'element': element,
            'charge': charge
        }
        records.append(record)
    
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']       
    records['chain_id'] = records['atom_serial'].map(box)
    return records.reset_index(drop=True)

## Parsing Pocket

In [4]:
amino_acids = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
    }

rev_amino_acids = {v: k for k, v in amino_acids.items()}

modified_residue_map = {
    "MSE": "M",  # Selenomethionine → Methionine
    "SEP": "S",  # Phosphoserine → Serine
    "TPO": "T",  # Phosphothreonine → Threonine
    "PTR": "Y",  # Phosphotyrosine → Tyrosine
    "HYP": "P",  # Hydroxyproline → Proline
    "KCX": "K",  # Carboxylysine → Lysine
    "CSO": "C",  # Oxidized cysteine → Cysteine
    "CGU": "E",  # γ-carboxy-glutamate → Glutamic Acid
    "F2F": "F",  # Fluorophenylalanine → Phenylalanine
    "ASH": "D",  # Protonated Aspartic Acid → Aspartic Acid
    "GLH": "E",  # Protonated Glutamic Acid → Glutamic Acid
    "CYX": "C",  # Disulfide-bonded Cysteine → Cysteine
    "HID": "H",  # Neutral Histidine (delta-protonated)
    "HIE": "H",  # Neutral Histidine (epsilon-protonated)
    "HIP": "H",  # Positively charged Histidine
}
def parse_residue_and_align(lst, chain=True):
    result = []
    prev = None
    seen_counts = {}
    
    if chain:
        
        for num, chain in lst:
            if num == prev:
                continue
            count = seen_counts.get(num, 0) + 1
            seen_counts[num] = count
            key = (f"{num}" if count == 1 else f"{num}-{count}", chain)
            result.append(key)
            prev = num
        result = {k: v for k, v in result}
    
    else:    
        for num in lst:
            if num == prev:
                result.append(key)
                continue 
            count = seen_counts.get(num, 0) + 1
            seen_counts[num] = count
            key = f"{num}" if count == 1 else f"{num}-{count}"
            result.append(key)
            prev = num
            
    return result


def read_residue(res):
    global amino_acids

    init = 0
    result = []
    res_dict = {}
    already = set()
    for r, c, idx, ic in res:
        if r == 'HOH':
            continue
        if r not in amino_acids:
            conv_r = 'X'
        else:
            conv_r = amino_acids[r]
        key = (c, idx, ic)
        if key in already:
            continue
        already.add(key)
        result.append(conv_r)
        res_dict[key] = init
        init += 1
    rev_res_dict = {v: k for k, v in res_dict.items()}
    return ''.join(result), res_dict, rev_res_dict

## PyMOL Modules

In [5]:

import os
import pymol
from pymol import cmd
from pathlib import Path
from tqdm import tqdm

# pymol.finish_launching() # if you see the work in pymol, you need to run this


def align_and_save(apo_path, holo_path, output_path):
    cmd.reinitialize()  # Clear previous structures
    cmd.load(str(apo_path), 'apo')
    cmd.load(str(holo_path), 'holo')
    alignment_info = cmd.super('apo', 'holo')
    cmd.save(str(output_path), 'apo')
    rmsd = alignment_info[0]  # RMSD
    return rmsd


def point_mutation(apo_path, output_path, mutations):
    global rev_amino_acids
    cmd.reinitialize()  # Clear previous structures
    cmd.load(apo_path, 'apo')
    cmd.wizard('mutagenesis')
    for row in mutations:
        chain, resid, ap_res, ho_res = map(str, row)
        # PyMOL residue selection: /object//chain/resi
        if ho_res == 'X':
            continue
        selection = f'/apo//{chain}/{resid}/'
        print(f'selcction: {selection}')
        print(f'apo: {ap_res} -> holo: {ho_res}')
        cmd.get_wizard().set_mode(rev_amino_acids[ho_res])
        cmd.get_wizard().do_select(selection)
        cmd.get_wizard().apply()
    cmd.set_wizard()
    cmd.save(output_path, 'apo')

## APO Align

In [6]:
import pickle
from pathlib import Path
import logging

# 로그 설정
out = Path('data/raw/af_align')
out.mkdir(parents=True, exist_ok=True)
log_path = out / "processing.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s: %(message)s',
    handlers=[
        logging.FileHandler(log_path, mode='w'),
        logging.StreamHandler()
    ]
)

In [None]:
data.head(1)

In [None]:
pocket_idx = pd.read_csv('data/raw/pocket_idx.tsv', sep='\t')
pocket_idx.head(2)

In [None]:
no_pdbbind = ['1mh5', '1d6v', '1ct8', '1i7z', '1a4k', '1a0q', '1i6v', '1c12', '2qhr', '4nyi', '4nyj', '4nym', '2hrp', '3eql', '4kmu', '4kn4', '4kn7', '2mpa', '1kcs', '1zyr']
no_csar = ['1gz9', '2v7t', '2j4k', '2v7u', '1swk', '2vhw', '1gzc', '1bcj', '2hr6']
no_match = no_pdbbind + no_csar
print(len(no_match)) # no match pocket

In [None]:
## 72mins

import pandas as pd
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

def sequence_reader(df):
    atom_end = df[df['record'] == 'ATOM'].index[-1]
    main_structure = df[:atom_end]
    read, ridx, r_ridx = read_residue(df[:atom_end][['res_name', 'chain_id', 'res_seq', 'i_code']].values)
    return main_structure, read, ridx, r_ridx


def find_root_atom(df):
    return df[df['atom_name'] == 'CA'][['chain_id', 'res_seq', 'i_code']].values


def find_atom(df, keys):
    result = []
    for k in keys:
        result.append(df[(df['chain_id'] == k[0]) & (df['res_seq'] == k[1]) & (df['i_code'] == k[2])])
    result = pd.concat(result).reset_index(drop=True)
    return result


count = 0
max_mutation = 0

rmsd_results = {}
apo_path = Path('data/raw/afdb')
holo_path = Path('data/raw/pdb')

# output path
apo_align_path = Path('data/raw/af_align')
apo_align_path.mkdir(exist_ok=True, parents=True)

mutant_path = Path('data/raw/mutant')
mutant_path.mkdir(exist_ok=True, parents=True)

a2h_path = Path('data/processed/a2h')
a2h_path.mkdir(exist_ok=True, parents=True)

# main loop
num = 0
for idx, row in data.iterrows():
    pdb = row['PDB']
    uni = row['UniProt']
    st = row['SET']

    # logging.info(f"\n[{idx+1}/{len(data)}] Processing code: {pdb}")
    
    if pdb in no_match:
        continue

    if st == 'PDB':
        continue
    
    elif st != 'CSAR':
        continue
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"    
        holo = holo_path / st / pdb / f"{pdb}_protein.pdb"
        pocket = holo_path / st / pdb / f"{pdb}_pocket.pdb"
        
        # read pdb format
        ap = read_pdb(apo)
        ho = read_pdb(holo)
        pk = read_pdb(pocket)
        
        # read pocket and parse chain_id from holo
        ho_keys = parse_residue_and_align(ho[['res_seq', 'chain_id']].values)
        pk['c_key'] = parse_residue_and_align(pk['res_seq'].values, chain=False)
        pk['chain_id'] = pk['c_key'].map(ho_keys)
        pk = pk.drop(columns=['c_key']).reset_index(drop=True)
        pk.loc[pk['res_name'] == 'HOH', 'chain_id'] = ' '

    else: # CSAR
        # continue # pocket이 잘못만들어짐.. ㅅㅂ 내 시간.. 그냥 mol2파일에서 9옴스트롬이내 추출 ㄱㄱ
        
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"
        holo = [f for f in(holo_path / st / pdb).glob('*complex.mol2')][0]
        pocket = pocket_idx[pocket_idx['PDB'] == pdb]['Residue'].values[0]
        
        # read for pdb format
        ap = read_pdb(apo)
        ho = read_mol2(holo)
        
        # read pocket index from DeepDTAF/CAPLA SSEs
        if pdb in ['2jj3', '2fai', '2r6w', '2z4b']:
            continue

        pk_sse = []
        for pk_idx in pocket.split(','):
            
            if pk_idx[0].isalpha():
                pk_chain, pk_res_seq = pk_idx[0], pk_idx[1:]
            else:
                pk_chain, pk_res_seq = '2', pk_idx[1:]

            if pk_res_seq[-1].isalpha():
                pk_res_seq, pk_i_code = pk_res_seq[:-1], pk_res_seq[-1]
            else:
                pk_i_code = ''
            
            pk_sse.append((pk_chain, int(pk_res_seq), pk_i_code))
        pk = find_atom(ho, pk_sse)

    num += 1
    if num % 1000 == 0:
        print(num)

    # # common process
    # extract amino acid sequence
    ho_main, ho_read, ho_ridx, ho_r_ridx = sequence_reader(ho)
    pk_main, pk_read, pk_ridx, pk_r_ridx = sequence_reader(pk)
    ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    # print(len(ho_ridx), ho_read)
    # print(len(pk_ridx), pk_read)
    # print(len(ap_ridx), ap_read)

    # search pocket in holo
    pk_ca = find_root_atom(pk_main)
    ho_pk = find_atom(ho_main, pk_ca)
    ho_ca = find_root_atom(ho_pk)
    ho_pk_idx = sorted([ho_ridx[(idx[0], idx[1], idx[2])] for idx in ho_ca])
    
    # extract pocket amino acids in holo (save position)
    ho_pk_read = ho_read[ho_pk_idx[0]:ho_pk_idx[-1]+1]
    ho_pk_res_dict = {ho_pk_idx[0] + i:i  for i in range(len(ho_pk_read))}
    ho_pk_res_idx = [ho_pk_res_dict[i] for i in ho_pk_idx] # for mutation position
    
    # sequence alignment (score params: match, mismatch, gap open, gap extension)
    alignments = pairwise2.align.localms(ho_pk_read, ap_read, 2, -1, -5, -1)
    ho_align, ap_align = alignments[0][:2]

    # mapping position & check mutation
    mapping_align = {}
    align_position = {}
    mutations = []
    st_pos = ho_pos = ap_pos = -1

    for a, b in zip(ho_align, ap_align):
        if a != '-':
            ho_pos += 1
            st_pos += 1
        if b != '-':
            ap_pos += 1
            st_pos += 1
        if a != '-' and b != '-':
            st_pos += 1
            align_position[st_pos] = ho_pos
            mapping_align[ho_pos] = ap_pos
            if a != b and ho_pos in ho_pk_res_idx:
                mutations.append((ap_pos, b, a))
                
    # check mutation
    if mutations:
        max_mutation = max(max_mutation, len(mutations))

        transform = []
        for mut in mutations:
            ap_pk_idx, ap_res, ho_res = mut
            ap_chain, ap_resid, _ = ap_r_ridx[ap_pk_idx]
            transform.append((ap_chain, ap_resid, ap_res, ho_res))
        
        mut_out = mutant_path / f"AF-{uni}-mut-{pdb}.pdb"    
        point_mutation(apo, mut_out, transform)

        # reload
        apo = mut_out
        ap = read_pdb(apo)
        ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)

    # find pocket in apo (ho_pk_idx -> ho_pk_res_dict -> mapping align)
    ap_ca = []
    rm_ca = []
    for i in ho_pk_idx:
        hp_key = ho_pk_res_dict[i]  # atom within pocket
        if hp_key in mapping_align.keys(): # atom within mapping & pocket
            ap_ca.append(ap_r_ridx[mapping_align[hp_key]]) # find 
        else:
            rm_ca.append(i)
    if rm_ca: # select atom atoms only in apo
        ho_pk = find_atom(ho_main, [ho_r_ridx[i] for i in ho_pk_idx if not i in rm_ca])
    ap_pk = find_atom(ap_main, ap_ca) # find pocket in apo

    # superimpose
    align_out = apo_align_path / f"AF-{uni}-align-{pdb}.pdb"
    rmsd = align_and_save(apo, holo, align_out)
    rmsd_results[pdb] = rmsd
    
    # mutated & aligned apo structure
    aa = read_pdb(align_out)
    aa_main, aa_read, aa_ridx, aa_r_ridx = sequence_reader(aa)
    aa_pk = find_atom(aa_main, ap_ca)
    
    # save
    a2h_out = a2h_path / f"{pdb}_a2h.pkl"
    a2h = {'APO': aa_pk, 'HOLO': ho_pk}
    with open(a2h_out, 'wb') as f:
        pd.to_pickle(a2h, f)

    # if len(mutations) >= 1:
    #     print(pdb)
    #     print(format_alignment(*alignments[0]))
    #     print(sequence_reader(ho_pk)[1])
    #     print(sequence_reader(ap_pk)[1])        
    #     print()
    #     break
    

print(count)

In [None]:
with open(a2h_out, 'rb') as f:
    loaded_bundle = pd.read_pickle(f)
loaded_bundle['HOLO']

In [None]:
loaded_bundle['APO']

In [None]:
data['SET'].value_counts()

In [None]:
sorted(rmsd_results.items(), key=lambda x: x[1], reverse=True)

In [29]:
intereset = '2xxw'
with open(a2h_path / f"{intereset}_a2h.pkl", 'rb') as f:
    lb = pd.read_pickle(f)

In [None]:
lb['APO']

In [None]:
lb['HOLO']


In [None]:
print(sequence_reader(lb['APO'])[1])
print(sequence_reader(lb['HOLO'])[1])


# 2025.05.23 
- sampling validation set in refined set (randomly)
- CSAR pocket search

# CSAR

In [10]:
from pathlib import Path
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist

def find_pocket_residues(complex_df: pd.DataFrame, ligand_df: pd.DataFrame, cutoff: float = 10.0):
    ligand_df = ligand_df.reset_index(drop=True)
    complex_df = complex_df.reset_index(drop=True)
    ligand_coords = ligand_df[['x', 'y', 'z']].to_numpy()
    protein_df = complex_df[~complex_df['res_name'].isin(ligand_df['res_name'].unique())].copy()
    protein_df = protein_df.reset_index(drop=True)
    protein_coords = protein_df[['x', 'y', 'z']].to_numpy()
    
    # calculate distance between ligand and protein
    dist_matrix = cdist(ligand_coords, protein_coords)
    
    # extract indices of atoms within cutoff
    close_idx = np.where(dist_matrix <= cutoff)[1]
    close_atoms = protein_df.iloc[close_idx]

    # extract pocket residues
    pocket_residues = {
        (row['chain_id'], row['res_seq'], row['i_code'])
        for _, row in close_atoms.iterrows()
    }
    pocket_df = find_atom(complex_df, pocket_residues)
    return pocket_df


def write_pocket_to_mol2(pocket_df: pd.DataFrame, output_path: str):
    pocket_df = pocket_df.reset_index(drop=True)
    pocket_df['atom_id'] = np.arange(1, len(pocket_df) + 1)
    pocket_df[['x', 'y', 'z', 'charge']] = pocket_df[['x', 'y', 'z', 'charge']].astype(float)

    with open(output_path, 'w') as f:
        # MOLECULE section
        f.write("@<TRIPOS>MOLECULE\n")
        f.write("POCKET\n")
        f.write(f"{len(pocket_df)} 0 1\n")
        f.write("PROTEIN\n")
        f.write("USER_CHARGES\n\n")

        # ATOM section
        f.write("@<TRIPOS>ATOM\n")
        for _, row in pocket_df.iterrows():
            f.write(
                f"{row['atom_id']:>7} {row['atom_name']:<8} "
                f"{row['x']:>10.4f} {row['y']:>10.4f} {row['z']:>10.4f} "
                f"{row['atom_type']:<6} {row['res_seq']:>5} {row['res_name']:<8} "
                f"{row.get('charge', 0.0):>9.4f}\n"
            )

        # BOND section (empty)
        f.write("@<TRIPOS>BOND\n")


def read_csar(mol2_path):
    mol2_path = Path(mol2_path)
    output_path = mol2_path.parent / f'{mol2_path.parent.stem}_ligand.mol2'

    with open(mol2_path) as f:
        lines = f.readlines()

    atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
    bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
    chain_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SUBSTRUCTURE'))
    set_start = next((i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SET')), len(lines))

    atom_lines = lines[atom_start+1:bond_start]
    bond_lines = lines[bond_start+1:chain_start]
    chain_lines = lines[chain_start+1:set_start]

    # --- parse ligand definition ---
    ligand_line_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("LIGAND"))

    ligand_atom_ids = []
    i = ligand_line_idx + 1
    while i < len(lines):
        line = lines[i].strip()
        if not line or not line.split()[0].isdigit():
            break
        ligand_atom_ids.extend(int(token) for token in line.split() if token.isdigit())
        i += 1

    ligand_atom_ids_set = set(ligand_atom_ids[1:])
    ligand_atom_lines = [l for l in atom_lines if int(l.split()[0]) in ligand_atom_ids_set]
    ligand_atom_id_map = {int(l.split()[0]): idx+1 for idx, l in enumerate(ligand_atom_lines)}

    # --- extract ligand bond ---
    ligand_bond_lines = []
    bond_num = 1
    for l in bond_lines:
        fields = l.split()
        if len(fields) < 4: # skip protein bonds
            continue
        a1, a2 = int(fields[1]), int(fields[2])
        if a1 in ligand_atom_ids_set and a2 in ligand_atom_ids_set:
            new_a1 = ligand_atom_id_map[a1]
            new_a2 = ligand_atom_id_map[a2]
            new_line = f"{bond_num:>6} {new_a1:>5} {new_a2:>5} {fields[3]}"
            if len(fields) > 4:
                new_line += " " + " ".join(fields[4:])
            ligand_bond_lines.append(new_line + "\n")
            bond_num += 1

    # --- Save ligand mol2 ---
    with open(output_path, 'w') as f:
        f.write("@<TRIPOS>MOLECULE\n")
        f.write(f"{mol2_path.parent.stem}\n")
        f.write(f"{len(ligand_atom_lines)} {len(ligand_bond_lines)} 1\n")
        f.write("SMALL\n")
        f.write("USER_CHARGES\n\n")

        f.write("@<TRIPOS>ATOM\n")
        for idx, line in enumerate(ligand_atom_lines, 1):
            tokens = line.split()
            atom_id = idx
            atom_name = tokens[1]
            x, y, z = map(float, tokens[2:5])
            atom_type = tokens[5]
            resid = tokens[6]
            resname = tokens[7]
            charge = float(tokens[8]) if len(tokens) > 8 else 0.0

            formatted = f"{atom_id:>7} {atom_name:<8} {x:>10.4f} {y:>10.4f} {z:>10.4f} "
            formatted += f"{atom_type:<6} {resid:>5} {resname:<8} {charge:>9.4f}\n"
            f.write(formatted)

        f.write("@<TRIPOS>BOND\n")
        for line in ligand_bond_lines:
            f.write(line)
    
    print(f"Ligand saved: {output_path}")
    complex = to_pdb_format(atom_lines, bond_lines, chain_lines)
    ligand =  ligand_reader(ligand_atom_lines)
    pocket = find_pocket_residues(complex, ligand)
    write_pocket_to_mol2(pocket, output_path.parent / 'pocket_atoms.mol2')
    print('Pocket saved: ', output_path.parent / 'pocket_atoms.mol2')
    return pocket

# pk = read_csar(holo)

In [None]:
## 72mins
no_csar = []
new_csar = []
import pandas as pd
from rdkit import Chem
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

def sequence_reader(df):
    atom_end = df[df['record'] == 'ATOM'].index[-1]
    main_structure = df[:atom_end]
    read, ridx, r_ridx = read_residue(df[:atom_end][['res_name', 'chain_id', 'res_seq', 'i_code']].values)
    return main_structure, read, ridx, r_ridx


def find_root_atom(df):
    return df[df['atom_name'] == 'CA'][['chain_id', 'res_seq', 'i_code']].values


def find_atom(df, keys):
    result = []
    for k in keys:
        result.append(df[(df['chain_id'] == k[0]) & (df['res_seq'] == k[1]) & (df['i_code'] == k[2])])
    result = pd.concat(result).reset_index(drop=True)
    return result


count = 0
max_mutation = 0

rmsd_results = {}
apo_path = Path('data/raw/afdb')
holo_path = Path('data/raw/pdb')

# output path
apo_align_path = Path('data/raw/af_align')
apo_align_path.mkdir(exist_ok=True, parents=True)

mutant_path = Path('data/raw/mutant')
mutant_path.mkdir(exist_ok=True, parents=True)

a2h_path = Path('data/processed/a2h')
a2h_path.mkdir(exist_ok=True, parents=True)

# main loop
num = 0
for idx, row in data.iterrows():
    pdb = row['PDB']
    uni = row['UniProt']
    st = row['SET']

    # logging.info(f"\n[{idx+1}/{len(data)}] Processing code: {pdb}")
    
    if pdb in no_match:
        continue

    if st == 'PDB':
        continue
    
    elif st != 'CSAR':
        continue
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"    
        holo = holo_path / st / pdb / f"{pdb}_protein.pdb"
        pocket = holo_path / st / pdb / f"{pdb}_pocket.pdb"
        
        # read pdb format
        ap = read_pdb(apo)
        ho = read_pdb(holo)
        pk = read_pdb(pocket)
        
        # read pocket and parse chain_id from holo
        ho_keys = parse_residue_and_align(ho[['res_seq', 'chain_id']].values)
        pk['c_key'] = parse_residue_and_align(pk['res_seq'].values, chain=False)
        pk['chain_id'] = pk['c_key'].map(ho_keys)
        pk = pk.drop(columns=['c_key']).reset_index(drop=True)
        pk.loc[pk['res_name'] == 'HOH', 'chain_id'] = ' '

    else: # CSAR
        # continue # pocket이 잘못만들어짐.. ㅅㅂ 내 시간.. 그냥 mol2파일에서 10옴스트롬이내 추출 ㄱㄱ
        
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"
        holo = [f for f in(holo_path / st / pdb).glob('*complex.mol2')][0]
        pocket = pocket_idx[pocket_idx['PDB'] == pdb]['Residue'].values[0]
        
        # read for pdb format
        ap = read_pdb(apo)
        ho = read_mol2(holo)
        pk = read_csar(holo)
        
        # fix csar sequence
        pk_seq = sequence_reader(pk)[1]
        try:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand.mol2')
            smi = Chem.MolToSmiles(li)
        except:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand_rm_mg.mol2')
            smi = Chem.MolToSmiles(li)
        row['Pocket'] = pk_seq
        row['Pocket_Len'] = len(pk_seq)
        row['Ligand'] = smi
        row['Ligand_Len'] = len(smi)
        new_csar.append(row.to_dict())
        
        # remove ligand
        ho = ho[ho['res_name'] != 'INH'].reset_index(drop=True)

    num += 1
    if num % 1000 == 0:
        print(num)

    # # common process
    # extract amino acid sequence
    ho_main, ho_read, ho_ridx, ho_r_ridx = sequence_reader(ho)
    pk_main, pk_read, pk_ridx, pk_r_ridx = sequence_reader(pk)
    ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    # print(len(ho_ridx), ho_read)
    # print(len(pk_ridx), pk_read)
    # print(len(ap_ridx), ap_read)

    # search pocket in holo
    pk_ca = find_root_atom(pk_main)
    ho_pk = find_atom(ho_main, pk_ca)
    ho_ca = find_root_atom(ho_pk)
    ho_pk_idx = sorted([ho_ridx[(idx[0], idx[1], idx[2])] for idx in ho_ca])
    
    # extract pocket amino acids in holo (save position)
    ho_pk_read = ho_read[ho_pk_idx[0]:ho_pk_idx[-1]+1]
    ho_pk_res_dict = {ho_pk_idx[0] + i:i  for i in range(len(ho_pk_read))}
    ho_pk_res_idx = [ho_pk_res_dict[i] for i in ho_pk_idx] # for mutation position
    
    # sequence alignment (score params: match, mismatch, gap open, gap extension)
    alignments = pairwise2.align.localms(ho_pk_read, ap_read, 2, -1, -5, -1)
    ho_align, ap_align = alignments[0][:2]

    # mapping position & check mutation
    mapping_align = {}
    align_position = {}
    mutations = []
    st_pos = ho_pos = ap_pos = -1

    for a, b in zip(ho_align, ap_align):
        if a != '-':
            ho_pos += 1
            st_pos += 1
        if b != '-':
            ap_pos += 1
            st_pos += 1
        if a != '-' and b != '-':
            st_pos += 1
            align_position[st_pos] = ho_pos
            mapping_align[ho_pos] = ap_pos
            if a != b and ho_pos in ho_pk_res_idx:
                mutations.append((ap_pos, b, a))
                
    # check mutation
    if mutations:
        max_mutation = max(max_mutation, len(mutations))

        transform = []
        for mut in mutations:
            ap_pk_idx, ap_res, ho_res = mut
            ap_chain, ap_resid, _ = ap_r_ridx[ap_pk_idx]
            transform.append((ap_chain, ap_resid, ap_res, ho_res))
        
        mut_out = mutant_path / f"AF-{uni}-mut-{pdb}.pdb"    
        point_mutation(apo, mut_out, transform)

        # reload
        apo = mut_out
        ap = read_pdb(apo)
        ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)

    try:
        # find pocket in apo (ho_pk_idx -> ho_pk_res_dict -> mapping align)
        ap_ca = []
        rm_ca = []
        for i in ho_pk_idx:
            hp_key = ho_pk_res_dict[i]  # atom within pocket
            if hp_key in mapping_align.keys(): # atom within mapping & pocket
                ap_ca.append(ap_r_ridx[mapping_align[hp_key]]) # find 
            else:
                rm_ca.append(i)
        if rm_ca: # select atom atoms only in apo
            ho_pk = find_atom(ho_main, [ho_r_ridx[i] for i in ho_pk_idx if not i in rm_ca])
        ap_pk = find_atom(ap_main, ap_ca) # find pocket in apo
    except:
        print(pdb)
        print(format_alignment(*alignments[0]))
        no_csar.append(pdb)
        continue

    # superimpose
    align_out = apo_align_path / f"AF-{uni}-align-{pdb}.pdb"
    rmsd = align_and_save(apo, holo, align_out)
    rmsd_results[pdb] = rmsd
    
    # mutated & aligned apo structure
    aa = read_pdb(align_out)
    aa_main, aa_read, aa_ridx, aa_r_ridx = sequence_reader(aa)
    aa_pk = find_atom(aa_main, ap_ca)
    
    # save
    a2h_out = a2h_path / f"{pdb}_a2h.pkl"
    a2h = {'APO': aa_pk, 'HOLO': ho_pk}
    with open(a2h_out, 'wb') as f:
        pd.to_pickle(a2h, f)

    # if len(mutations) >= 1:
    #     print(pdb)
    #     print(format_alignment(*alignments[0]))
    #     print(sequence_reader(ho_pk)[1])
    #     print(sequence_reader(ap_pk)[1])        
    #     print()
    #     break

print('done')

In [None]:
pd.DataFrame(new_csar).to_csv('data/fix_csar.csv', index=False, header=True)

In [None]:
# search largest vector
import pickle
from pathlib import Path

def find_largest_vector(directory):
    directory = Path(directory)
    max_count = 0
    max_ca = 0
    max_file = None
    max_ca_file = None

    with open(a2h_path / f"{intereset}_a2h.pkl", 'rb') as f:
        lb = pd.read_pickle(f)
        count = len(data['APO'])
        ca = pd.DataFrame(data['APO'])
        ca = ca[ca['atom_name'] == 'CA']
        if count > max_count:
            max_count = count
            max_file = pkl_file
        
        ca_count = len(ca)
        if ca_count > max_ca:
            max_ca = ca_count
            max_ca_file = pkl_file

    print(f"Largest vector: {max_file.name} with {max_count} entries")
    print(f"Largest CA vector: {max_ca_file.name} with {max_ca} entries")

find_largest_vector("data/raw/af_align")

In [None]:
import pickle
from pathlib import Path
import pandas as pd
import numpy as np

def find_files_with_coord_diff(directory, threshold=1e-3):
    directory = Path(directory)
    files_with_diff = []

    for pkl_file in directory.glob("*.pkl"):
        with open(pkl_file, "rb") as f:
            data = pickle.load(f)
        apo = pd.DataFrame(data['APO'])
        holo = pd.DataFrame(data['HOLO'])

        # 기준: res_seq, atom_name이 같은 원자끼리 좌표 비교
        merged = pd.merge(
            apo, holo,
            on=['res_seq', 'atom_name'],
            suffixes=('_apo', '_holo')
        )

        # coord 컬럼에서 좌표 추출
        def extract_coords(row, prefix):
            return np.array(row[f'coord_{prefix}'])

        # 좌표 차이 계산
        merged['coord_diff'] = merged.apply(
            lambda row: np.linalg.norm(
                extract_coords(row, 'apo') - extract_coords(row, 'holo')
            ),
            axis=1
        )

        # 임계값(threshold) 이상 차이가 있는 경우
        if (merged['coord_diff'] > threshold).any():
            files_with_diff.append(pkl_file.name)

    print(f"Files with coordinate differences (>{threshold} Å):")
    for fname in files_with_diff:
        print(fname)

find_files_with_coord_diff("APO_to_Holo", threshold=0.1)  # 0.1Å 이상 차이

# Final

In [None]:
## 72mins
final_data = []
import pandas as pd
from rdkit import Chem
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

def sequence_reader(df):
    atom_end = df[df['record'] == 'ATOM'].index[-1]
    main_structure = df[:atom_end]
    read, ridx, r_ridx = read_residue(df[:atom_end][['res_name', 'chain_id', 'res_seq', 'i_code']].values)
    return main_structure, read, ridx, r_ridx


def find_root_atom(df):
    return df[df['atom_name'] == 'CA'][['chain_id', 'res_seq', 'i_code']].values


def find_atom(df, keys):
    result = []
    for k in keys:
        result.append(df[(df['chain_id'] == k[0]) & (df['res_seq'] == k[1]) & (df['i_code'] == k[2])])
    result = pd.concat(result).reset_index(drop=True)
    return result


count = 0
max_mutation = 0

rmsd_results = {}
apo_path = Path('data/raw/afdb')
holo_path = Path('data/raw/pdb')

# output path
apo_align_path = Path('data/raw/af_align')
apo_align_path.mkdir(exist_ok=True, parents=True)

mutant_path = Path('data/raw/mutant')
mutant_path.mkdir(exist_ok=True, parents=True)

a2h_path = Path('data/processed/a2h')
a2h_path.mkdir(exist_ok=True, parents=True)

# main loop
num = 0
for idx, row in data.iterrows():
    pdb = row['PDB']
    uni = row['UniProt']
    st = row['SET']

    if pdb in no_match:
        continue

    if st == 'PDB':
        continue
    
    elif st != 'CSAR':
        final_data.append(row.to_dict())

        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"    
        holo = holo_path / st / pdb / f"{pdb}_protein.pdb"
        pocket = holo_path / st / pdb / f"{pdb}_pocket.pdb"
        
        # read pdb format
        ap = read_pdb(apo)
        ho = read_pdb(holo)
        pk = read_pdb(pocket)
        
        # read pocket and parse chain_id from holo
        ho_keys = parse_residue_and_align(ho[['res_seq', 'chain_id']].values)
        pk['c_key'] = parse_residue_and_align(pk['res_seq'].values, chain=False)
        pk['chain_id'] = pk['c_key'].map(ho_keys)
        pk = pk.drop(columns=['c_key']).reset_index(drop=True)
        pk.loc[pk['res_name'] == 'HOH', 'chain_id'] = ' '

    else: # CSAR
        
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"
        holo = [f for f in(holo_path / st / pdb).glob('*complex.mol2')][0]
        pocket = pocket_idx[pocket_idx['PDB'] == pdb]['Residue'].values[0]
        
        # read for pdb format
        ap = read_pdb(apo)
        ho = read_mol2(holo)
        pk = read_csar(holo)
        
        # fix csar sequence
        pk_seq = sequence_reader(pk)[1]
        try:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand.mol2')
            smi = Chem.MolToSmiles(li)
        except:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand_rm_mg.mol2')
            smi = Chem.MolToSmiles(li)
        row['Pocket'] = pk_seq
        row['Pocket_Len'] = len(pk_seq)
        row['Ligand'] = smi
        row['Ligand_Len'] = len(smi)
        final_data.append(row.to_dict())
        
        # remove ligand
        ho = ho[ho['res_name'] != 'INH'].reset_index(drop=True)

    num += 1
    if num % 1000 == 0:
        print(num)

    logging.info(f"\n[{idx+1}/{len(data)}] Processing code: {pdb}")

    # # common process
    # extract amino acid sequence
    ho_main, ho_read, ho_ridx, ho_r_ridx = sequence_reader(ho)
    pk_main, pk_read, pk_ridx, pk_r_ridx = sequence_reader(pk)
    ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    # print(len(ho_ridx), ho_read)
    # print(len(pk_ridx), pk_read)
    # print(len(ap_ridx), ap_read)

    # search pocket in holo
    pk_ca = find_root_atom(pk_main)
    ho_pk = find_atom(ho_main, pk_ca)
    ho_ca = find_root_atom(ho_pk)
    ho_pk_idx = sorted([ho_ridx[(idx[0], idx[1], idx[2])] for idx in ho_ca])
    
    # extract pocket amino acids in holo (save position)
    ho_pk_read = ho_read[ho_pk_idx[0]:ho_pk_idx[-1]+1]
    ho_pk_res_dict = {ho_pk_idx[0] + i:i  for i in range(len(ho_pk_read))}
    ho_pk_res_idx = [ho_pk_res_dict[i] for i in ho_pk_idx] # for mutation position
    
    # sequence alignment (score params: match, mismatch, gap open, gap extension)
    alignments = pairwise2.align.localms(ho_pk_read, ap_read, 2, -1, -5, -1)
    ho_align, ap_align = alignments[0][:2]

    # mapping position & check mutation
    mapping_align = {}
    align_position = {}
    mutations = []
    st_pos = ho_pos = ap_pos = -1

    for a, b in zip(ho_align, ap_align):
        if a != '-':
            ho_pos += 1
            st_pos += 1
        if b != '-':
            ap_pos += 1
            st_pos += 1
        if a != '-' and b != '-':
            st_pos += 1
            align_position[st_pos] = ho_pos
            mapping_align[ho_pos] = ap_pos
            if a != b and ho_pos in ho_pk_res_idx:
                mutations.append((ap_pos, b, a))
                
    # check mutation
    if mutations:
        max_mutation = max(max_mutation, len(mutations))

        transform = []
        for mut in mutations:
            ap_pk_idx, ap_res, ho_res = mut
            ap_chain, ap_resid, _ = ap_r_ridx[ap_pk_idx]
            transform.append((ap_chain, ap_resid, ap_res, ho_res))
        
        mut_out = mutant_path / f"AF-{uni}-mut-{pdb}.pdb"    
        point_mutation(apo, mut_out, transform)

        # reload
        apo = mut_out
        ap = read_pdb(apo)
        ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    
    # find pocket in apo (ho_pk_idx -> ho_pk_res_dict -> mapping align)
    ap_ca = []
    rm_ca = []
    for i in ho_pk_idx:
        hp_key = ho_pk_res_dict[i]  # atom within pocket
        if hp_key in mapping_align.keys(): # atom within mapping & pocket
            ap_ca.append(ap_r_ridx[mapping_align[hp_key]]) # find 
        else:
            rm_ca.append(i)
    if rm_ca: # select atom atoms only in apo
        ho_pk = find_atom(ho_main, [ho_r_ridx[i] for i in ho_pk_idx if not i in rm_ca])
    ap_pk = find_atom(ap_main, ap_ca) # find pocket in apo

    # superimpose
    align_out = apo_align_path / f"AF-{uni}-align-{pdb}.pdb"
    rmsd = align_and_save(apo, holo, align_out)
    rmsd_results[pdb] = rmsd
    
    # mutated & aligned apo structure
    aa = read_pdb(align_out)
    aa_main, aa_read, aa_ridx, aa_r_ridx = sequence_reader(aa)
    aa_pk = find_atom(aa_main, ap_ca)
    
    # save
    a2h_out = a2h_path / f"{pdb}_a2h.pkl"
    a2h = {'APO': aa_pk, 'HOLO': ho_pk}
    with open(a2h_out, 'wb') as f:
        pd.to_pickle(a2h, f)

print('done')

In [None]:
with open('data/raw/rmsd_results.txt', 'w') as f:
    for k, v in rmsd_results.items():
        f.write(f"{k}: {v}\n")

In [None]:
final = pd.DataFrame(final_data)
final.to_csv('data/final.csv', index=False, header=True)
final

In [None]:
final = final[~final['PDB'].isin(no_match)]
final['SET'].value_counts()

In [None]:
final[final['SET'] == 'Refined']['UniProt'].value_counts()

In [None]:
align_pdbs = [f.stem.split('-')[3] for f in apo_align_path.glob('*.pdb')]
len(align_pdbs) # exclude 29 pdbs (no_match)

In [None]:
import json
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

general = final[final['SET'] == 'General']
refined = final[final['SET'] == 'Refined']

seeds = [42, 100, 123, 456, 789]
for idx, seed in enumerate(seeds):
    # minimum sample per UniProt
    num_uni = refined['UniProt'].nunique()
    samples_per_uni = 1000 // num_uni

    # sampling each UniProt (maximum samples_per_uni)
    sampled = refined.groupby('UniProt', group_keys=False).apply(
        lambda x: x.sample(n=min(samples_per_uni, len(x)), random_state=seed)
    )

    # remaining samples
    remaining = 1000 - len(sampled)
    if remaining > 0:
        additional = refined[~refined.index.isin(sampled.index)].sample(n=remaining, random_state=seed)
        valid = pd.concat([sampled, additional])

    #check 
    valid['UniProt'].value_counts()

    # create training set
    remain_refined = refined[~refined.index.isin(valid.index)]
    train = pd.concat([general, remain_refined]).reset_index(drop=True)

    # extract pdb list
    train_pdb = train['PDB'].tolist()
    valid_pdb = valid['PDB'].tolist()
    fold_pdb = {'SEED': seed, 'TRN': train_pdb, 'VAL': valid_pdb}

    # save
    with open(f'cache/fold_{idx+1}.json', 'w') as f:
        json.dump(fold_pdb, f)
    
    # check
    print(f'Fold {idx+1}: Seed {seed}')
    print(len(train), len(valid))
    print()