In [9]:
import re
import itertools
import sys
import re
from collections import OrderedDict  

In [10]:
restypes = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
            'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

nums = ['1', '2', '3', '4', '5', '6', '7', '8', '9']

metrics_dir = sys.argv[1]

def extract_unique_keys(nested, preserve_order=True):
    """Return all substrings inside {...} anywhere in *nested*."""
    
    def walk(obj):
        """Yield every string found at any depth inside obj."""
        if isinstance(obj, str):
            yield obj
        elif isinstance(obj, dict):
            for v in obj.values():
                yield from walk(v)
        elif isinstance(obj, (list, tuple, set)):
            for item in obj:
                yield from walk(item)
        # ignore non-iterable, non-string leaf nodes (ints, floats, etc.)

    # Collect every string in one pass, then run the regex once
    all_text = " ".join(walk(nested))
    hits = re.findall(r"\{([^}]*)\}", all_text)

    if preserve_order:
        return list(OrderedDict.fromkeys(hits))  # de-duplicate, keep order
    else:
        return sorted(set(hits))     # Fast, but order not guaranteed

    
def get_helices(keys, TM=False):
    helices_used = set()
    
    for key in keys:
        if 'x' in key:
            helix = key.split('x')[0]
            if helix.isdigit():
                    helices_used.add(int(helix))
    
        elif 'TM' == key[:2] or 'H' == key[0]:
            if key[0] != 'H':
                helices_used.add(int(key[2:]))
            elif key[0] == 'H':
                helices_used.add(int(key[1:]))

    return helices_used


def get_ref_numbers(helices_used):
    # Prompt for segid
    segid = input("\n Enter the segid for the protein: ").strip()

    # Prompt for reference residue numbers (.50) for each helix
    helix_refs = {}
    print('\n')
    for helix in sorted(helices_used):
        ref_res = int(input(f"Enter the residue number for {helix}.50: "))
        helix_refs[helix] = {'ref_res': ref_res, 'segid': segid}

    return helix_refs, segid


def process_0x00(keys, helix_refs, segid=False, ref=False):
    dict_0x00 = {}
    for key in keys:
        helix, gpcrdb_pos = key.split('x')
        if ref:
            gpcrdb_pos = gpcrdb_pos[:-4]
        if helix.isdigit():
            helix_num = int(helix)
            gpcrdb_pos = int(gpcrdb_pos)
            offset = gpcrdb_pos - 50
            ref_res = helix_refs[helix_num]['ref_res']
            resid = ref_res + offset

            if segid:
                dict_0x00[key] = f"protein and segid {segid} and resid {resid}"
            else:
                dict_0x00[key] = f"protein and resid {resid}"
    return dict_0x00


def process_TMX(keys, helix_refs, segid=False, ref=False):
    dict_TMX = {}

    for key in keys_TM:
        if key[0] == 'H':
            helix_num = int(key[1:])
        else:
            helix_num = int(key[2:])

        ref_res = helix_refs[helix_num]['ref_res']
        segid = helix_refs[helix_num]['segid']

        # Prompt for GPCRDB range
        gpcrdb_range = input(f"Enter the GPCRDB range for TM{helix_num} (e.g., 3.44-3.56): ").strip()
        start_gpcrdb, end_gpcrdb = gpcrdb_range.split('-')
        start_gpcrdb = int(start_gpcrdb.split('.')[1])
        end_gpcrdb = int(end_gpcrdb.split('.')[1])

        # Calculate residue range
        start_res = ref_res + (start_gpcrdb - 50)
        end_res = ref_res + (end_gpcrdb - 50)
        if ref:
            key += '_ref'
        if segid:
            dict_TMX[key] = f"protein and segid {segid} and resid {start_res} to {end_res}"
        else:
            dict_TMX[key] = f"protein and resid {start_res} to {end_res}"

    return dict_TMX


def process_X000(keys, segid=False):
    dict_X000 = {}
    if not keys:
        return {}
    for key in keys:
        num = key[1:]
        if segid:
            dict_X000[key] = f"protein and segid {segid} and resid {num}"
        else:
            dict_X000[key] = f"protein and resid {num}"
    return dict_X000

                
def process_BW(keys_0x00, keys_TM=[], ref=False):
    # Identify the helices present
    if not keys_0x00:
        return {}

    helices_used = get_helices(keys_0x00 + keys_TM)

    helix_refs, segid = get_ref_numbers(helices_used)
    
    res_dict = {}
    
    res_dict = {**res_dict, **process_0x00(keys_0x00, helix_refs, segid, ref=ref)}
    
    if keys_TM:
        res_dict = {**res_dict, **process_TMX(keys_TMX, helix_refs, segid, ref=ref)}

    return res_dict, segid
    
    
def make_cond_dict(metric_dict):
    keys = extract_unique_keys(metric_dict)

    keys_0x00 = []
    keys_X000 = []
    keys_TMX = []
    keys_ref = []

    for key in keys:
        if 'x' in key and 'ref' not in key:
            keys_0x00.append(key)
        elif 'H8' in key:
            keys_TMX.append(key)
        elif 'ref' in key and 'selection' not in key:
            keys_ref.append(key)
        elif key[0] in restypes and key[1] in nums:
            keys_X000.append(key)
        elif 'TM' == key[:2] and key[2] in nums:
            keys_TMX.append(key)
        else:
            print(f"Couldn't process key {key}, add on your own")
    
    cond_dict_BW, segid = process_BW(keys_0x00, keys_TM=keys_TMX)
    cond_dict_X000 = process_X000(keys_X000, segid)
    cond_dict_ref = {}
    ref_same = ''
    
    if keys_ref:
        ref_same = input(f"\nShould the reference selections be the same as the original? Y/N : ").strip()
    
    if ref_same == 'Y':
        for key in list(cond_dict_BW.keys()):
            cond_dict_ref[key + '_ref'] = cond_dict_BW[key]
    elif ref_same == 'N':
        cond_dict_ref, _ = process_BW(keys_ref, ref=True)
    
    total_dict = {**cond_dict_BW, **cond_dict_X000, **cond_dict_ref}
    return total_dict
            


In [14]:
metrics = {
    'distances': {
        # TM6 Intracellular Position
        '7.53-1.53': ['{7x53} and name CA', '{1x53} and name CA'],
        '7.56N-7.52O': ['{7x56} and name N', '{7x52} and name O'],
        '7.43O-D135': ['{7x43} and name OH', '{D135} and name CA'],
        'M108SD-D135': ['{M108} and name SD', '{D135} and name CA'],
    },}

make_cond_dict(metrics)


 Enter the segid for the protein:  






Enter the residue number for 1.50:  100
Enter the residue number for 7.50:  700


{'7x53': 'protein and resid 703',
 '1x53': 'protein and resid 103',
 '7x56': 'protein and resid 706',
 '7x52': 'protein and resid 702',
 '7x43': 'protein and resid 693',
 'D135': 'protein and resid 135',
 'M108': 'protein and resid 108'}

In [34]:
test = 'thing_ref'

test[:-4]

'thing'