# 2025.05.21 전처리

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm

data = pd.read_csv('data/AF_data_info.csv')
data.head(1)

## Read PDB File

In [2]:
def read_pdb(pdb_path):
    records = []
    with open(pdb_path, 'r') as f:
        for line in f:
            if line.startswith(('ATOM', 'HETATM')):
                record = {
                    'record': line[0:6].strip(),
                    'atom_serial': int(line[6:11]),
                    'atom_name': line[12:16].strip(),
                    'alt_loc': line[16],
                    'res_name': line[17:20].strip(),
                    'chain_id': line[21],
                    'res_seq': int(line[22:26]),
                    'i_code': line[26],
                    'x': float(line[30:38]),
                    'y': float(line[38:46]),
                    'z': float(line[46:54]),
                    'occupancy': float(line[54:60]),
                    'b_factor': float(line[60:66]),
                    'element': line[76:78].strip(),
                    'charge': line[78:80].strip()
                }
                records.append(record)
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']
    return records

## Read Mol2 to PDB format

In [3]:
def read_mol2(mol2_path, ligand=False):
    with open(mol2_path) as f:
        lines = f.readlines()

    if ligand:
        atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
        bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
        atom_end = bond_start
        atom_lines = lines[atom_start+1:atom_end]
        return ligand_reader(atom_lines)
    
    else:
        atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
        bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
        chain_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SUBSTRUCTURE'))
        set_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SET'))

        atom_end = bond_start
        bond_end = chain_start
        chain_end = set_start

        atom_lines = lines[atom_start+1:atom_end]
        bond_lines = lines[bond_start+1:bond_end]
        chain_lines = lines[chain_start+1:chain_end]
        return to_pdb_format(atom_lines, bond_lines, chain_lines)


def ligand_reader(atom):
    records = []
    for idx, line in enumerate(atom):
        fields = line.split()
        atom_id = int(fields[0])
        atom_name = fields[1]
        x, y, z = map(float, fields[2:5])
        atom_type = fields[5]
        resid = int(fields[6])
        resname_full = fields[7]
        res_name = resname_full[:3]
        occupancy = 1.0
        b_factor = 0.0
        element = atom_type.split('.')[0][0].upper()
        charge = fields[8] if len(fields) > 8 and fields[8].replace('.', '', 1).replace('-', '', 1).isdigit() else ''
        record = {
            'record': 'ATOM',
            'atom_serial': atom_id,
            'atom_name': atom_name,
            'alt_loc': '',
            'res_name': res_name,
            'chain_id': '',
            'res_seq': resid,
            'i_code': '',
            'x': x,
            'y': y,
            'z': z,
            'atom_type': atom_type,
            'occupancy': occupancy,
            'b_factor': b_factor,
            'element': element,
            'charge': charge
        }
        records.append(record)
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']       
    return records.reset_index(drop=True)


def parse_substructure(lines):
    substructures = []
    for line in lines:
        parts = line.strip().split()
        sub = {
            'id': int(parts[0]),
            'name': parts[1],
            'root_atom': int(parts[2]),
            'type': parts[3],
            'dict_type': parts[4],
            'chain': parts[5],
            'residue': parts[6],
            'unk': parts[7],
            'status': parts[8],
        }
        substructures.append(sub)
    return pd.DataFrame(substructures)


def to_pdb_format(atom, bond, chain):
    records = []
    chains = parse_substructure(chain)
    
    box = {}
    prev = None
    root_atom_chain = None
    chain_atoms = [1]
    for idx, line in enumerate(atom):
        fields = line.split()
        atom_id = int(fields[0])
        atom_name = fields[1]
        x, y, z = map(float, fields[2:5])
        atom_type = fields[5]
        resid = int(fields[6])
        
        resname_full = fields[7]
        res_name = resname_full[:3]
        
        if idx == 0:
            prev = resid
        else:
            if resid == prev:
                chain_atoms.append(atom_id)
            else:
                for a in chain_atoms:
                    box[a] = root_atom_chain
                prev = resid
                chain_atoms = [atom_id]

        if atom_id in chains['root_atom'].values:
            chain_id = chains[chains['root_atom'] == atom_id]['chain'].values[0]
            root_atom_chain = chain_id
        else:
            chain_id = None
            
        occupancy = 1.0
        b_factor = 0.0
        element = atom_type.split('.')[0][0].upper()
        charge = fields[8] if len(fields) > 8 and fields[8].replace('.', '', 1).replace('-', '', 1).isdigit() else ''
        record = {
            'record': 'ATOM',
            'atom_serial': atom_id,
            'atom_name': atom_name,
            'alt_loc': '',
            'res_name': res_name,
            'chain_id': chain_id,
            'res_seq': resid,
            'i_code': '',
            'x': x,
            'y': y,
            'z': z,
            'atom_type': atom_type,
            'occupancy': occupancy,
            'b_factor': b_factor,
            'element': element,
            'charge': charge
        }
        records.append(record)
    
    records = pd.DataFrame(records)
    records = records[records['element'] != 'H']       
    records['chain_id'] = records['atom_serial'].map(box)
    return records.reset_index(drop=True)

## Parsing Pocket

In [4]:
amino_acids = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
    }

rev_amino_acids = {v: k for k, v in amino_acids.items()}

modified_residue_map = {
    "MSE": "M",  # Selenomethionine → Methionine
    "SEP": "S",  # Phosphoserine → Serine
    "TPO": "T",  # Phosphothreonine → Threonine
    "PTR": "Y",  # Phosphotyrosine → Tyrosine
    "HYP": "P",  # Hydroxyproline → Proline
    "KCX": "K",  # Carboxylysine → Lysine
    "CSO": "C",  # Oxidized cysteine → Cysteine
    "CGU": "E",  # γ-carboxy-glutamate → Glutamic Acid
    "F2F": "F",  # Fluorophenylalanine → Phenylalanine
    "ASH": "D",  # Protonated Aspartic Acid → Aspartic Acid
    "GLH": "E",  # Protonated Glutamic Acid → Glutamic Acid
    "CYX": "C",  # Disulfide-bonded Cysteine → Cysteine
    "HID": "H",  # Neutral Histidine (delta-protonated)
    "HIE": "H",  # Neutral Histidine (epsilon-protonated)
    "HIP": "H",  # Positively charged Histidine
}
def parse_residue_and_align(lst, chain=True):
    result = []
    prev = None
    seen_counts = {}
    
    if chain:
        
        for num, chain in lst:
            if num == prev:
                continue
            count = seen_counts.get(num, 0) + 1
            seen_counts[num] = count
            key = (f"{num}" if count == 1 else f"{num}-{count}", chain)
            result.append(key)
            prev = num
        result = {k: v for k, v in result}
    
    else:    
        for num in lst:
            if num == prev:
                result.append(key)
                continue 
            count = seen_counts.get(num, 0) + 1
            seen_counts[num] = count
            key = f"{num}" if count == 1 else f"{num}-{count}"
            result.append(key)
            prev = num
            
    return result


def read_residue(res):
    global amino_acids

    init = 0
    result = []
    res_dict = {}
    already = set()
    for r, c, idx, ic in res:
        if r == 'HOH':
            continue
        if r not in amino_acids:
            conv_r = 'X'
        else:
            conv_r = amino_acids[r]
        key = (c, idx, ic)
        if key in already:
            continue
        already.add(key)
        result.append(conv_r)
        res_dict[key] = init
        init += 1
    rev_res_dict = {v: k for k, v in res_dict.items()}
    return ''.join(result), res_dict, rev_res_dict

## PyMOL Modules

In [5]:

import os
import pymol
from pymol import cmd
from pathlib import Path
from tqdm import tqdm

# pymol.finish_launching() # if you see the work in pymol, you need to run this


def align_and_save(apo_path, holo_path, output_path):
    cmd.reinitialize()  # Clear previous structures
    cmd.load(str(apo_path), 'apo')
    cmd.load(str(holo_path), 'holo')
    alignment_info = cmd.super('apo', 'holo')
    cmd.save(str(output_path), 'apo')
    rmsd = alignment_info[0]  # RMSD
    return rmsd


def point_mutation(apo_path, output_path, mutations):
    global rev_amino_acids
    cmd.reinitialize()  # Clear previous structures
    cmd.load(apo_path, 'apo')
    cmd.wizard('mutagenesis')
    for row in mutations:
        chain, resid, ap_res, ho_res = map(str, row)
        # PyMOL residue selection: /object//chain/resi
        if ho_res == 'X':
            continue
        selection = f'/apo//{chain}/{resid}/'
        print(f'selcction: {selection}')
        print(f'apo: {ap_res} -> holo: {ho_res}')
        cmd.get_wizard().set_mode(rev_amino_acids[ho_res])
        cmd.get_wizard().do_select(selection)
        cmd.get_wizard().apply()
    cmd.set_wizard()
    cmd.save(output_path, 'apo')

In [6]:
from pathlib import Path
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist

def find_pocket_residues(complex_df: pd.DataFrame, ligand_df: pd.DataFrame, cutoff: float = 10.0):
    ligand_df = ligand_df.reset_index(drop=True)
    complex_df = complex_df.reset_index(drop=True)
    ligand_coords = ligand_df[['x', 'y', 'z']].to_numpy()
    protein_df = complex_df[~complex_df['res_name'].isin(ligand_df['res_name'].unique())].copy()
    protein_df = protein_df.reset_index(drop=True)
    protein_coords = protein_df[['x', 'y', 'z']].to_numpy()
    
    # calculate distance between ligand and protein
    dist_matrix = cdist(ligand_coords, protein_coords)
    
    # extract indices of atoms within cutoff
    close_idx = np.where(dist_matrix <= cutoff)[1]
    close_atoms = protein_df.iloc[close_idx]

    # extract pocket residues
    pocket_residues = {
        (row['chain_id'], row['res_seq'], row['i_code'])
        for _, row in close_atoms.iterrows()
    }
    pocket_df = find_atom(complex_df, pocket_residues)
    return pocket_df


def write_pocket_to_mol2(pocket_df: pd.DataFrame, output_path: str):
    pocket_df = pocket_df.reset_index(drop=True)
    pocket_df['atom_id'] = np.arange(1, len(pocket_df) + 1)
    pocket_df[['x', 'y', 'z', 'charge']] = pocket_df[['x', 'y', 'z', 'charge']].astype(float)

    with open(output_path, 'w') as f:
        # MOLECULE section
        f.write("@<TRIPOS>MOLECULE\n")
        f.write("POCKET\n")
        f.write(f"{len(pocket_df)} 0 1\n")
        f.write("PROTEIN\n")
        f.write("USER_CHARGES\n\n")

        # ATOM section
        f.write("@<TRIPOS>ATOM\n")
        for _, row in pocket_df.iterrows():
            f.write(
                f"{row['atom_id']:>7} {row['atom_name']:<8} "
                f"{row['x']:>10.4f} {row['y']:>10.4f} {row['z']:>10.4f} "
                f"{row['atom_type']:<6} {row['res_seq']:>5} {row['res_name']:<8} "
                f"{row.get('charge', 0.0):>9.4f}\n"
            )

        # BOND section (empty)
        f.write("@<TRIPOS>BOND\n")


def read_csar(mol2_path):
    mol2_path = Path(mol2_path)
    output_path = mol2_path.parent / f'{mol2_path.parent.stem}_ligand.mol2'

    with open(mol2_path) as f:
        lines = f.readlines()

    atom_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>ATOM'))
    bond_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>BOND'))
    chain_start = next(i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SUBSTRUCTURE'))
    set_start = next((i for i, l in enumerate(lines) if l.startswith('@<TRIPOS>SET')), len(lines))

    atom_lines = lines[atom_start+1:bond_start]
    bond_lines = lines[bond_start+1:chain_start]
    chain_lines = lines[chain_start+1:set_start]

    # --- parse ligand definition ---
    ligand_line_idx = next(i for i, l in enumerate(lines) if l.strip().startswith("LIGAND"))

    ligand_atom_ids = []
    i = ligand_line_idx + 1
    while i < len(lines):
        line = lines[i].strip()
        if not line or not line.split()[0].isdigit():
            break
        ligand_atom_ids.extend(int(token) for token in line.split() if token.isdigit())
        i += 1

    ligand_atom_ids_set = set(ligand_atom_ids[1:])
    ligand_atom_lines = [l for l in atom_lines if int(l.split()[0]) in ligand_atom_ids_set]
    ligand_atom_id_map = {int(l.split()[0]): idx+1 for idx, l in enumerate(ligand_atom_lines)}

    # --- extract ligand bond ---
    ligand_bond_lines = []
    bond_num = 1
    for l in bond_lines:
        fields = l.split()
        if len(fields) < 4: # skip protein bonds
            continue
        a1, a2 = int(fields[1]), int(fields[2])
        if a1 in ligand_atom_ids_set and a2 in ligand_atom_ids_set:
            new_a1 = ligand_atom_id_map[a1]
            new_a2 = ligand_atom_id_map[a2]
            new_line = f"{bond_num:>6} {new_a1:>5} {new_a2:>5} {fields[3]}"
            if len(fields) > 4:
                new_line += " " + " ".join(fields[4:])
            ligand_bond_lines.append(new_line + "\n")
            bond_num += 1

    # --- Save ligand mol2 ---
    with open(output_path, 'w') as f:
        f.write("@<TRIPOS>MOLECULE\n")
        f.write(f"{mol2_path.parent.stem}\n")
        f.write(f"{len(ligand_atom_lines)} {len(ligand_bond_lines)} 1\n")
        f.write("SMALL\n")
        f.write("USER_CHARGES\n\n")

        f.write("@<TRIPOS>ATOM\n")
        for idx, line in enumerate(ligand_atom_lines, 1):
            tokens = line.split()
            atom_id = idx
            atom_name = tokens[1]
            x, y, z = map(float, tokens[2:5])
            atom_type = tokens[5]
            resid = tokens[6]
            resname = tokens[7]
            charge = float(tokens[8]) if len(tokens) > 8 else 0.0

            formatted = f"{atom_id:>7} {atom_name:<8} {x:>10.4f} {y:>10.4f} {z:>10.4f} "
            formatted += f"{atom_type:<6} {resid:>5} {resname:<8} {charge:>9.4f}\n"
            f.write(formatted)

        f.write("@<TRIPOS>BOND\n")
        for line in ligand_bond_lines:
            f.write(line)
    
    # print(f"Ligand saved: {output_path}")
    complex = to_pdb_format(atom_lines, bond_lines, chain_lines)
    ligand =  ligand_reader(ligand_atom_lines)
    pocket = find_pocket_residues(complex, ligand)
    write_pocket_to_mol2(pocket, output_path.parent / 'pocket_atoms.mol2')
    print('Pocket saved: ', output_path.parent / 'pocket_atoms.mol2')
    return pocket

# pk = read_csar(holo)

## APO Align

In [7]:
import pickle
from pathlib import Path
import logging

# 로그 설정
out = Path('data/raw/')
out.mkdir(parents=True, exist_ok=True)
log_path = out / "processing.log"
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s: %(message)s',
    handlers=[
        logging.FileHandler(log_path, mode='w'),
        logging.StreamHandler()
    ]
)

In [None]:
no_pdbbind = ['1mh5', '1d6v', '1ct8', '1i7z', '1a4k', '1a0q', '1i6v', '1c12', '2qhr', '4nyi', '4nyj', '4nym', '2hrp', '3eql', '4kmu', '4kn4', '4kn7', '2mpa', '1kcs', '1zyr']
no_csar = ['1gz9', '2v7t', '2j4k', '2v7u', '1swk', '2vhw', '1gzc', '1bcj', '2hr6']
no_rdkit = ['5tmp', '1ksn', '4kcx', '4lhm', '4xtw', '4xtx', '1sl3', '1mue', '1a7x', '3zjt', '4ob0', '4ob1', '4ob2', '4y63', '2fm5', '2fou', '2fov', '2foy', '3bho', '3v9b', '1nu1', '4u5t', '4i60', '3whw', '3cst', '3fxz', '3fy0', '4ie2', '4ie3', '3e6k', '3vjs', '3vjt', '2vr0', '3bwf', '3zju', '4lv1', '4wkv', '4wku', '3zp9', '4hww', '2pll', '2q2n', '3kck', '2z3z', '4l6q', '1qpf', '3e9b', '3egk', '4inr', '4inu', '4ixu', '4ixv', '1h07', '2ci9', '3wax', '3wd2', '4hze', '4i06', '4wvs', '4no1', '2eep', '4wkt', '1epq', '4rlp', '4hxq', '4bcb']
no_match = no_pdbbind + no_csar + no_rdkit # 20 + 9 + 66
print(len(no_match)) # no match pocket

# Final

In [None]:

## 72mins
final_data = []
import pandas as pd
from rdkit import Chem
from Bio import pairwise2
from Bio.pairwise2 import format_alignment

def sequence_reader(df):
    atom_end = df[df['record'] == 'ATOM'].index[-1]
    main_structure = df[:atom_end]
    read, ridx, r_ridx = read_residue(df[:atom_end][['res_name', 'chain_id', 'res_seq', 'i_code']].values)
    return main_structure, read, ridx, r_ridx


def find_root_atom(df):
    return df[df['atom_name'] == 'CA'][['chain_id', 'res_seq', 'i_code']].values


def find_atom(df, keys):
    result = []
    for k in keys:
        result.append(df[(df['chain_id'] == k[0]) & (df['res_seq'] == k[1]) & (df['i_code'] == k[2])])
    result = pd.concat(result).reset_index(drop=True)
    return result


count = 0
max_mutation = 0

rmsd_results = {}
apo_path = Path('data/raw/afdb')
holo_path = Path('data/raw/pdb')

# output path
apo_align_path = Path('data/raw/af_align')
apo_align_path.mkdir(exist_ok=True, parents=True)

mutant_path = Path('data/raw/mutant')
mutant_path.mkdir(exist_ok=True, parents=True)

a2h_path = Path('data/preprocessed/a2h')
a2h_path.mkdir(exist_ok=True, parents=True)

# main loop
num = 0
for idx, row in data.iterrows():
    pdb = row['PDB']
    uni = row['UniProt']
    st = row['SET']

    if pdb in no_match:
        continue

    if st == 'PDB':
        continue
    
    elif st != 'CSAR':
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"    
        holo = holo_path / st / pdb / f"{pdb}_protein.pdb"
        pocket = holo_path / st / pdb / f"{pdb}_pocket.pdb"
        ligand = holo_path / st / pdb / f"{pdb}_ligand.mol2"
        
        # read pdb format
        ap = read_pdb(apo)
        ho = read_pdb(holo)
        pk = read_pdb(pocket)
        
        # read ligand
        li = Chem.MolFromMol2File(ligand)
        smi = Chem.CanonSmiles(Chem.MolToSmiles(li))
        row['Ligand'] = smi
        row['Ligand_Len'] = len(smi)
        final_data.append(row.to_dict())
        
        # read pocket and parse chain_id from holo
        ho_keys = parse_residue_and_align(ho[['res_seq', 'chain_id']].values)
        pk['c_key'] = parse_residue_and_align(pk['res_seq'].values, chain=False)
        pk['chain_id'] = pk['c_key'].map(ho_keys)
        pk = pk.drop(columns=['c_key']).reset_index(drop=True)
        pk.loc[pk['res_name'] == 'HOH', 'chain_id'] = ' '

    else: # CSAR
        
        # path
        apo = apo_path / f"AF-{uni}-F1-model_v4.pdb"
        holo = [f for f in(holo_path / st / pdb).glob('*complex.mol2')][0]
        
        # read for pdb format
        ap = read_pdb(apo)
        ho = read_mol2(holo)
        pk = read_csar(holo)
        
        # fix csar sequence
        pk_seq = sequence_reader(pk)[1]
        try:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand.mol2')
            smi = Chem.MolToSmiles(li)
        except:
            li = Chem.MolFromMol2File(holo.parent / f'{pdb}_ligand_rm_mg.mol2')
            smi = Chem.MolToSmiles(li)
        smi = Chem.CanonSmiles(smi)
        row['Pocket'] = pk_seq
        row['Pocket_Len'] = len(pk_seq)
        row['Ligand'] = smi
        row['Ligand_Len'] = len(smi)
        final_data.append(row.to_dict())
        
        # remove ligand
        ho = ho[ho['res_name'] != 'INH'].reset_index(drop=True)

    num += 1
    if num % 1000 == 0:
        print(num)

    logging.info(f"\n[{idx+1}/{len(data)}] Processing code: {pdb}")

    # # common process
    # extract amino acid sequence
    ho_main, ho_read, ho_ridx, ho_r_ridx = sequence_reader(ho)
    pk_main, pk_read, pk_ridx, pk_r_ridx = sequence_reader(pk)
    ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    # print(len(ho_ridx), ho_read)
    # print(len(pk_ridx), pk_read)
    # print(len(ap_ridx), ap_read)

    # search pocket in holo
    pk_ca = find_root_atom(pk_main)
    ho_pk = find_atom(ho_main, pk_ca)
    ho_ca = find_root_atom(ho_pk)
    ho_pk_idx = sorted([ho_ridx[(idx[0], idx[1], idx[2])] for idx in ho_ca])
    
    # extract pocket amino acids in holo (save position)
    ho_pk_read = ho_read[ho_pk_idx[0]:ho_pk_idx[-1]+1]
    ho_pk_res_dict = {ho_pk_idx[0] + i:i  for i in range(len(ho_pk_read))}
    ho_pk_res_idx = [ho_pk_res_dict[i] for i in ho_pk_idx] # for mutation position
    
    # sequence alignment (score params: match, mismatch, gap open, gap extension)
    alignments = pairwise2.align.localms(ho_pk_read, ap_read, 2, -1, -5, -1)
    ho_align, ap_align = alignments[0][:2]

    # mapping position & check mutation
    mapping_align = {}
    align_position = {}
    mutations = []
    st_pos = ho_pos = ap_pos = -1

    for a, b in zip(ho_align, ap_align):
        if a != '-':
            ho_pos += 1
            st_pos += 1
        if b != '-':
            ap_pos += 1
            st_pos += 1
        if a != '-' and b != '-':
            st_pos += 1
            align_position[st_pos] = ho_pos
            mapping_align[ho_pos] = ap_pos
            if a != b and ho_pos in ho_pk_res_idx:
                mutations.append((ap_pos, b, a))
                
    # check mutation
    if mutations:
        max_mutation = max(max_mutation, len(mutations))

        transform = []
        for mut in mutations:
            ap_pk_idx, ap_res, ho_res = mut
            ap_chain, ap_resid, _ = ap_r_ridx[ap_pk_idx]
            transform.append((ap_chain, ap_resid, ap_res, ho_res))
        
        mut_out = mutant_path / f"AF-{uni}-mut-{pdb}.pdb"    
        point_mutation(apo, mut_out, transform)

        # reload
        apo = mut_out
        ap = read_pdb(apo)
        ap_main, ap_read, ap_ridx, ap_r_ridx = sequence_reader(ap)
    
    # find pocket in apo (ho_pk_idx -> ho_pk_res_dict -> mapping align)
    ap_ca = []
    rm_ca = []
    for i in ho_pk_idx:
        hp_key = ho_pk_res_dict[i]  # atom within pocket
        if hp_key in mapping_align.keys(): # atom within mapping & pocket
            ap_ca.append(ap_r_ridx[mapping_align[hp_key]]) # find 
        else:
            rm_ca.append(i)
    if rm_ca: # select atom atoms only in apo
        ho_pk = find_atom(ho_main, [ho_r_ridx[i] for i in ho_pk_idx if not i in rm_ca])
    ap_pk = find_atom(ap_main, ap_ca) # find pocket in apo

    # superimpose
    align_out = apo_align_path / f"AF-{uni}-align-{pdb}.pdb"
    rmsd = align_and_save(apo, holo, align_out)
    rmsd_results[pdb] = rmsd
    
    # mutated & aligned apo structure
    aa = read_pdb(align_out)
    aa_main, aa_read, aa_ridx, aa_r_ridx = sequence_reader(aa)
    aa_pk = find_atom(aa_main, ap_ca)
    
    # save
    a2h_out = a2h_path / f"{pdb}_a2h.pkl"
    a2h = {'APO': aa_pk, 'HOLO': ho_pk}
    with open(a2h_out, 'wb') as f:
        pd.to_pickle(a2h, f)

print('done')

In [10]:
# from rdkit import Chem, RDLogger
# RDLogger.DisableLog('rdApp.*')

# for _, row in data.iterrows():
#     pdb = row['PDB']
#     st = row['SET']
    
#     if pdb in no_match:
#         continue
    
#     if st == 'PDB':
#         continue
#     path = f'data/raw/pdb/{st}/{pdb}/{pdb}_ligand.mol2'
#     mol = Chem.MolFromMol2File(path)
#     if mol:
#         smi = Chem.MolToSmiles(mol)
#         smi = Chem.CanonSmiles(smi)
#     else:
#         print(pdb, st, smi)

In [11]:
with open('data/raw/rmsd_results.txt', 'w') as f:
    for k, v in rmsd_results.items():
        f.write(f"{k}: {v}\n")

In [None]:
with open(a2h_out, 'rb') as f:
    loaded_bundle = pd.read_pickle(f)
loaded_bundle['APO']

In [None]:
final = pd.DataFrame(final_data)
final.to_csv('data/final.csv', index=False, header=True)
final

In [None]:
final = final[~final['PDB'].isin(no_match)]
final['SET'].value_counts()

In [None]:
final[final['SET'] == 'Refined']['UniProt'].value_counts()

In [None]:
align_pdbs = [f.stem.split('-')[3] for f in apo_align_path.glob('*.pdb')]
len(align_pdbs) # exclude 29 pdbs (no_match)

In [None]:
import json
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

general = final[final['SET'] == 'General']
refined = final[final['SET'] == 'Refined']

seeds = [42, 100, 123, 456, 789]
for idx, seed in enumerate(seeds):
    # minimum sample per UniProt
    num_uni = refined['UniProt'].nunique()
    samples_per_uni = 1000 // num_uni

    # sampling each UniProt (maximum samples_per_uni)
    sampled = refined.groupby('UniProt', group_keys=False).apply(
        lambda x: x.sample(n=min(samples_per_uni, len(x)), random_state=seed)
    )

    # remaining samples
    remaining = 1000 - len(sampled)
    if remaining > 0:
        additional = refined[~refined.index.isin(sampled.index)].sample(n=remaining, random_state=seed)
        valid = pd.concat([sampled, additional])

    #check 
    valid['UniProt'].value_counts()

    # create training set
    remain_refined = refined[~refined.index.isin(valid.index)]
    train = pd.concat([general, remain_refined]).reset_index(drop=True)

    # extract pdb list
    train_pdb = train['PDB'].tolist()
    valid_pdb = valid['PDB'].tolist()
    fold_pdb = {'SEED': seed, 'TRN': train_pdb, 'VAL': valid_pdb}

    # save
    with open(f'cache/fold_{idx+1}.json', 'w') as f:
        json.dump(fold_pdb, f)
    
    # check
    print(f'Fold {idx+1}: Seed {seed}')
    print(len(train), len(valid))
    print()

In [None]:
num_uni

In [None]:
print('done')

In [None]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
count = 0
box = {}
no_rdkit = []
for _, row in final.iterrows():
    pdb = row['PDB']
    st = row['SET']
    # smi = row['Ligand']
    # mol = Chem.MolFromSmiles(smi)
    # if mol is None:
    #     count += 1
    #     box[st] = box.get(st, 0) + 1
    # else:
    #     smi = Chem.CanonSmiles(Chem.MolToSmiles(mol))
    
    
    path = f'data/raw/pdb/{st}/{pdb}/{pdb}_ligand.mol2'
    mol = Chem.MolFromMol2File(path)
    if mol is None:
        count += 1
        box[st] = box.get(st, 0) + 1
        no_rdkit.append(pdb)
        print(pdb, st, row['Ligand'])
    else:
        smi = Chem.CanonSmiles(Chem.MolToSmiles(mol))            
print(box, count)

In [None]:
ff = final[final['PDB'] != '2qvu'].reset_index(drop=True)
len(ff), len(ff['UniProt'].unique())


In [25]:
ff.to_csv('data/data_info.csv', index=False, header=True)

# Dummy

In [None]:
# search largest vector
import pickle
from pathlib import Path

def find_largest_vector(directory):
    directory = Path(directory)
    max_count = 0
    max_ca = 0
    max_file = None
    max_ca_file = None

    with open(a2h_path / f"{intereset}_a2h.pkl", 'rb') as f:
        lb = pd.read_pickle(f)
        count = len(data['APO'])
        ca = pd.DataFrame(data['APO'])
        ca = ca[ca['atom_name'] == 'CA']
        if count > max_count:
            max_count = count
            max_file = pkl_file
        
        ca_count = len(ca)
        if ca_count > max_ca:
            max_ca = ca_count
            max_ca_file = pkl_file

    print(f"Largest vector: {max_file.name} with {max_count} entries")
    print(f"Largest CA vector: {max_ca_file.name} with {max_ca} entries")

find_largest_vector("data/raw/af_align")

In [None]:
import pickle
from pathlib import Path
import pandas as pd
import numpy as np

def find_files_with_coord_diff(directory, threshold=1e-3):
    directory = Path(directory)
    files_with_diff = []

    for pkl_file in directory.glob("*.pkl"):
        with open(pkl_file, "rb") as f:
            data = pickle.load(f)
        apo = pd.DataFrame(data['APO'])
        holo = pd.DataFrame(data['HOLO'])

        # 기준: res_seq, atom_name이 같은 원자끼리 좌표 비교
        merged = pd.merge(
            apo, holo,
            on=['res_seq', 'atom_name'],
            suffixes=('_apo', '_holo')
        )

        # coord 컬럼에서 좌표 추출
        def extract_coords(row, prefix):
            return np.array(row[f'coord_{prefix}'])

        # 좌표 차이 계산
        merged['coord_diff'] = merged.apply(
            lambda row: np.linalg.norm(
                extract_coords(row, 'apo') - extract_coords(row, 'holo')
            ),
            axis=1
        )

        # 임계값(threshold) 이상 차이가 있는 경우
        if (merged['coord_diff'] > threshold).any():
            files_with_diff.append(pkl_file.name)

    print(f"Files with coordinate differences (>{threshold} Å):")
    for fname in files_with_diff:
        print(fname)

find_files_with_coord_diff("APO_to_Holo", threshold=0.1)  # 0.1Å 이상 차이

In [31]:
for _, row in ff.iterrows():
    if '~' in row['Ligand']:
        print(row['PDB'], row['SET'], row['Ligand'])

3sjt General C[C@]([NH3+])(CCC[CH2]~B(O)(O)O)C(=O)[O-]
3skk General [NH3+][C@@](CCC[CH2]~B(O)(O)O)(C(=O)[O-])C(F)F
4kii General CCCCCCCCCCCC(=O)[N+]12c3cccc[n]3~[Rh+2]1~[n]1ccccc12
1ijr General CC(=O)N[C@@H](Cc1ccc(OCC(=O)[O-])[c](~[P+](=O)([O-])[O-])c1)C(=O)N[C@@H](C)c1ccc(OCC2CCCCC2)c(C(N)=O)c1
4z46 General C1CC[C@H]2[NH2+]~[Pt+]~[NH2+][C@@H]2C1


In [256]:
sse_path = Path('data/preprocessed/sse')
no_sse_pdb = []
for _, row in ff.iterrows():
    pdb = row['PDB']
    sse_global = sse_path / 'global' / f'{pdb}.csv'
    sse_pocket = sse_path / 'pocket' / f'{pdb}.csv'
    if not sse_global.exists() or not sse_pocket.exists():
        no_sse_pdb.append(pdb)
    # if '~' in row['Ligand']:
    #     print(row['PDB'], row['SET'], row['Ligand'])
no_sse_pdb

['1uld', '1q0y', '1xw6']

In [43]:
import shutil
cc = Path('data/raw/copy_csar')
cc.mkdir(parents=True, exist_ok=True)
csar_path = Path('data/raw/pdb/CSAR')
for pdb in csar_path.glob('*'):
    file = [f for f in pdb.glob('*complex.mol2')][0]
    shutil.copy(file, cc / f'{pdb.stem}_protein.mol2')
print('done')

done


In [250]:
def dssp_to_ss8(dssp_path):
    ss8_seq = []
    with open(dssp_path, 'r') as f:
        parsing = False
        for line in f:
            if not parsing:
                if line.startswith("  #  RESIDUE AA"):
                    parsing = True
                continue

            if len(line) < 17:
                print("over")
                continue
            
            ss_char = line[16]
            # print(line.strip())
            if ss_char == ' ':
                ss_char = 'C'  # loop
            ss8_seq.append(ss_char)

    return ''.join(ss8_seq)


In [251]:
tmp = r"data\preprocessed\CSAR_DSSP\1ax1.dssp"
dssp_to_ss8(tmp)


'CEEEEEEESSCCTTCSSEEEEETCEECTTSCEESSCBCTTSPBPSSCEEEEEESSCEECBCTTTCCBCEEEEEEEEECCCCCSSSCCCEEEEEEEECTTCCCCBCGGGTTTBSSSSCCGGGCCEEEEEECSCCTTSCSCSSEEEEEESSSSCSEEEECCCCTTCCEEEEEEEETTTTEEEEEEEETTTTEEEEEEEECCGGGTSCSEEEEEEEEEECSSTTCCCCCEEEEEEEEEEECC'

In [74]:
ss8_out = Path('data/preprocessed/csar.out.ss8')
dssps = Path('data/preprocessed/CSAR_DSSP').glob('*.dssp')

with open(ss8_out, 'w') as f:
    for dssp in dssps:
        ss8 = dssp_to_ss8(dssp)
        f.write(f'>{dssp.stem}\n{ss8}\n')

# SSE

In [82]:
csar_ff = ff[ff['SET'] == 'CSAR'].reset_index(drop=True)
pocs = []
glos = []
for _, row in csar_ff.iterrows():
    pdb = row['PDB']
    poc = row['Pocket']
    glo = row['Global']
    pocs.append({'id': pdb, 'seq': poc})
    glos.append({'id': pdb, 'seq': glo})

pocs = pd.DataFrame(pocs)
glos = pd.DataFrame(glos)
pocs.to_csv('data/preprocessed/csar_pockets.csv', index=False)
glos.to_csv('data/preprocessed/csar_proteins.csv', index=False)


In [155]:
a2h_path = Path('data/preprocessed/a2h')
pocket_inform = {}
for _, row in csar_ff.iterrows():
    pdb = row['PDB']
    with open(a2h_path / f'{pdb}_a2h.pkl', 'rb') as f:
        data = pickle.load(f)['HOLO']
    cas = data[data['atom_name'] == 'CA'][['res_seq', 'chain_id']].values
    pocket_inform[pdb] = [f"{v}-{k}" for k, v in cas]

In [257]:
with open(a2h_path / f'1uld_a2h.pkl', 'rb') as f:
    data = pickle.load(f)['HOLO']
data[data['atom_name'] == 'CA']

Unnamed: 0,record,atom_serial,atom_name,alt_loc,res_name,chain_id,res_seq,i_code,x,y,z,atom_type,occupancy,b_factor,element,charge
1,ATOM,3255,CA,,GLU,B,208,,37.173,13.994,29.573,C.3,1.0,0.0,C,0.0397
10,ATOM,4363,CA,,ASN,B,276,,46.293,22.036,41.599,C.3,1.0,0.0,C,0.0143
18,ATOM,4189,CA,,ARG,B,265,,28.335,17.624,52.664,C.3,1.0,0.0,C,-0.2637
29,ATOM,2941,CA,,VAL,B,188,,41.524,21.656,40.304,C.3,1.0,0.0,C,-0.0875
36,ATOM,3284,CA,,VAL,B,210,,33.537,17.256,32.289,C.3,1.0,0.0,C,-0.0875
43,ATOM,3466,CA,,ALA,B,221,,37.59,8.616,50.321,C.3,1.0,0.0,C,0.0337
48,ATOM,2977,CA,,ASN,B,190,,39.6,22.991,46.478,C.3,1.0,0.0,C,0.0143
56,ATOM,3500,CA,,GLY,B,223,,32.013,10.789,47.963,C.3,1.0,0.0,C,-0.0252
60,ATOM,3319,CA,,VAL,B,212,,33.541,17.011,38.645,C.3,1.0,0.0,C,-0.0875
67,ATOM,3180,CA,,SER,B,203,,37.984,19.221,38.963,C.3,1.0,0.0,C,-0.0249


In [260]:
success = []
csar_pocket_dssp = Path('data/preprocessed/CSAR_DSSP_pocket')
csar_pocket_dssp.mkdir(parents=True, exist_ok=True)
dssps = Path('data/preprocessed/CSAR_DSSP').glob('*.dssp')
for dssp in dssps:
    pdb = dssp.stem
    # if not pdb in pocket_inform:
    #     continue
    if pdb != '1uld':
        continue
    
    box = []
    new_num = 1
    with open(dssp, 'r') as f:
        parsing = False
        for line in f:
            if line.startswith('  #  RESIDUE AA'):
                box.append(line)
                parsing = True
                continue
            if parsing:
                if line[13] == '!':
                    continue
                num = int(line[:5])
                resid = int(line[6:10])
                chain = line[11]
                
                key = f"{chain}-{resid}"
                if key in pocket_inform[pdb]:
                    box.append(f"{new_num:>5}" + line[5:])
                    new_num += 1
                else:
                    if box[-1].startswith('  #  RESIDUE AA'):
                        continue
                    if box and box[-1][13] != '!':
                        box.append(f"{new_num:>5}{'!':>9}\n")
                        new_num += 1
            else:
                box.append(line)
        
        if not box[-1].startswith('  #  RESIDUE AA'):
            box = box[:-1]
            with open(csar_pocket_dssp / f'{pdb}.dssp', 'w') as f:
                for line in box:
                    f.write(line)
            success.append(pdb)
print('done')
print(len(success))

['==== Secondary Structure Definition by the program DSSP, NKI version 4.5.1                         ==== DATE=2025-05-13        .\n', 'REFERENCE W. KABSCH AND C.SANDER, BIOPOLYMERS 22 (1983) 2577-2637                                                              .\n', 'HEADER    SUGAR BINDING PROTEIN                   12-SEP-03   1ULD                                                             .\n', 'COMPND    MOL_ID: 1; MOLECULE: galectin-2; CHAIN: A, B, C, D; SYNONYM: CGL2; ENGINEERED: YES                                   .\n', 'SOURCE    MOL_ID: 1; GENE: cgl2; ORGANISM_SCIENTIFIC: Coprinopsis cinerea; ORGANISM_TAXID: 5346; EXPRESSION_SYSTEM: Sacchar... .\n', 'AUTHOR    P.J.Walser; P.W.Haebel; M.Kuenzler; U.Kues; M.Aebi; N.Ban                                                            .\n', '  600  4  0  0  0 TOTAL NUMBER OF RESIDUES, NUMBER OF CHAINS, NUMBER OF SS-BRIDGES(TOTAL,INTRACHAIN,INTERCHAIN)                .\n', ' 27366.9   ACCESSIBLE SURFACE OF PROTEIN (ANGSTROM**2)      

IndexError: list index out of range

In [259]:
box

['==== Secondary Structure Definition by the program DSSP, NKI version 4.5.1                         ==== DATE=2025-05-13        .\n',
 'REFERENCE W. KABSCH AND C.SANDER, BIOPOLYMERS 22 (1983) 2577-2637                                                              .\n',
 'HEADER    SUGAR BINDING PROTEIN                   12-SEP-03   1ULD                                                             .\n',
 'COMPND    MOL_ID: 1; MOLECULE: galectin-2; CHAIN: A, B, C, D; SYNONYM: CGL2; ENGINEERED: YES                                   .\n',
 'SOURCE    MOL_ID: 1; GENE: cgl2; ORGANISM_SCIENTIFIC: Coprinopsis cinerea; ORGANISM_TAXID: 5346; EXPRESSION_SYSTEM: Sacchar... .\n',
 'AUTHOR    P.J.Walser; P.W.Haebel; M.Kuenzler; U.Kues; M.Aebi; N.Ban                                                            .\n',
 '  600  4  0  0  0 TOTAL NUMBER OF RESIDUES, NUMBER OF CHAINS, NUMBER OF SS-BRIDGES(TOTAL,INTRACHAIN,INTERCHAIN)                .\n',
 ' 27366.9   ACCESSIBLE SURFACE OF PROTEIN (ANGSTROM**2

In [220]:
import pandas as pd
import numpy as np
import glob
import re
import os
import sys

pred_sse = Path('data/preprocessed/PredSSE_pocket')
pred_sse.mkdir(parents=True, exist_ok=True)

f=open(r"data/preprocessed/csar.out.ss8", "r")   
for line in f.readlines():
    if '>' in line:
        pdbname=line.strip().split('>')[1]    
    else:
        num = -1
        i= -1
        sse =''
        for s in line:
            num +=1 
#             print(num)
            sse= sse+s
        try:
            with open(r'data/preprocessed/CSAR_DSSP_pocket/%s.dssp'%(pdbname), "r") as fg:
                for _ in range(28):
                    next(fg)
                for line1 in fg:
                    if line1[13]=='!':
                        continue
                    else:
                        i +=1
                        if i <= num:
                            line1=line1[:16] + sse[i] + line1[17:]
                            with open(r'data/preprocessed/PredSSE_pocket/%s_pocket.dssp'%(pdbname),"a+") as fw:
                                fw.write(line1)
        except:
            continue
f.close()

In [85]:
import pandas as pd
import numpy as np
import glob
import re
import os
import sys

pred_sse = Path('data/preprocessed/PredSSE')
pred_sse.mkdir(parents=True, exist_ok=True)

f=open(r"data/preprocessed/csar.out.ss8", "r")   
for line in f.readlines():
    if '>' in line:
        pdbname=line.strip().split('>')[1]    
    else:
        num = -1
        i= -1
        sse =''
        for s in line:
            num +=1 
#             print(num)
            sse= sse+s
        with open(r'data/preprocessed/CSAR_DSSP/%s.dssp'%(pdbname), "r") as fg:
            for _ in range(28):
                next(fg)
            for line1 in fg:
                if line1[13]=='!':
                    continue
                else:
                    i +=1
                    if i <= num:
                        line1=line1[:16] + sse[i] + line1[17:]
                        with open(r'data/preprocessed/PredSSE/%s_protein.dssp'%(pdbname),"a+") as fw:
                            fw.write(line1)
f.close()

In [233]:
from tqdm import tqdm
import pandas as pd
import numpy as np
import sys
import glob
import re
import os

columns_g = ['residue','chain','aa','structure']
columns_l = ['residue','chain','aa','structure']
structure = ('B', 'C', 'E', 'G', 'H', 'I', 'S', 'T')
aa = ('G', 'A', 'V', 'L', 'I', 'M', 'F', 'P', 'W', 'S', 'T', 'Y', 'C', 'Q', 'N', 'D', 'E', 'K', 'R', 'H', 'X')

c1 = {
    'non_polar' : ('G','A','V','L','I','M','F','P','W'),
    'polar' : ('S','T','Y','C','Q','N'),
    'acidic' : ('D','E'),
    'basic' : ('K','R','H'),
}

def f1(aa, key):
    if aa=='X':
        return 1/len(c1)
    return 1 if aa in c1[key] else 0


c2 = {
    1 : ('A','G','V'),
    2 : ('I','L','F','P'),
    3 : ('Y','M','T','S'),
    4 : ('H','N','Q','W'),
    5 : ('R','K'),
    6 : ('D','E'),
    7 : ('C',)
}

def f2(aa, key):
    if aa=='X':
        return 1/len(c2)
    return 1 if aa in c2[key] else 0


def idx_df_l_init(df_g):
    chains = df_g.chain.unique()
    
    if len(chains) == 1:
        return lambda row: chains[0]+row.residue
    
    i = -1
    cidx = 0
    
    def idx_df_l(row):
        nonlocal i
        nonlocal df_g

        while True:
            i+=1
            row_g = df_g.iloc[i]
            
            if row_g.residue==row.residue:
                if row.aa != row_g.aa:
                    continue
                if row.chain and row.chain!=row_g.chain:
                    continue
                break
            
        return row_g.chain+row.residue
    return idx_df_l

In [245]:
f_g = glob.glob('data/preprocessed/PredSSE/*')
f_l = glob.glob('data/preprocessed/PredSSE_pocket/*')

demo = set(pd.read_csv('data/preprocessed/csar_proteins.csv').id)
demo = set(i for i in demo if i in success)
dataset = {
    'demo': demo.copy(),
}
columns_s = ['id','seq']
seq_dfs = {
    'demo': pd.DataFrame(columns=columns_s),
}
pocket_dfs = {
    'demo': pd.DataFrame(columns=columns_s),
}
other_seq = set()


for fi_g in tqdm(f_g):
    seq_name = os.path.basename(fi_g)[:4]
    
    flag = True
    phase = None
    for k in dataset:
        if seq_name in dataset[k]:
            dataset[k].remove(seq_name)
            phase = k
            flag = False
            break
    if flag:
        other_seq.add(seq_name)
        continue
    fi_l = os.path.join(os.path.dirname(fi_g),os.path.basename(fi_g)).replace('protein','pocket')
    if not os.path.exists(fi_l):
        print(f'{fi_l} not exists')
        continue
   
    df_g = pd.DataFrame(columns=columns_g)
    with open(fi_g) as f:
        for line in f:
            if line[13]=='!':
                continue
            else:
                row = pd.Series([line[5:11].strip(),line[11:12].strip(),line[12:14].strip(),line[16:17].strip()],
                           index=columns_g)
            df_g = pd.concat([df_g, row.to_frame().T], ignore_index=True)
            # df_g = df_g.append(row, ignore_index=True)
     
    df_g.aa = df_g.aa.apply(lambda x: 'C' if re.match('[A-Z]',x) is None else x)
    df_g.structure = df_g.structure.apply(lambda x: 'C' if re.match('[A-Z]',x) is None else x)

    assert len(df_g.structure.unique())<= 8
    assert len(df_g.aa.unique())<= 21
    
    for key in c1:
        new_feature = df_g.aa.apply(lambda x:f1(x,key))
        df_g = eval(f'df_g.assign({key}=new_feature)')
    
    for key in c2:
        new_feature = df_g.aa.apply(lambda x:f2(x,key))
        df_g = eval(f'df_g.assign(c2_{key}=new_feature)')
    
    for si in structure:
        new_feature = df_g.structure.apply(lambda x:1 if x==si else 0)
        df_g = eval(f'df_g.assign(s2_{si}=new_feature)')
    
    for ai in aa:
        new_feature = df_g.aa.apply(lambda x:1 if x==ai else 0)
        df_g = eval(f'df_g.assign(a_{ai}=new_feature)')
    
    # seq_dfs[phase] = seq_dfs[phase].append(pd.Series([seq_name,''.join(df_g.aa)],index=columns_s),ignore_index=True)
    seq_dfs[phase] = pd.concat([seq_dfs[phase], pd.Series([seq_name,''.join(df_g.aa)],index=columns_s)], ignore_index=True)
    
    df_g = df_g.assign(idx=df_g.chain+df_g.residue)
    
    
    df_l = pd.DataFrame(columns=columns_l)
    with open(fi_l) as f:
        for line in f:
            if line[13]=='!':
                continue
            else:
                row = pd.Series([line[5:11].strip(),line[11:12].strip(),line[12:14].strip(),line[16:17].strip()],
                           index=columns_l)
            
            # df_l = df_l.append(row, ignore_index=True)
            df_l = pd.concat([df_l, row.to_frame().T], ignore_index=True)
    
    df_l.aa = df_l.aa.apply(lambda x: 'C' if re.match('[A-Z]',x) is None else x)
    
    # pocket_dfs[phase] = pocket_dfs[phase].append(pd.Series([seq_name,''.join(df_l.aa)],index=columns_s),ignore_index=True)
    pocket_dfs[phase] = pd.concat([pocket_dfs[phase], pd.Series([seq_name,''.join(df_l.aa)],index=columns_s)], ignore_index=True)
    
    try:
        df_l = df_l.assign(idx=df_l.apply(idx_df_l_init(df_g),axis=1))
    except:
        print(f"{seq_name} error: {sys.exc_info()[0]}")
        continue
    
    df_g.drop('aa',axis=1,inplace=True)
    df_g.drop('structure',axis=1,inplace=True)
    df_g.drop('residue',axis=1,inplace=True)
    df_g.drop('chain',axis=1,inplace=True)
    
    df_g.to_csv(f'{phase}/global/{seq_name}.csv')
    
    df_l.drop('aa',axis=1,inplace=True)
    df_l.drop('structure',axis=1,inplace=True)
    df_l.drop('residue',axis=1,inplace=True)
    df_l.drop('chain',axis=1,inplace=True)
    
    df_l = df_l.merge(df_g)
    
    df_l.to_csv(f'{phase}/pocket/{seq_name}.csv')
    
    del df_g
    del df_l

for k in pocket_dfs:
    pocket_dfs[k].to_csv(f'{k}_pocket_.csv')

100%|██████████| 125/125 [00:02<00:00, 58.16it/s]


In [248]:
lens = []
for fi in Path('demo/pocket').glob('*.csv'):
    df = pd.read_csv(fi)
    lens.append(len(df))
print(sorted(lens, reverse=True))

[102, 92, 85, 83, 79, 79, 78, 77, 77, 75, 74, 73, 72, 70, 70, 68, 68, 67, 67, 67, 66, 65, 65, 65, 62, 61, 60, 59, 58, 57, 57, 49, 49, 48, 47, 45, 45, 44, 42, 41, 39, 38, 37, 36, 36, 20, 19, 13, 11, 6, 1]


In [249]:
len([102, 92, 85, 83, 79, 79, 78, 77, 77, 75, 74, 73, 72, 70, 70, 68, 68, 67, 67, 67, 66, 65, 65, 65])

24

In [254]:
lens = []
for fi in Path('demo/pocket').glob('*.csv'):
    if fi.stem.split('_')[0] in no_sse_pdb:
        print(fi.stem.split('_')[0])
        df = pd.read_csv(fi)
        lens.append(len(df))
print(sorted(lens, reverse=True))

1q6e
1q6g
1s7y
1txf
1y93
1zhx
2dm5
4ubp
[102, 78, 77, 75, 68, 67, 20, 11]


In [255]:
lens = []
for fi in Path('data/preprocessed/sse/pocket').glob('*.csv'):
    df = pd.read_csv(fi)
    lens.append(len(df))
print(sorted(lens, reverse=True))

[137, 125, 124, 124, 123, 121, 120, 119, 119, 119, 118, 116, 116, 116, 116, 115, 115, 115, 114, 114, 114, 114, 114, 113, 112, 112, 112, 112, 111, 111, 111, 111, 111, 110, 110, 110, 109, 109, 109, 108, 108, 107, 107, 107, 107, 106, 106, 106, 106, 106, 106, 106, 105, 104, 103, 103, 103, 101, 100, 100, 99, 99, 98, 98, 98, 98, 98, 96, 95, 95, 95, 94, 94, 94, 94, 94, 94, 94, 93, 93, 93, 93, 92, 92, 92, 92, 92, 92, 91, 91, 91, 90, 90, 90, 90, 90, 90, 90, 90, 89, 89, 89, 89, 89, 89, 89, 88, 88, 88, 88, 88, 88, 88, 87, 87, 87, 87, 87, 87, 87, 87, 86, 86, 86, 86, 86, 86, 86, 86, 86, 85, 85, 85, 85, 85, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 84, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 83, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 82, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 81, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,

In [263]:
fff = ff[~ff['PDB'].isin(no_sse_pdb)].reset_index(drop=True)
fff['SET'].value_counts()


SET
General    7412
Refined    2732
CORE        229
CSAR         62
Name: count, dtype: int64

In [264]:
fff.to_csv('data/new_data_info.csv', index=False)

In [271]:
aims = ['rmse', 'loss']
models = ['DeepDTA', 'DeepDTAF', 'CAPLA']
loss = ['mse_mean', 'mse_sum']
apex = [True, False]
scheduler = [True, False]
epoch = [20, 50]

for aim in aims:
    for model in models:
        for l in loss:
            for a in apex:
                for s in scheduler:
                    for e in epoch:
                        if apex:
                            if scheduler:
                                command = f"python main.py --model {model} --project {model}_{aim}_{l}_ap-{a}_sh-{s}_epoch{e} --n_epoch {e} --loss {l} --apex --scheduler --aim {aim}"
                            else:
                                command = f"python main.py --model {model} --project {model}_{aim}_{l}_ap-{a}_sh-{s}_epoch{e} --n_epoch {e} --loss {l} --apex --aim {aim}"
                        else:
                            if scheduler:
                                command = f"python main.py --model {model} --project {model}_{aim}_{l}_ap-{a}_sh-{s}_epoch{e} --n_epoch {e} --loss {l} --scheduler --aim {aim}"
                            else:
                                command = f"python main.py --model {model} --project {model}_{aim}_{l}_ap-{a}_sh-{s}_epoch{e} --n_epoch {e} --loss {l} --aim {aim}"
                        print(command)


python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-True_sh-True_epoch20 --n_epoch 20 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-True_sh-True_epoch50 --n_epoch 50 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-True_sh-False_epoch20 --n_epoch 20 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-True_sh-False_epoch50 --n_epoch 50 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-False_sh-True_epoch20 --n_epoch 20 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-False_sh-True_epoch50 --n_epoch 50 --loss mse_mean --apex --scheduler --aim rmse
python main.py --model DeepDTA --project DeepDTA_rmse_mse_mean_ap-False_sh-False_epoch20 --n_epoch 20 --loss mse_mean --

In [285]:
models = ['DeepDTA', 'DeepDTAF', 'CAPLA']
aims = ['rmse', 'mse']
reduction = ['mse_mean', 'mse_sum']
epoch = [20, 40, 60, 80, 100]

for m in models:
    for a in aims:
        for r in reduction:
            for e in epoch:
                command = f"python main.py --model {m} --n_epoch {e} --loss {r} --aim {a} --project {m}_{a}_{r.split('_')[1]}_epoch{e}"
                print(command)


python main.py --model DeepDTA --n_epoch 20 --loss mse_mean --aim rmse --project DeepDTA_rmse_mean_epoch20
python main.py --model DeepDTA --n_epoch 40 --loss mse_mean --aim rmse --project DeepDTA_rmse_mean_epoch40
python main.py --model DeepDTA --n_epoch 60 --loss mse_mean --aim rmse --project DeepDTA_rmse_mean_epoch60
python main.py --model DeepDTA --n_epoch 80 --loss mse_mean --aim rmse --project DeepDTA_rmse_mean_epoch80
python main.py --model DeepDTA --n_epoch 100 --loss mse_mean --aim rmse --project DeepDTA_rmse_mean_epoch100
python main.py --model DeepDTA --n_epoch 20 --loss mse_sum --aim rmse --project DeepDTA_rmse_sum_epoch20
python main.py --model DeepDTA --n_epoch 40 --loss mse_sum --aim rmse --project DeepDTA_rmse_sum_epoch40
python main.py --model DeepDTA --n_epoch 60 --loss mse_sum --aim rmse --project DeepDTA_rmse_sum_epoch60
python main.py --model DeepDTA --n_epoch 80 --loss mse_sum --aim rmse --project DeepDTA_rmse_sum_epoch80
python main.py --model DeepDTA --n_epoch 10

In [284]:
import json
summaries = []
results = [fd for fd in Path('logs').glob('*') if fd.stem != 'DeepDTAF_log']
for r in results:
    name = r.stem 
    md = name.split('_')[0]
    aim = name.split('_')[1]
    mse_reduction = name.split('_')[3]
    apex = name.split('_')[4].split('-')[1]
    scheduler = name.split('_')[5].split('-')[1]
    epoch = name.split('_')[6][5:]

    summary = {'model': md, 
               'aim': aim, 
               'mse_reduction': mse_reduction, 
               'apex': apex, 'scheduler': scheduler, 
               'epoch': epoch}
    
    file = r / 'average_test_metrics.json'
    if not file.exists():
        continue

    with open(file, 'r') as f:
        tst_result = json.load(f)
    for dt in ['CORE', 'CSAR']:
        rmse = tst_result[dt]['rmse']
        mae = tst_result[dt]['mae']
        pcc = tst_result[dt]['pcc']
        sd = tst_result[dt]['sd']
        ci = tst_result[dt]['ci']
        
        summary[f'{dt}_rmse'] = rmse
        summary[f'{dt}_mae'] = mae
        summary[f'{dt}_pcc'] = pcc
        summary[f'{dt}_sd'] = sd
        summary[f'{dt}_ci'] = ci
    
    summaries.append(summary)
    
summaries = pd.DataFrame(summaries)
summaries.to_csv('summaries.csv', index=False)
summaries


Unnamed: 0,model,aim,mse_reduction,apex,scheduler,epoch,CORE_rmse,CORE_mae,CORE_pcc,CORE_sd,CORE_ci,CSAR_rmse,CSAR_mae,CSAR_pcc,CSAR_sd,CSAR_ci
0,CAPLA,rmse,mean,True,False,20,1.510282,1.226740,0.708645,1.487428,0.756714,2.222766,1.721082,0.571250,2.165147,0.704666
1,CAPLA,rmse,mean,True,False,50,1.527745,1.223725,0.700007,1.505596,0.752728,2.144618,1.693175,0.602664,2.107476,0.710923
2,CAPLA,rmse,mean,True,True,20,1.510282,1.226740,0.708645,1.487428,0.756714,2.222766,1.721082,0.571250,2.165147,0.704666
3,CAPLA,rmse,mean,True,True,50,1.527745,1.223725,0.700007,1.505596,0.752728,2.144618,1.693175,0.602664,2.107476,0.710923
4,DeepDTAF,loss,mean,False,False,20,1.537258,1.238472,0.700638,1.503907,0.750762,2.212589,1.779897,0.537089,2.224508,0.702969
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
63,DeepDTA,rmse,sum,False,True,50,1.483163,1.210515,0.735650,1.426866,0.767094,1.831741,1.433866,0.747953,1.752715,0.786373
64,DeepDTA,rmse,sum,True,False,20,1.441149,1.167233,0.739143,1.413604,0.771426,1.741605,1.343187,0.759746,1.715038,0.789873
65,DeepDTA,rmse,sum,True,False,50,1.483163,1.210515,0.735650,1.426866,0.767094,1.831741,1.433866,0.747953,1.752715,0.786373
66,DeepDTA,rmse,sum,True,True,20,1.441149,1.167233,0.739143,1.413604,0.771426,1.741605,1.343187,0.759746,1.715038,0.789873


In [281]:
base_df = pd.read_csv('data/deepdtaf_data.csv')
base_df[base_df['PDB'] == '5lz4'] # TST71은 데이터 제공 X

Unnamed: 0.1,Unnamed: 0,PDB,Ligand,Ligand_Len,Pocket,Pocket_Len,Global,Global_Len,Set,Affinity
13322,13322,5lz4,CC1(C)CN(CCOc2ccc(c(C#N)c2)F)C(=O)C1,36,WIYLFLSHGLGAFTLYSSSRRHSFGGATLDAWMFPFYNHQFADFTFAL,48,KIPRGNGPYSVGCTDLMFDHTNKGTFLRLYYPSQDNDRLDTLWIPN...,371,TST71,5.7
