CSR SALAD Version 10.1
8 Oct 2025
by Jackson Cahn

Adapted from:
CSR SALAD Version 8
11 Dec 2015
by Jackson Cahn

Copyright California Institute of Technology
All rights reserved

TO use: change values in the first cell and run all. Results will appear at the bottom.

In [None]:
infile = './my_file.pdb'

max_size = 400 #maximum size of designed library, should be roughly 0.5x your screening capacity

ex_motif = True #Turn on to exclude glycine-rich motif from design space. Note that even when not excluded it is deprioritized.
ex_diphos = True #Same, for other diphosphate-contacting residues
ex_periph = False #Same, for more peripheral residues

verbose = False

In [2]:
#imports

import sys,os
from glob import glob
from Bio.PDB import *
from operator import itemgetter
from numpy import array,matrix, dot,reshape
from numpy.linalg import norm
from math import sqrt,acos,degrees
from itertools import product
import pandas as pd
from IPython.display import display

In [3]:
#Dictionaries

sidechaincenteratoms = { #the pseudocenter of each sidechain is defined as the mean xyz of these atoms
    'GLY': ['CA'],
    'ALA': ['CB'],
    'VAL': ['CG1', 'CG2'],
    'ILE': ['CG1','CG2','CD1'],
    'LEU': ['CD1', 'CD2', 'CG', 'CB'],
    'SER': ['OG'],
    'THR': ['OG1', 'CG2'],
    'ASP': ['OD1', 'OD2'],
    'ASN': ['OD1', 'ND2'],
    'GLU': ['OE1', 'OE2','CG'],
    'GLN': ['OE1', 'NE2'],
    'LYS': ['NZ', 'CE'],
    'ARG': ['NE', 'NH1', 'NH2'],
    'CYS': ['SG'],
    'CSO': ['SG', 'OD'],
    'MET': ['CE', 'SD', 'CG'],
    'MSE': ['SE'],
    'PHE': ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
    'TYR': ['CE1', 'CE2', 'CZ', 'OH'],
    'TRP': ['CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3'],
    'HIS': ['CG', 'ND1', 'CD2', 'CE1', 'NE2'],
    'PRO': ['CB', 'CG', 'CD'],
}

Plibraries = { #degenerate codons for is_nadP->is_nad libraries
    ('PHE', "Edge") : [['DFVY','KWC'],['DEFLVY*','KWK']],
    ('PHE', "Bidentate") : [['DFVY','KWC'],['DEFLVY*','KWK'],['DEFHLQVY*','BWK']],
    ('PHE', "Face") : [['F','---'],['FY','TWC'],['DFVY','KWC'],['FHLQY*','YWK']],
    ('PHE', "Simple") : [['DFVY','KWC'],['DEFLVY*','KWK']],
    ('LEU', "Edge") : [['ELQV','SWA']],
    ('LEU', "Bidentate") : [['ELQV','SWA'],['DEHLQV','SWK']],
    ('LEU', "Face") : [['FHLY','YWC'],['FILMST','WYK']],
    ('LEU', "Simple") : [['ELQV','SWA'],['DEHLQV','SWK']],
    ('ILE', "Edge") : [['EIKV','RWA']],
    ('ILE', "Bidentate") : [['EIKV','RWA']],
    ('ILE', "Face") : [['FINY','WWC'],['FHILNY','HWC']],
    ('ILE', "Simple") : [['EIKV','RWA']],
    ('MET', "Edge") : [['EKMV','RWG']],
    ('MET', "Bidentate") : [['EKMV','RWG']],
    ('MET', "Face") : [['FILM','WTK'],['FILMPST','HYK']],
    ('MET', "Simple") : [['EKMV','RWG']],
    ('VAL', "Edge") : [['DEV','GWK'],['ELQV','SWA'],['DEIKMNV','RWK']],
    ('VAL', "Bidentate") : [['DEV','GWK']],
    ('VAL', "Face") : [['DHLV','SWC'],['DHILNV','VWC']],
    ('VAL', "Simple") : [['DEV','GWK']],
    ('SER', "Edge") : [['INS','ADC'],['DGINSV','RDC']],
    ('SER', "Bidentate") : [['DGNS','RRC'],['ADGNST','RVC']],
    ('SER', "Face") : [['DGINSV','RDC']],
    ('SER', "Simple") : [['DGNS','RRC'],['DGHILNRSV','VDC']],
    ('PRO', "Edge") : [["Peripheral",'---'],['ADHP','SMC']],
    ('PRO', "Bidentate") : [['ADHP','SMC']],
    ('PRO', "Face") : [['ILPT','MYA']],
    ('PRO', "Simple") : [['ADHP','SMC']],
    ('THR', "Edge") : [['ADNT','RMC'],['ADEKNT','RMK']],
    ('THR', "Bidentate") : [['ADNT','RMC'],['ADEKNT','RMK']],
    ('THR', "Face") : [['INT','AHC'],['HNPT','MMC'],['IKMNRST','ANK']],
    ('THR', "Simple") : [['ADNT','RMC'],['ADEKNT','RMK']],
    ('ALA', "Edge") : [['AD','GMC'],['ADNT','RMC']],
    ('ALA', "Bidentate") : [['AS','KCA'],['ADGNST','RVC']],
    ('ALA', "Face") : [['APT','VCA'],['AILPTV','VYA']],
    ('ALA', "Simple") : [['AS','KCA'],['ADGNST','RVC']],
    ('TYR', "Edge") : [['FINY','WWC'],['FHILNY','HWC']],
    ('TYR', "Bidentate") : [['FINY','WWC'],['DFINVY','DWC']],
    ('TYR', "Face") : [['FY','TWC'],['FHLY','YWC'],['ADFINSTVY','DHC']],
    ('TYR', "Simple") : [['DEKNY*','DAK'],['DEFIKLMNVY*','DWK']],
    ('HIS', "Edge") : [['DHN','VAC'],['DEHKNQ','VAK']],
    ('HIS', "Bidentate") : [['DHN','VAC'],['DEHKNQ','VAK']],
    ('HIS', "Face") : [['HNPT','MMC']],
    ('HIS', "Simple") : [['DHN','VAC'],['DEHKNQ','VAK']],
    ('GLN', "Edge") : [['Q','---'],['EQ','SAA']],
    ('GLN', "Bidentate") : [['Q','---'],['EQ','SAA']],
    ('GLN', "Face") : [['LPQ','CHA']],
    ('GLN', "Simple") : [['Q','---'],['EQ','SAA']],
    ('LYS', "Edge") : [['EK','RAA'],['EKQ','VAA'],['EIKV','RWA']], 
    ('LYS', "Bidentate") : [['DEGKNRS','RRK']],
    ('LYS', "Face") : [['KLMQ','MWG'],['KLMPQT','MHG']],
    ('LYS', "Simple") : [['DEKN','RAK'],['DEIKMNV','RWK'],['DEFIKLMNVY*','DWK']],
    ('ASP', "Edge") : [['D','---'],['DE','GAK'],['DEHQ','SAK']],
    ('ASP', "Bidentate") : [['D','---'],['DE','GAK'],['DEHQ','SAK']],
    ('ASP', "Face") : [['D','---'],['DE','GAK'],['DEHQ','SAK']],
    ('ASP', "Simple") : [['D','---'],['DE','GAK'],['DEHQ','SAK']],
    ('GLU', "Edge") : [['E','---'],['EKQ','VAA']],
    ('GLU', "Bidentate") : [['E','---'],['EKQ','VAA']],
    ('GLU', "Face") : [['E','---'],['EKQ','VAA']],
    ('GLU', "Simple") : [['E','---'],['EKQ','VAA']],
    ('CYS', "Edge") : [['CDGY','KRC'],['CDFGVY','KDC']],
    ('CYS', "Bidentate") : [['CDGY','KRC']],
    ('CYS', "Face") : [['CFGV','KKC'],['CDFGVY','KDC']],
    ('CYS', "Simple") : [['CDGY','KRC']],
    ('TRP', "Edge") : [['LMRW','WKG'],['LQRW*','YDG']],
    ('TRP', "Bidentate") : [['QRW*','YRG'],['CHQRWY*','YRK']],
    ('TRP', "Face") : [['CFLW','TKK']],
    ('TRP', "Simple") : [['QRW*','YRG'],['CHQRWY*','YRK']],
    ('ARG', "Edge") : [['EGQR','SRA'],['EGLQRV','SDA']],
    ('ARG', "Bidentate") : [['DEGHQR','SRK']],
    ('ARG', "Face") : [['CHPRSY','YVC'],['CHNPRSTY','HVC'],['CFHILNPRSTY','HNC']],
    ('ARG', "Simple") : [['DEGHQR','SRK'],['DEGKNRS','RRK'],['DEGHKNQRS','VRK']],
    ('GLY', "Edge") : [['DG','GRC'],['DGV','GDC']],
    ('GLY', "Bidentate") : [['DG','GRC'],['ADGV','GNC']],
    ('GLY', "Face") : [['DGINSV','RDC']],
    ('GLY', "Simple") : [['DG','GRC'],['ADGV','GNC'],['DEGKNRS','RRK']],
    ('ASN', "Edge") : [['N','---'],['DEKN','RAK']],
    ('ASN', "Bidentate") : [['DN','RAC'],['DEHKNQ','VAK']],
    ('ASN', "Face") : [['FINY','WWC']], 
    ('ASN', "Simple") : [['DN','RAC'],['ADNT','RMC'],['ADEKNT','RMK']]
}

Nlibraries = { #degenerate codons for NAD->NADP libraries
    ('PHE', "Edge") : [['F','---'],['FIKLMNY*','WWK']],
    ('PHE', "Bidentate") : [['FS','TYC'],['FHLPSY','YHC']],
    ('PHE', "Face") : [['CFLR','YKC'],['CFHLRY','YDC']],
    ('PHE', "Simple") : [['FINSTY','WHC']],
    ('LEU', "Edge") : [['EIKLQV','VWA']],
    ('LEU', "Bidentate") : [['LS','TYA'],['ILST','WYA']],
    ('LEU', "Face") : [['LR','CKA'],['HLR','CDC'],['HLQR','CDK']],
    ('LEU', "Simple") : [['LR','CKA'],['IKLQR','MDA']],
    ('ILE', "Edge") : [['IKMN','AWK']],
    ('ILE', "Bidentate") : [['IKMNRS','ADK']],
    ('ILE', "Face") : [['IR','AKA'],['IKRT','ANA']],
    ('ILE', "Simple") : [['IKMNRS','ADK']],
    ('MET', "Edge") : [['KM','AWG'],['IKMN','AWK']],
    ('MET', "Bidentate") : [['KM','AWG'],['IKMNRS','ADK']],
    ('MET', "Face") : [['MR','AKG']],
    ('MET', "Simple") : [['KM','AWG'],['IKMNRS','ADK']],
    ('VAL', "Edge") : [['AITV','RYA'],['AEIKTV','RHA']],
    ('VAL', "Bidentate") : [['GISV','RKC'],['DGHILNRSV','VDC']],
    ('VAL', "Face") : [['GMRV','RKG'],['AGMRTV','RBG']],
    ('VAL', "Simple") : [['GISV','RKC'],['DGHILNRSV','VDC']],
    ('SER', "Edge") : [['AS','KCA'],['ALSV','KYA']],
    ('SER', "Bidentate") : [['GS','RGC'],['AGST','RSC']],
    ('SER', "Face") : [['GS','RGC'],['AGST','RSC']],
    ('SER', "Simple") : [["Simple",'---'],['GS','RGC'],['AGST','RSC']],
    ('PRO', "Edge") : [['LP','CYA'],['ALPV','SYA']],
    ('PRO', "Bidentate") : [['PS','YCA'],['PST','HCA'],['APST','NCA']],
    ('PRO', "Face") : [['PRST','MSC']],
    ('PRO', "Simple") : [['PS','YCA'],['PST','HCA'],['APST','NCA']],
    ('THR', "Edge") : [['AT','RCA'],['ADNT','RMC']],
    ('THR', "Bidentate") : [['AST','DCA']],
    ('THR', "Face") : [['AGRT','RSA'],['ADGHNPRST','VVC']],
    ('THR', "Simple") : [['AST','DCA']],
    ('ALA', "Edge") : [['A','---'],['AS','KCA'],['ALSV','KYA']],
    ('ALA', "Bidentate") : [['AGST','RSC']],
    ('ALA', "Face") : [['AGRT','RSA']],
    ('ALA', "Simple") : [['AGST','RSC']],
    ('TYR', "Edge") : [['FIKLMNY*','WWK']],
    ('TYR', "Bidentate") : [['HNY','HAC'],['HKNQY*','HAK']],
    ('TYR', "Face") : [['CHRY','YRC']],
    ('TYR', "Simple") : [['HNY','HAC'],['HKNQY*','HAK']],
    ('HIS', "Edge") : [['HILN','MWC']],
    ('HIS', "Bidentate") : [['HNRS','MRC']],
    ('HIS', "Face") : [['HNRS','MRC']],
    ('HIS', "Simple") : [['HNPT','MMC'],['HNPRST','MVC']],
    ('GLN', "Edge") : [['IKLQ','MWA']],
    ('GLN', "Bidentate") : [['KQR','MRA'],['KPQRT','MVA']],
    ('GLN', "Face") : [['KPQRT','MVA']],
    ('GLN', "Simple") : [['HKNQRS','MRK']],
    ('LYS', "Edge") : [['K','---'],['KN','AAK']],
    ('LYS', "Bidentate") : [['KN','AAK'],['KNT','AMK'],['KNRST','AVK']],
    ('LYS', "Face") : [['KNRS','ARK'],['HKNQRS','MRK']],
    ('LYS', "Simple") : [['KN','AAK'],['KNT','AMK'],['KNRST','AVK']],
    ('ASP', "Edge") : [['ADGNST','RVC'],['ADGINSTV','RNC']],
    ('ASP', "Bidentate") : [['ADNT','RMC'],['ADGNST','RVC']],
    ('ASP', "Face") : [['ADNT','RMC'],['ADHNPT','VMC']],
    ('ASP', "Simple") : [['DGNS','RRC'],['ADGNST','RVC']],
    ('GLU', "Edge") : [['AES*','KMA'],['AELSV*','KHA'],['AEIKLSTV*','DHA']],
    ('GLU', "Bidentate") : [['EGKQR','VRA'],['DEGHKNQRS','VRW']],
    ('GLU', "Face") : [['EGQR','SRA'],['AEGKPQRT','VVA']],
    ('GLU', "Simple") : [['AEKST*','DMA'],['ADEKNSTY*','DMK']],
    ('CYS', "Edge") : [['ACGS','KSC'],['ACFGSV','KBC']],
    ('CYS', "Bidentate") : [['ACGS','KSC']],
    ('CYS', "Face") : [['CHRY','YRC'],['CHNPRSTY','HVC']],
    ('CYS', "Simple") : [['ACGS','KSC']],
    ('TRP', "Edge") : [['LMRW','WKG']],
    ('TRP', "Bidentate") : [['KRSTW*','WVG']],
    ('TRP', "Face") : [['QRW*','YRG'],['CHQRWY*','YRK']],
    ('TRP', "Simple") : [['KRSTW*','WVG']],
    ('ARG', "Edge") : [['R','---'],['LR','CKA'],['CFHLRY','YDC']],
    ('ARG', "Bidentate") : [['KNRS','ARK'],['HKNQRS','MRK']],
    ('ARG', "Face") : [['R','---'],['QR','CRA'],['LPQR','CNA']],
    ('ARG', "Simple") : [['KNRS','ARK'],['HKNQRS','MRK']],
    ('GLY', "Edge") : [['G','---'],['AGV','GBA']],
    ('GLY', "Bidentate") : [['G','---'],['AGST','RSC']],
    ('GLY', "Face") : [['G','---'],['AGRT','RSA']],
    ('GLY', "Simple") : [['G','---'],['AGST','RSC']],
    ('ASN', "Edge") : [['NS','ARC'],['IKMN','AWK']],
    ('ASN', "Bidentate") : [['NS','ARC'],['HKNQRS','MRK']],
    ('ASN', "Face") : [['KNRST','AVK']],
    ('ASN', "Simple") : [['NS','ARC'],['HKNQRS','MRK']],
}

degen_nucleotide_sizes = {
    1 : ['A','G','C','T','-'],
    2 : ['Y','R','S','W','K','M'],
    3 : ['B','D','H','V'],
    4 : ['N'] }

degen_sizes_flat = {}
for k in degen_nucleotide_sizes.keys():
    for e in degen_nucleotide_sizes[k]:
        degen_sizes_flat[e]=k

_three_letter_to_one = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
        'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
        'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
        'ALA': 'A', 'VAL': 'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M',
        'MSE': 'M', 'CSO': 'C'}

In [4]:
"""Analyze cofactor binding sites."""

def analyze_cofactor_binding(infile):
    print("Analyzing cofactor binding site...")
    global LogString, passovers
    
    # Load and prepare structure
    structure, cof, is_nad, ns, motif = readfile(infile)
    phos_atoms, cof_atoms = defineatoms(structure, cof, is_nad)
    assert len(phos_atoms) != 0, "No phosphate atoms found"
    
    # Find first shell residues
    contactlist, recon = get_first_shell(phos_atoms, ns, motif, cof_atoms, 4.2, is_nad)
    
    # Expand search radius if needed
    if len(contactlist) < 3 and not ex_periph:
        LogString += f'Only {len(contactlist)} residues found, expanding search radius to 5.2.\n'
        tmp_contactlist, tmp_recon = get_first_shell(phos_atoms, ns, motif, cof_atoms, 5.2, is_nad)
        
        for residue in tmp_contactlist:
            if residue.get_id()[1] % 500 not in [s.get_id()[1] % 500 for s in contactlist]:
                recon[residue] = tmp_recon[residue]
                LogString += f'  ->Added residue {residue}.\n'
    
    # Classify residues
    resclass = classify(recon, cof, motif, is_nad)
    
    # Extended edge search if needed
    if "Edge" not in [e[1] for e in resclass]:
        LogString += 'No edge residue detected, checking for weaker geometry edge residues.\n'
        all_residues = Selection.unfold_entities(structure[0], 'R')
        resclass = extended_edge_search(resclass, all_residues, cof, is_nad)
    
    # Sort by residue number
    resclass.sort(key=lambda c: c[0].get_id()[1] % 500)
    assert len(resclass) != 0, "No residues classified"
    
    # Generate library
    lib = make_library(resclass, is_nad)
    
    # Create library display table
    library_df = create_library_table(resclass, lib)
    
    # Get additional residues for recovery
    recovery_suggestions = get_recovery_targets(cof_atoms, ns, contactlist, structure, phos_atoms, is_nad)
    
    return library_df, recovery_suggestions, lib[1]

def readfile(infile):
    name = '.'.join(infile.split('.')[:-1])
    structure = PDBParser(QUIET=True).get_structure(name, infile) # load molecule, ignore errors
    res_list = Selection.unfold_entities(structure[0],'R')
    motif = []
    is_nad = False
    global LogString

    #identify cofactor molecules
    cofactors = [] 
    for r in res_list:
        if r.get_resname() in ['NAP' , 'NDP']:
            cofactors.append(r)
            LogString += 'Cofactor found: %s \n' % str(r)
        elif r.get_resname() in ['NAD','NAI']:
            cofactors.append(r)
            if is_nad == False: LogString += '**Cofactor type set to NAD**\n'
            is_nad = True
            LogString += 'Cofactor found: %s \n' % str(r)
    if len(cofactors) == 0:
        print("Cofactor not found")
        sys.exit()

    #find glycine-rich motifs for binding pyrophosphate
    motif_patterns = [
        ([0, 2, 4], 5),  # G-X-G-X-G/A
        ([0, 2, 5], 6),  # G-X-G-X-X-G/A
        ([0, 3, 5], 6),  # G-X-X-G-X-G/A
    ]

    for i in range(len(res_list) - 6):
        for positions, length in motif_patterns:
            if _check_gly_motif(res_list, i, positions):
                motif_segment = res_list[i:i + length]
                motif.extend(motif_segment)
                LogString += f'Gly-rich motif found: \n--' + '\n--'.join(str(r) for r in motif_segment) + '\n'
                break

    # make KDTree for neighbor searching
    atom_list = Selection.unfold_entities(structure[0],'A')
    ns = NeighborSearch(atom_list)

    return [structure,cofactors,is_nad,ns,motif]

def _check_gly_motif(res_list, start, positions):
    return bool(res_list[start + positions[0]].get_resname() == 'GLY' and
            res_list[start + positions[1]].get_resname() == 'GLY' and
            res_list[start + positions[2]].get_resname() in ['GLY', 'ALA'])


def defineatoms(structure,coflist,is_nad): #returns atoms belong to the cofactor, and specifically those of the O2B (phosphate or not phosphate) moiety
    cof_atoms = Selection.unfold_entities(coflist,'A')
    if not is_nad: phos_names = ['O2B','P2B','O1X','O2X','O3X']
    else: phos_names = ['O2B','N3A','C4A','O3B']
    phos_atoms = []
    for at in cof_atoms:
        if at.get_name() in phos_names: phos_atoms.append(at)
    return [phos_atoms,cof_atoms]

def get_first_shell(phos_atoms, ns, motif, cof_atoms, dist, is_nad):
    global LogString
    
    near_residues = {}  # {resid: (residue, vector_to_phosphate)}
    recontacts = {}      # {resid: contact_atoms_list}
    
    for phos_atom in phos_atoms:
        neighbors = _find_neighbor_atoms(phos_atom, ns, dist)
        
        for neighbor_atom in neighbors:
            residue = neighbor_atom.get_parent()
            
            if not _is_valid_neighbor(neighbor_atom, phos_atoms, cof_atoms):
                continue
            
            if is_nad and _is_behind_cofactor(neighbor_atom, phos_atom):
                LogString += f'Atom {neighbor_atom} of residue {residue} is too far back. Skipping.\n'
                continue
            
            if _should_skip_motif_residue(residue, motif, ex_motif, passovers):
                continue
            
            if not _is_residue_oriented_correctly(residue, phos_atom, is_nad):
                continue
            
            # Process and store residue information
            _process_residue_contact(
                residue, ns, phos_atom, cof_atoms, phos_atoms, 
                dist, near_residues, recontacts, is_nad
            )
    
    # Convert to output format
    return _format_results(near_residues, recontacts)

def _find_neighbor_atoms(phos_atom, ns, dist):
    global LogString
    center = phos_atom.get_coord()
    neighbors = ns.search(center, dist, level='A')
    LogString += f'Atoms near cofactor:\n  {neighbors}\n'
    return neighbors


def _is_valid_neighbor(atom, phos_atoms, cof_atoms):
    global LogString
    
    if not is_aa(atom.get_parent()):
        if atom not in phos_atoms and atom not in cof_atoms:
            LogString += f'Atom {atom} of molecule {atom.get_parent()} is not amino acid. Skipping.\n'
        return False
    
    # Skip backbone atoms and hydrogens
    atom_name = atom.get_name()
    return atom_name not in ['N', 'C', 'O'] and atom_name[0] != 'H'


def _is_behind_cofactor(neighbor_atom, phos_atom):
    parent = phos_atom.get_parent()
    short_dist = norm(neighbor_atom.get_vector() - parent['O2B'].get_vector())
    long_dist = norm(neighbor_atom.get_vector() - parent['O4B'].get_vector())
    return short_dist - long_dist > -0.75


def _should_skip_motif_residue(residue, motif, ex_motif, passovers):
    if residue not in motif:
        return False
    
    if ex_motif:
        if residue.get_resname() != 'GLY' and residue.get_id()[1] not in [e.get_id()[1] for e in passovers]:
            passovers.append(residue)
        return True
    
    return residue.get_resname() == 'GLY'


def _is_residue_oriented_correctly(residue, phos_atom, is_nad):
    global LogString
    
    # Special check for glycine
    if residue.get_resname() == 'GLY':
        ca_dist = norm(residue['CA'].get_vector() - phos_atom.get_vector())
        n_dist = norm(residue['N'].get_vector() - phos_atom.get_vector())
        if ca_dist > n_dist:
            LogString += f'Residue {residue} points away from cofactor. Skipping.\n'
            return False
    
    # Check angle for residues with CA
    if residue.has_id('CA') and 'CA' not in residue.get_resname():
        pcenter, angle = _compute_angle(residue, phos_atom, is_nad)
        max_angle = 100 if is_nad else 90
        if degrees(angle) > max_angle:
            LogString += f'Residue {residue} points away from cofactor. Skipping.\n'
            return False
    
    return True

def _process_residue_contact(residue, ns, phos_atom, cof_atoms, phos_atoms, dist, near_residues, recontacts, is_nad):
    global LogString
    
    resid = residue.get_id()[1] % 500
    pcenter, _ = _compute_angle(residue, phos_atom, is_nad)
    vector_to_phos = pcenter - phos_atom.get_vector()
    
    # Check for recontacts
    contact_atoms = _check_recontacts(residue, ns, cof_atoms, phos_atoms, dist, is_nad)
    if contact_atoms is False:
        return
    
    # Important contact atoms to preserve across updates
    important_atoms = ['O1A', 'O2A', 'O1N', 'O2N', 'O3', 'C5A', 'O3B']
    
    if resid not in near_residues or recontacts[resid] is False:
        # First encounter of this residue ID
        near_residues[resid] = (residue, vector_to_phos)
        recontacts[resid] = contact_atoms
        _log_residue_contacts(residue, contact_atoms)
        
    elif vector_to_phos.norm() < near_residues[resid][1].norm():
        # Found closer instance - update but preserve important contacts
        old_contacts = recontacts[resid]
        merged_contacts = _merge_contact_lists(contact_atoms, old_contacts, important_atoms)
        
        near_residues[resid] = (residue, vector_to_phos)
        recontacts[resid] = merged_contacts
        
    else:
        # Not closer, but merge any new important contacts
        old_contacts = recontacts[resid]
        recontacts[resid] = _merge_contact_lists(old_contacts, contact_atoms, important_atoms)

def _format_results(near_residues, recontacts):
    """Format results for return."""
    residue_list = []
    contact_dict = {}
    
    for resid, (residue, _) in near_residues.items():
        residue_list.append(residue)
        if recontacts[resid] is not False:
            contact_dict[residue] = recontacts[resid]
    
    return [residue_list, contact_dict]

def _merge_contact_lists(primary, secondary, important_atoms):
    global LogString
    merged = list(primary)
    for atom in secondary:
        if atom not in merged and atom in important_atoms:
            merged.append(atom)
            LogString += f'Atom {atom} appended to contacts.\n'
    return merged


def _log_residue_contacts(residue, contact_atoms):
    """Log residue contact information."""
    global LogString
    recontacting_atoms = [atom for atom in contact_atoms if atom != 'Peripheral']
    LogString += f'Residue {residue} contacts atoms {recontacting_atoms} of cofactor.\n'
    if 'Peripheral' in contact_atoms:
        LogString += '  ->Assigned as peripheral.\n'



def _compute_angle(res, root, is_nad):
    pcenter = _get_pseudocenter(res)
    if res.get_resname() == 'GLY': return [pcenter,0]
    CA = res['CA'].get_vector()
    if not is_nad: root_loc = root.get_parent()['P2B'].get_vector()
    else: root_loc = root.get_parent()['C2B'].get_vector()
    v1 = pcenter - CA
    v2 = root_loc - CA
    cosang = dot(v1,v2)/(norm(v1)*norm(v2))
    ang = acos(cosang)
    return [pcenter,ang]


def _get_pseudocenter(r):
    if r.get_resname() in sidechaincenteratoms.keys():
        center_atoms = sidechaincenteratoms[r.get_resname()]
    else:
        print(f"{r} is a nonstandard amino acid.")
        return r['CA'].get_vector()
    center_atom_coords = []
    for at in Selection.unfold_entities(r,'A'):
        if at.get_name() in list(center_atoms): center_atom_coords.append(at.get_vector())
    if len(center_atom_coords) == 0:
        print(f"{r} is missing its sidechain.")
        return r['CA'].get_vector()
    pseudocenter = center_atom_coords[0]
    if len(center_atom_coords) != 1:
        for v in center_atom_coords[1:]: pseudocenter = pseudocenter + v
    pseudocenter = pseudocenter / len(center_atom_coords)
    return pseudocenter

def _check_recontacts(r,ns,cof_atoms,phos,dist,is_nad): #recontacts are the atoms of the cofactor within a certain distance of the residue
    recontacts = []
    if is_nad:
        for e in phos:
            if e.get_name() == 'O3B': 
                phos.remove(e)
    pcount = 0
    for at in Selection.unfold_entities(r,'A'):
        if at.get_name() in sidechaincenteratoms[r.get_resname()]: 
            neighbors = ns.search(at.get_coord(),dist*1.025,level='A')
            for neighbor_atom in neighbors:
                if neighbor_atom not in cof_atoms: continue
                if neighbor_atom in phos:
                    pcount += 1
                    continue
                if neighbor_atom.get_name() not in recontacts: recontacts.append(neighbor_atom.get_name())
    if pcount == 0:
        if r.get_resname() in ['SER','THR','TYR','HIS','GLN','ASN','LYS','ARG','GLU','ASP']: #common H binders
            if len(recontacts) == 0: recontacts.append("Peripheral")
        else:
            return False
    return recontacts


def classify(recontacts, cofactor, motif, is_nad):
    global LogString
    
    class_list = []
    
    for residue in recontacts.keys():
        residue_class = _determine_residue_class(residue, recontacts[residue], cofactor, motif, is_nad)
        
        if residue_class:
            # Handle exclusions for certain classes
            if _should_exclude_residue(residue, residue_class):
                passovers.append(residue)
            else:
                class_list.append([residue, residue_class])
    
    return class_list

def _determine_residue_class(residue, recontacts, cofactor, motif, is_nad):
    global LogString
    
    # Check motif membership
    if not ex_motif and residue in motif:
        return "Motif"
    
    # Check peripheral recontacts
    if recontacts == ["Peripheral"]:
        return "Peripheral"
    
    # Handle residues with no recontacts
    if not recontacts:
        return _classify_no_recontact_residue(residue, cofactor, is_nad)
    
    # Check face contacts (C5A plus ring atoms)
    ring_atoms = {'C4A', 'N3A', 'C2A', 'N1A', 'C6A', 'C8A', 'N7A'}
    if 'C5A' in recontacts and bool(ring_atoms & set(recontacts)): #bool is checking if the merged set has length > 0
        return "Face"
    
    # Check other ring contacts with plane verification
    back_ring_atoms = {'N9A', 'C4A', 'N3A', 'C2A', 'N1A', 'C6A', 'N7A'}
    if bool(back_ring_atoms & set(recontacts)):
        class_check = _plane_check(residue, cofactor, is_nad) #returns "Edge","Face", or None
        if class_check:
            return class_check
    
    # Check diphosphate binding
    diphosphate_atoms = {'O1A', 'O2A', 'O1N', 'O2N', 'O3'}
    if bool(diphosphate_atoms & set(recontacts)):
        return "Pyrophosphate"
    
    # Check bidentate binding
    if 'O3B' in recontacts:
        return "Bidentate"
    
    # NAD-specific ring-binder check (Search at expanded distance)
    if is_nad and _is_ring_binder(residue):
        return "Ring-binder"
    
    # Default classification
    return "Nonsimple"

def _classify_no_recontact_residue(residue, cofactor, is_nad):
    global LogString
    
    if residue.get_resname() != 'ARG' or not residue.has_id('CD'):
        return "Simple"
    
    # Check arginine CD distances to cofactor atoms
    _, n_cof = _nearest(cofactor, residue, is_nad)
    
    cd_c4a_dist = norm(n_cof['C4A'].get_vector() - residue['CD'].get_vector())
    if cd_c4a_dist < 4.2:
        LogString += f'Arg {residue} defined as face on the basis of CD-C4A distance of {cd_c4a_dist:.1f}.\n'
        return "R-chain"
    
    cd_c5a_dist = norm(n_cof['C5A'].get_vector() - residue['CD'].get_vector())
    if cd_c5a_dist < 4.2:
        LogString += f'Arg {residue} defined as face on the basis of CD-C5A distance of {cd_c5a_dist:.1f}.\n'
        return "R-chain"
    
    return "Simple"

def _nearest(cofactors, residue, is_nad): #find nearest cofactor to residue (for multimers)
    res_pcenter = _get_pseudocenter(residue)
    min_dist = [20,None]
    for cof in cofactors:
        if not is_nad: P_coord = cof['P2B'].get_vector()
        else: P_coord = cof['O2B'].get_vector()
        if norm(res_pcenter-P_coord) < min_dist[0]:
            min_dist = [norm(res_pcenter-P_coord),cof] #update if find something closer
    cofactor = min_dist[1]
    return [res_pcenter, cofactor]


def _is_ring_binder(residue):
    global LogString
    
    expanded_contacts = []
    for atom in Selection.unfold_entities(residue, 'A'):
        if atom.get_name() in sidechaincenteratoms[residue.get_resname()]:
            neighbors = ns.search(atom.get_coord(), 4.4, level='A')
            for neighbor_atom in neighbors:
                if neighbor_atom in cof_atoms:
                    expanded_contacts.append(neighbor_atom.get_name())
    
    if 'O2B' not in expanded_contacts:
        LogString += f'Residue {residue} is peripheral even with expanded search radius.\n'
        return True
    
    return False


def _should_exclude_residue(residue, residue_class):
    """Determine if a residue should be excluded based on its class and settings."""
    exclusion_rules = {
        "Peripheral": ex_periph,
        "Pyrophosphate": ex_diphos,
        "Ring-binder": ex_periph  # Ring-binder follows peripheral exclusion
    }
    
    return exclusion_rules.get(residue_class, False)

def _plane_check(res, cofactors, is_nad):
    """This function transforms the coordinates of residues into a new coordinate system based on
    the position and orientation of the cofactor, to better determine where they are in reference
    to the plane of the adenine ring"""
    global LogString
    
    res_pcenter, cofactor = _nearest(cofactors, res, is_nad)
    if cofactor is None:
        return None
    
    # Transform to adenine-centered coordinate system
    transform_matrix = _get_adenine_transform_matrix(cofactor)
    pc_vec = _transform_vector(res_pcenter, cofactor['C4A'].get_vector(), transform_matrix)
    
    LogString += f'Residue {res} has pseudocenter-C4A vector <{pc_vec[0]:.2f},{pc_vec[1]:.2f},{pc_vec[2]:.2f}> in adenine-centered coordinate space.\n'
    
    # Check for offset face positioning
    if _is_offset_face(pc_vec):
        LogString += '  ->Meets threshold for assignment as face.\n'
        return "Offset face"
    
    # Check for edge positioning
    if _is_potential_edge(pc_vec) and res.get_resname() != 'GLY':
        if _confirm_edge_by_ca_position(res, cofactor, transform_matrix):
            LogString += '  ->Meets threshold for assignment as edge.\n'
            return "Edge"
    
    # Check planar residues for R-chain classification
    if res.get_resname() in ['ARG', 'GLN', 'TYR', 'PHE', 'GLU', 'TRP', 'HIS', 'PRO']:
        if _is_parallel_to_adenine(res, cofactor):
            LogString += '  ->Below threshold, assigned as face.\n'
            return "R-chain"
    
    return None


def _get_adenine_coordinate_system(cofactor):
    c4a = cofactor['C4A'].get_vector()
    n3a = cofactor['N3A'].get_vector()
    c5a = cofactor['C5A'].get_vector()
    
    # Create orthonormal basis
    x_axis = (n3a - c4a).normalized()
    y_axis = (c5a - c4a).normalized()
    z_axis = (x_axis ** y_axis).normalized()  # Normal to adenine plane
    
    return x_axis, y_axis, z_axis


def _get_adenine_transform_matrix(cofactor):
    """Create transformation matrix to adenine-centered coordinate system."""
    x_axis, y_axis, z_axis = _get_adenine_coordinate_system(cofactor)
    
    # Build transformation matrix (inverse of basis matrix)
    basis_matrix = matrix([
        [x_axis[0], y_axis[0], z_axis[0]],
        [x_axis[1], y_axis[1], z_axis[1]],
        [x_axis[2], y_axis[2], z_axis[2]]
    ])
    
    return basis_matrix.I


def _transform_vector(point, origin, transform_matrix):
    """Transform a vector into the new coordinate system."""
    vec = point - origin
    vec_matrix = matrix([[vec[0]], [vec[1]], [vec[2]]])
    transformed = transform_matrix * vec_matrix
    
    # Convert matrix result to list
    return [float(transformed[0, 0]), float(transformed[1, 0]), float(transformed[2, 0])]


def _is_offset_face(pc_vec):
    z_distance = abs(pc_vec[2])
    xy_sum = abs(pc_vec[0]) + abs(pc_vec[1])
    xy_bounds = (abs(pc_vec[0]) < 3 and -2 < pc_vec[1] < 3)
    
    # Check magnitude of transformed vector
    vector_magnitude = sqrt(pc_vec[0]**2 + pc_vec[1]**2 + pc_vec[2]**2)
    
    return bool(z_distance > 2.5 and 
            (xy_sum < 2.5 or xy_bounds) and 
            vector_magnitude < 5.35)


def _is_potential_edge(pc_vec):
    return bool(abs(pc_vec[2]) < 3.2 and 
            abs(pc_vec[2]) < abs(pc_vec[1]) and 
            abs(pc_vec[0]) < 5)


def _confirm_edge_by_ca_position(res, cofactor, transform_matrix):
    """Confirm edge classification by checking CA position."""
    global LogString
    
    ca_vec = _transform_vector(res['CA'].get_vector(), cofactor['C4A'].get_vector(), transform_matrix)
    
    LogString += f'Residue {res} has CA-C4A vector <{ca_vec[0]:.2f},{ca_vec[1]:.2f},{ca_vec[2]:.2f}> in adenine-centered coordinate space.\n'
    
    # CA should be more in-plane than out, and on negative y side
    return bool(abs(ca_vec[2]) < max(abs(ca_vec[0]), abs(ca_vec[1])) and 
            ca_vec[1] < 0)


def _get_adenine_coordinate_system(cofactor):
    """Get the orthonormal basis vectors for the adenine ring coordinate system."""
    c4a = cofactor['C4A'].get_vector()
    n3a = cofactor['N3A'].get_vector()
    c5a = cofactor['C5A'].get_vector()
    
    # Create orthonormal basis
    x_axis = (n3a - c4a).normalized()
    y_axis = (c5a - c4a).normalized()
    z_axis = (x_axis ** y_axis).normalized()  # Normal to adenine plane
    
    return x_axis, y_axis, z_axis


def _get_adenine_transform_matrix(cofactor):
    """Create transformation matrix to adenine-centered coordinate system."""
    x_axis, y_axis, z_axis = _get_adenine_coordinate_system(cofactor)
    
    # Build transformation matrix (inverse of basis matrix)
    basis_matrix = matrix([
        [x_axis[0], y_axis[0], z_axis[0]],
        [x_axis[1], y_axis[1], z_axis[1]],
        [x_axis[2], y_axis[2], z_axis[2]]
    ])
    
    return basis_matrix.I


def _is_parallel_to_adenine(res, cofactor):
    global LogString
    
    # Get atoms defining the residue plane
    plane_atoms = _get_plane_atoms(res)
    if len(plane_atoms) < 3:
        return False
    
    # Calculate normal vector to residue plane
    vectors = [res[atom].get_vector() for atom in plane_atoms[:3]]
    x_vec = vectors[1] - vectors[0]
    y_vec = vectors[2] - vectors[0]
    residue_normal = (x_vec ** y_vec).normalized()
    
    _, _, adenine_normal = _get_adenine_coordinate_system(cofactor)
    
    # Check angle between normals
    cos_angle = dot(adenine_normal, residue_normal) / (norm(adenine_normal) * norm(residue_normal))
    angle = degrees(acos(cos_angle))
    deviation = abs(90 - angle)
    
    LogString += f'Angle between normal vector to side-chain of residue {res} and normal vector to adenine is {90 - deviation:.2f}.\n'
    
    return bool(deviation < 20)


def _get_plane_atoms(res):
    resname = res.get_resname()
    
    if resname in ['GLU', 'GLN']:
        atoms = list(sidechaincenteratoms[resname])
        atoms.append('CD')
        return atoms
    elif resname == 'PRO':
        return ['CB', 'CD', 'N']
    elif resname in sidechaincenteratoms:
        return sidechaincenteratoms[resname][:3]
    
    return []

def extended_edge_search(resclass, res_list, cofactors, is_nad):
    global LogString
    
    # Find residues that might have floor residues before them
    trigger_classes = {"Offset face", "Face", "R-chain"}
    
    for res, res_class in resclass:
        if res_class not in trigger_classes:
            continue
            
        prior_residues = _get_prior_residues(res, res_list, 2) # Check the two residues before this one
        
        for prior in prior_residues:
            if _should_reclassify_as_floor(prior, res, resclass, cofactors, is_nad):
                _update_classification(resclass, prior, "Floor")
    
    return resclass

def _get_prior_residues(target_res, res_list, n=2):
    for i, res in enumerate(res_list):
        if res == target_res and i >= n:
            return res_list[i-n:i]
    return []


def _should_reclassify_as_floor(candidate, reference_res, resclass, cofactors, is_nad):
    global LogString
    
    # Skip glycines and things already classified as Edge-like or Face-like
    if candidate.get_resname() == 'GLY':
        return False
    current_class = _get_current_class(candidate, resclass)
    if current_class in {"Edge", "Face", "Offset face", "R-chain"}:
        return False
    
    # check geometry
    res_pcenter, cofactor = _nearest(cofactors, candidate, is_nad)
    if cofactor is None:
        return False
    
    o2_vec = cofactor['O2B'].get_vector()
    if norm(o2_vec - res_pcenter) > norm(o2_vec - candidate['CA'].get_vector()):
        return False
    
    edge_angle = _calculate_edge_angle(candidate, res_pcenter, cofactor)
    LogString += f'Angle between residue {candidate} side chain and adenine edge is {edge_angle:.2f}.\n'
    if edge_angle > 45:
        return False
    
    normal_angle = _calculate_normal_angle(candidate, res_pcenter, cofactor)
    deviation = abs(90 - normal_angle)
    LogString += f'Angle between side-chain of residue {candidate} and normal vector to adenine is {deviation:.2f}.\n'
    
    if deviation < 20:
        LogString += '  ->Meets threshold for assignment as floor residue.\n'
        return True
    
    return False

def _calculate_edge_angle(residue, res_pcenter, cofactor):
    edge_vector = cofactor['N9A'].get_vector() - cofactor['N3A'].get_vector()
    residue_vector = res_pcenter - residue['CA'].get_vector()
    
    cos_angle = dot(edge_vector, residue_vector) / (norm(edge_vector) * norm(residue_vector))
    return degrees(acos(cos_angle))


def _calculate_normal_angle(residue, res_pcenter, cofactor):
    edge_vector = cofactor['N9A'].get_vector() - cofactor['N3A'].get_vector()
    plane_vector = cofactor['C4A'].get_vector() - cofactor['C5A'].get_vector()
    normal_vector = edge_vector ** plane_vector  # Cross product
    
    residue_vector = res_pcenter - residue['CA'].get_vector()
    cos_angle = dot(normal_vector, residue_vector) / (norm(normal_vector) * norm(residue_vector))
    return degrees(acos(cos_angle))

def _get_current_class(residue, resclass):
    for res, res_class in resclass:
        if res.get_id() == residue.get_id():
            return res_class
    return None

def _update_classification(resclass, residue, new_class):
    # If it already has a class
    for entry in resclass:
        if entry[0].get_id() == residue.get_id():
            entry[1] = new_class
            return    
    # Not found, add new entry
    resclass.append([residue, new_class])



In [5]:
"""Design library for cofactor specificity reversal"""

def make_library(resinfo, is_nad):
    print("Designing library for cofactor specificity reversal...")
    raw_lib = []
    for res, resclass  in resinfo:
        paired_info = (res.get_resname(),_simplify_resclass(resclass))
        if not is_nad: appendor = (res.get_id()[1],Plibraries[paired_info])
        else: appendor = (res.get_id()[1],Nlibraries[paired_info])

        if resclass in ["Peripheral","Ring-binder","Pyrophosphate","Motif"]:
            appendor[1].insert(0,[_three_letter_to_one[res.get_resname()],'---'])
        raw_lib.append(appendor)
    library = select_size(raw_lib,resinfo, is_nad)
    return library

def _simplify_resclass(rc):
    if rc in ["Edge","Bidentate","Face","Simple"]: return rc
    elif rc == "Floor": return "Edge"
    elif rc in ["R-chain","Offset face","Ring-binder"]: return "Face"
    elif rc in ["Nonsimple","Simple*","Peripheral"]: return "Simple"
    elif rc in ["Motif","Pyrophosphate"]: return "Bidentate"
    else:
        print("unknown resclass", rc)
        sys.exit()

def select_size(raw_lib, resclass, is_nad):
    global LogString
    
    resclass_dict = {e[0].get_id()[1]: e[1] for e in resclass}

    raw_lib = _reduce_minimum_size(raw_lib, resclass, is_nad) #add omission options

    index_dict = _find_target_size(raw_lib, resclass_dict) #find indices for dictionary under max size
    
    library_out = _format_output_library(raw_lib, resclass, index_dict)
    result_string = f"\n Suggested library size: {_count_size(raw_lib, index_dict)} "
    
    return [library_out, result_string]

def _reduce_minimum_size(raw_lib, resclass, is_nad):
    global LogString
    
    min_size = _count_size(raw_lib, [0] * len(raw_lib))
    LogString += f'Minimum default library size is {min_size}.\n'
    
    if min_size <= max_size:
        return raw_lib
    
    # Priority order for amino acids to omit
    priority_orders = {
        'P': ['W','C','M','F','P','I','L','E','D','V','Q','H','A','Y','G','N','T','K','R','S'],
        'N': ['C','W','M','Y','H','F','T','Q','A','P','G','R','S','N','L','V','E','K','I','D']
    }
    order = priority_orders['N'] if is_nad else priority_orders['P']
    
    for aa in order:
        for lib_posn, res_entry in enumerate(resclass):
            residue = res_entry[0]
            if _three_letter_to_one[residue.get_resname()] == aa:
                # Add option to omit this residue
                omission_option = [aa, '---']
                raw_lib[lib_posn] = (
                    raw_lib[lib_posn][0], 
                    [omission_option] + raw_lib[lib_posn][1]
                )
                
                min_size = _count_size(raw_lib, [0] * len(raw_lib))
                LogString += f'\n  ->Adding option to omit residue {raw_lib[lib_posn][0]} from library. '
                LogString += f'Minimum size is now {min_size}.\n'
                
                if min_size <= max_size:
                    return raw_lib
    
    return raw_lib


def _find_target_size(raw_lib, resclass_dict):
    """Reduce library size from maximum by adjusting variant indices."""
    global LogString
    
    # Start with largest possible library
    index_dict = [len(entry[1]) - 1 for entry in raw_lib]
    size = _count_size(raw_lib, index_dict)
    LogString += f' Starting at largest library size, {size}.\n'
    
    if size <= max_size:
        return index_dict
    
    resclass_priority = [
        "Motif", "Pyrophosphate", "Ring-binder", "Peripheral", 
        "Floor", "Edge", "R-chain", "Offset face", "Face", 
        "Simple*", "Simple", "Nonsimple", "Bidentate"
    ]
    
    # Keep cycling through priorities until we reach target size
    while size > max_size:
        found_reduction = False
        
        for resclass_type in resclass_priority:
            for i in range(len(raw_lib)):
                res_id = raw_lib[i][0]
                
                if index_dict[i] > 0 and resclass_dict[res_id] == resclass_type:
                    index_dict[i] -= 1
                    size = _count_size(raw_lib, index_dict)
                    LogString += f'Reducing library size at residue {res_id}. Library size is now {size}.\n'
                    found_reduction = True
                    
                    if size <= max_size:
                        return index_dict
                    break  # Move to next resclass type after one reduction
        
        # Safety check - if we made no reductions in a full pass, we're stuck
        if not found_reduction:
            LogString += f'WARNING: Cannot reduce library size further. Stuck at {size}.\n'
            break
    
    return index_dict


def _format_output_library(raw_lib, resclass, index_dict):
    global passovers
    
    # Special case for single position
    if len(index_dict) == 1:
        return raw_lib[0][1][index_dict[0]]
    
    # Build library and track omitted residues
    library_out = []
    for i in range(len(raw_lib)):
        variant = raw_lib[i][1][index_dict[i]]
        library_out.append(variant)
        
        # Track residues marked for omission
        if variant[1] == '---':
            passovers.append(resclass[i][0])
    
    return library_out


def _count_size(raw_lib, index_dict):
    """Calculate total library size for given variant indices."""
    assert len(raw_lib) == len(index_dict)
    size = 1
    for i, idx in enumerate(index_dict):
        codon = raw_lib[i][1][idx][1]
        size *= _expand_codon(codon)
    return size


def _expand_codon(codon):
    """Calculate number of sequences a degenerate codon represents."""
    size = 1
    for nucleotide in codon:
        size *= degen_sizes_flat[nucleotide]
    return size



def create_library_table(resclass, lib):
    """Create a pandas DataFrame with the library results."""
    data = []
    
    for i, (residue, classification) in enumerate(resclass):
        res_name = residue.get_resname()
        res_id = residue.get_id()[1]
        
        if len(resclass) != 1:
            codon = lib[0][i][1]
            amino_acids = lib[0][i][0]
        else:
            codon = lib[0][1]
            amino_acids = lib[0][0]
        
        data.append({
            'Residue': f'{res_name} {res_id}',
            'Type': classification,
            'Codon': codon,
            'Amino Acids': amino_acids
        })
    
    return pd.DataFrame(data)

# def style_results_table(df):
#     """Apply styling to results DataFrame."""
#     return df.style.set_properties(**{
#         'text-align': 'left',
#         'font-size': '12pt',
#     }).set_table_styles([
#         {'selector': 'th', 'props': [('font-weight', 'bold'), 
#                                       ('text-align', 'center'),
#                                       ('background-color', '#f0f0f0')]}
#     ])

In [6]:
"""Identify residues for site-saturation recovery."""
def get_recovery_targets(cof_atoms, ns, contactlist, structure, phos_atoms, is_nad):
    print("Indentifying positions for activity recovery...")
    global passovers
    
    # Get backing residues
    backinglist = list(_get_backing(cof_atoms, ns, contactlist))
    backinglist.sort(key=lambda c: c.get_id()[1] % 500)
    
    # Get hydrogen bonding residues
    hbonders = find_hbonds(structure, contactlist, ns)
    
    # Get second shell residues
    secondshell = list(get_second_shell(phos_atoms, ns, contactlist, is_nad))
    
    # Add unique hbonders to second shell
    existing_ids = ([e.get_id()[1] for e in backinglist] + 
                    [e.get_id()[1] for e in secondshell])
    for h in hbonders:
        if h.get_id()[1] not in existing_ids:
            secondshell.append(h)
    
    secondshell.sort(key=lambda c: c.get_id()[1] % 500)
    
    # Create recovery priority table
    recovery_positions = []
    
    for res in passovers:
        recovery_positions.append({
            'Residue': f'{res.get_resname()} {res.get_id()[1]}',
            'Priority': 'High',
            'Reason': 'Excluded from library'
        })
    
    for res in backinglist:
        recovery_positions.append({
            'Residue': f'{res.get_resname()} {res.get_id()[1]}',
            'Priority': 'Medium',
            'Reason': 'Backing residue'
        })
    
    for res in secondshell:
        recovery_positions.append({
            'Residue': f'{res.get_resname()} {res.get_id()[1]}',
            'Priority': 'Low',
            'Reason': 'Second shell/H-bonding'
        })
    
    return pd.DataFrame(recovery_positions) if recovery_positions else None

def _get_backing(cof_atoms,ns,contactlist):
    backing_residues = {}
    for at in cof_atoms:
        if at.get_name() in ['N7A','C5A','N6A','N1A']:
            center = at.get_coord()
            neighbors = ns.search(center,4.1,level='R')
            for n in neighbors:
                if n.has_id('CA') and 'CA' not in n.get_resname() and n.get_id()[1]%500 not in [r.get_id()[1]%500 for r in contactlist]:
                    if n.get_id()[1]%500 not in backing_residues: backing_residues[n.get_id()[1]%500] = n
    return backing_residues.values()

def get_second_shell(phos_atoms,ns,contactlist, is_nad):
    second_shell_residues = {}
    for at in phos_atoms:
        center = at.get_coord()
        neighbors = ns.search(center, 8.2,level='R')
        for n in neighbors:
            if n.has_id('CA') and 'CA' not in n.get_resname() and n.get_id()[1]%500 not in [r.get_id()[1]%500 for r in contactlist]:
                if (is_nad == False and n.get_resname() in ['LYS','ARG','HIS']) or (is_nad and n.get_resname() in ['ASP','GLU']):
                    if n.get_id()[1]%500 not in second_shell_residues: second_shell_residues[n.get_id()[1]%500] = n 
    return second_shell_residues.values()

def find_hbonds(structure,contactlist,ns):
    hbond_residues = {}
    res_list = Selection.unfold_entities(structure[0],'R')
    expanded_contacts = []
    for r in res_list:
        if r.get_id()[1]%500 in [c.get_id()[1]%500 for c in contactlist] and r not in expanded_contacts: expanded_contacts.append(r)
    for r in expanded_contacts:
        res_atoms = Selection.unfold_entities(r,'A')
        for at in res_atoms:
            if at.get_name()[0] in ['O','N'] and at.get_name() not in ['O','N']:
                nats = ns.search(at.get_coord(),3.8,level='A')
                for n in nats:
                    if n.get_name()[0] in ['O','N'] and n.get_name() not in ['O','N']: #an oxygen or nitrogen but not backbone
                        nr = n.get_parent()
                        if is_aa(nr) == False: continue
                        if nr in expanded_contacts: continue
                        if nr.get_id()[1]%500 not in hbond_residues: hbond_residues[nr.get_id()[1]%500] = nr
    return hbond_residues.values()


In [None]:
# Main execution

#Initialize holders
LogString = ''
by_aa = {key: [] for key in sidechaincenteratoms.keys()}
lengthlist = []
passovers = []

# Main execution
library_df, recovery_df, library_info = analyze_cofactor_binding(infile)

# Display library table
print("\n=== LIBRARY DESIGN RESULTS ===\n")
#styled_results = style_results_table(library_df)
display(library_df)
print(f"\n{library_info}")

# Display recovery suggestions
if recovery_df is not None and not recovery_df.empty:
    print("\n=== SITE-SATURATION MUTAGENESIS TARGETS ===")
    print("The following residues should be targeted for activity recovery:\n")
    
    # Group by priority for cleaner display
    for priority in ['High', 'Medium', 'Low']:
        priority_df = recovery_df[recovery_df['Priority'] == priority]
        if not priority_df.empty:
            print(f"\n{priority} Priority:")
            display(priority_df[['Residue', 'Reason']])
else:
    print("\nNo obvious amino acids for activity recovery.")

# Display log if verbose
if verbose:
    print("\n=== ACTIVITY LOG ===\n")
    # Clean up HTML tags from log
    clean_log = LogString.replace('<', '[').replace('>', ']').replace('[br]', '\n')
    print(clean_log)