In [None]:
def set_values(matrix, atom_desc, left_bottom, right_up, bin_size):
    aas = ['GLY', 'ALA', 'ILE', 'LEU', 
           'PRO', 'SER', 'THR', 'CYS', 
           'MET', 'ASP', 'ASN', 'GLU', 
           'GLN', 'LYS', 'ARG', 'HIS', 
           'PHE', 'TYR', 'TRP']
    
    x = atom_desc[5]
    y = atom_desc[6]
    z = atom_desc[7]
    
    # check if the atom is inside of the box
    if (x-left_bottom[0] < 0.0 or y-left_bottom[1] < 0.0 or z-left_bottom[2] < 0.0):
        return 0
    
    if (x-right_up[0] >= 0.0 or y-right_up[1] >= 0.0 or z-right_up[2] >= 0.0):
        return 0
    
    # compute atom index
    x_bin = np.floor((x-left_bottom[0])/bin_size)
    y_bin = np.floor((y-left_bottom[1])/bin_size)
    z_bin = np.floor((z-left_bottom[2])/bin_size)
    
    #if y_bin == 20:
    #        print(y-left_bottom[1])
    
    # increase values in matrix corresponding to the atom
    if atom_desc[0] != 'ATOM':
        matrix['X'][x_bin, y_bin, z_bin] += 1.0
        matrix['XXX'][x_bin, y_bin, z_bin] += 1.0
        return 1
    
    if atom_desc[1] == 'CA':
        matrix['CA'][x_bin, y_bin, z_bin] += 1.0
        
    if atom_desc[2] in ['C', 'N', 'O', 'S']:
        matrix[atom_desc[2]][x_bin, y_bin, z_bin] += 1.0
    else:
        matrix['X'][x_bin, y_bin, z_bin] += 1.0
    
    if atom_desc[3] in aas:
        matrix[atom_desc[3]][x_bin, y_bin, z_bin] += 1.0
    else:
        matrix['XXX'][x_bin, y_bin, z_bin] += 1.0
        
    if atom_desc[4] == 'H':
        matrix['HELIX'][x_bin, y_bin, z_bin] += 1.0
        
    if atom_desc[4] == 'S':
        matrix['STRAND'][x_bin, y_bin, z_bin] += 1.0
        
    return 2

# convert complex to 4d matrix
# name should be preloaded to pymol with cmd.load()
def mhc2matrix(name, left_bottom=(-30.0, -10.0, -10.0), right_up=(30.0, 10.0, 10.0), bin_size=1.0):
    x_bins = int((right_up[0]-left_bottom[0])/bin_size)
    y_bins = int((right_up[1]-left_bottom[1])/bin_size)
    z_bins = int((right_up[2]-left_bottom[2])/bin_size)
    
    layer_names = ['CA', 'C', 'N', 'O', 'S', 'X', 
                   'GLY', 'ALA', 'ILE', 'LEU', 
                   'PRO', 'SER', 'THR', 'CYS', 
                   'MET', 'ASP', 'ASN', 'GLU', 
                   'GLN', 'LYS', 'ARG', 'HIS', 
                   'PHE', 'TYR', 'TRP', 'XXX', 
                   'HELIX', 'STRAND']
    
    filter_num = len(layer_names)
    m = np.zeros((x_bins, y_bins, z_bins), dtype=zip(layer_names, [np.float32]*len(layer_names)))
    
    atoms = []
    myspace = {'atoms': atoms}
    cmd.iterate_state(0, name, 'atoms.append([type, name, elem, resn, ss, x, y, z])', space=myspace)
    #print(atoms)
    
    for atom_desc in atoms: #cmd.iterate_state(0, name+'//C//', 'type, name, elem, resn, ss, x, y, z'):
        set_values(m, atom_desc, left_bottom, right_up, bin_size)
        
    return np.array([m[name] for name in layer_names])

In [None]:
# select terminal peptide CA and all atoms in range of 15 A from the peptide
def select_atoms(name, pept, plen):
    parser = Bio.PDB.PDBParser(QUIET=True)
    ppb = Bio.PDB.CaPPBuilder() 
    struct = parser.get_structure(id=name, file=corrected_pdb_path+'/'+name+'.pdb')[0]
    reslist = ppb.build_peptides(struct[pept])[0]
    firstCA = reslist[0].get_id()[1]
    lastCA = reslist[-1].get_id()[1]
    
    assert(lastCA-firstCA+1 == plen)
    
    cmd.select('fixed', '%s//%s/%i/CA %s//%s/%i/CA byres %s be. 15 of %s//%s//' % (name, pept, firstCA, 
                                                                                   name, pept, lastCA, 
                                                                                   name, name, pept))
    idx = []
    myspace = {'idx': idx}
    cmd.iterate_state(0, name, 'idx.append(ID)', space=myspace)  
    
    return idx