In [1]:
import iotbx.pdb
from scitbx.array_family import flex
from cctbx import crystal
import cctbx
import math
from iotbx.data_manager import DataManager
from itertools import chain
from collections import defaultdict
from scipy.spatial import KDTree
import sys
sys.path.append("../")
from qscore_utils import *
dm = DataManager()
dm.process_model_file("../../data/1yjp.pdb")
model = dm.get_model()

In [2]:
import itertools
class FlexContainer:
    def __init__(self, data,debug=True,shape=None):
        if isinstance(data,np.ndarray):
            int_types = tuple([dtype for dtype in np.sctypeDict.values() if np.issubdtype(dtype, np.integer)])
            float_types = tuple([dtype for dtype in np.sctypeDict.values() if np.issubdtype(dtype, np.floating)])
            bool_types = tuple([bool,np.bool_])
            
            if data.dtype in int_types:
                data = flex.uint64(data)
            elif data.dtype in float_types:
                data =flex.double(data)
            elif data.dtype in bool_types:
                data = flex.bool(data)
            else:
                assert False, "Unrecognized numpy dtype: "+str(data.dtype)
        
        assert "scitbx_array_family_flex" in str(type(data)), "Initialize with a flex array"
        
        if shape is None:
            shape = data.focus()
        if isinstance(shape,int):
            shape = (shape,)
        if shape!=data.focus():#,"Shape and focus are out of sync :"+str(data.focus())+"(flex focus) vs "+str(shape)+" (shape)"
            data.reshape(flex.grid(shape))
            
        self._data = data
        self._shape = shape
        self._debug=debug # if true, some  numpy consitency checks are performed dynamically
    
    def __repr__(self):
        a = self.np
        return self.as_numpy_array().__class__.__repr__(self)
    
    def __len__(self):
        return self.shape[0]
    
    def _intake_other(self,other):
        if isinstance(other,FlexContainer):
            other = other.data
        elif "scitbx_array_family_flex" in str(type(other)):
            other = other
        elif isinstance(other,np.ndarray):
            other = FlexContainer(other).data
        else:
            assert False, "Unable to intake the 'other' object of type: "+str(type(other))
        return other




    def sum(self, axis=None):
        # If axis is None, sum all values
        if axis is None:
            total_sum = flex.sum(self.data)
            return self.__class__(flex.double([total_sum]), shape=(1,))
        if not isinstance(self.shape, tuple) or len(self.shape) == 1:
            if axis not in [None, 0]:
                raise ValueError(f"For a 1D FlexContainer, 'axis' should be either 0 or None, but got {axis}")

        

        # Check if axis is valid
        if axis < 0 or axis >= len(self.shape):
            raise ValueError(f"Invalid axis {axis}. Expected value between 0 and {len(self.shape) - 1}")

        # Construct a slicing tuple for all dimensions
        full_slice = [slice(None) for _ in self.shape]

        # Create a list of all axes except the one being summed over
        other_axes = [i for i in range(len(self.shape)) if i != axis]

        # List to store summed values
        all_indices = list(itertools.product(*(range(self.shape[i]) for i in other_axes)))
        result = flex.double(len(all_indices), 0.0)

        # Use itertools.product to generate combinations of indices from other_axes
        for indices_idx, indices in enumerate(all_indices):
            current_slice = full_slice.copy()
            for ax, idx in zip(other_axes, indices):
                current_slice[ax] = idx

            all_sum_indices = self.__class__._nd_to_1d_indices(tuple(current_slice), self.shape)
            values_to_sum = self.data.select(flex.uint32(all_sum_indices))
            result[indices_idx] = flex.sum(values_to_sum)

        # Calculate the new shape by removing the summed axis
        new_shape = list(self.shape)
        #print(f"Original shape: {self.shape}")  # Debug statement
        #print(f"Axis to delete: {axis}")        # Debug statement
        del new_shape[axis]
        new_shape = tuple(new_shape)

        return self.__class__(result, shape=new_shape)








    
    def __add__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data + other)

    def __sub__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data - other)

    def __mul__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data * other)

    def __truediv__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data / other)

    def __floordiv__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data // other)

    def __mod__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data % other)

    def __pow__(self, power, modulo=None):
        if modulo:
            return self.__class__(pow(self.data, power, modulo))
        return self.__class__(self.data ** power)

    def __lshift__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data << other)

    def __rshift__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data >> other)

    def __and__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data & other)

    def __or__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data | other)

    def __xor__(self, other):
        other = self._intake_other(other)
        return self.__class__(self.data ^ other)

    # There are also reflected and in-place versions of these operations.
    # For example:
    # def __radd__(self, other):
    #     return other + self.data
    # 
    # def __iadd__(self, other):
    #     self.data += other
    #     return self
    # ... and so on for each operation


    
    def __getitem__(self, indices):
        if not isinstance(indices,tuple):
            indices = (indices,)
        #print("Input indices:", indices)
        indices = self.__class__._preprocess_indices(indices, self.shape)
        #print("Processed indices:", indices)
        indices_flat = self.__class__._nd_to_1d_indices(indices, self.shape)
        #print("Flattened indices:", indices_flat, "Length:", len(indices_flat))
        if not  "scitbx_array_family_flex" in str(type(indices_flat)):
            indices_flat = flex.uint32(indices_flat)
        #print(list(indices_flat))
        new_shape = self.__class__._compute_sliced_shape(self.shape,indices)
        #print("Computed new shape:", new_shape)
        new_data = self.data.select(indices_flat)
        #print(new_shape)
        if len(new_shape)==0:
            new_shape = (1,)
        new_data.reshape(flex.grid(new_shape))
        out = self.__class__(new_data,shape=new_shape)
        if self._debug:
            #print(indices)
            try:
                compare = out.np==self.np[*indices]
            except:
                compare = out.np==self.np[indices]
            passed = np.all(compare)
            if not passed:
                #print(out.np)
                #print(self.np[*indices])
                assert False
           
        return out
    
    @staticmethod
    def _preprocess_indices(indices, shape):
        # Convert scalar indices and single slices to a tuple
        if isinstance(indices, (int, slice, np.ndarray)) or "scitbx_array_family_flex" in str(type(indices)):
            indices = [indices]
        else:
            indices = list(indices)

        # Special handling for 1D arrays with list-based indices
        if len(shape) == 1 and isinstance(indices[0], list):
            indices = [indices]

        # Convert numpy arrays and flex arrays to list
        for i, idx in enumerate(indices):
            if isinstance(idx, np.ndarray):
                indices[i] = idx.tolist()
            elif "scitbx_array_family_flex" in str(type(idx)):
                indices[i] = list(idx)

        # Pad the indices to match the shape's dimensions with slices
        while len(indices) < len(shape):
            indices.append(slice(None, None, None))

        return indices[:len(shape)+1]  # In case indices originally had more dimensions than shape

#     @staticmethod
#     def _nd_to_1d_indices(indices, shape):
#         # Calculate strides for each dimension
#         strides = [1] * len(shape)
#         for i in reversed(range(len(shape) - 1)):
#             strides[i] = strides[i + 1] * shape[i + 1]

#         if len(indices) == 1 and isinstance(indices[0], list):
#             # Handle special case of list-based indexing
#             return [index * strides[0] for index in indices[0]]

#         indices_multi_dim = [range(i.start or 0, i.stop or s, i.step or 1) if isinstance(i, slice) else [i] if isinstance(i, int) else i for i, s in zip(indices, shape)]

#         all_indices = list(itertools.product(*indices_multi_dim))

#         # Pre-calculate 1D indices
#         result = []
#         for index_tuple in all_indices:
#             index_1d = 0
#             for i, s in zip(index_tuple, strides):
#                 index_1d += i * s
#             result.append(index_1d)

#         return result




    @staticmethod
    def _nd_to_1d_indices(indices, shape):
        # Calculate strides for each dimension
        strides = [1] * len(shape)
        for i in reversed(range(len(shape)-1)):
            strides[i] = strides[i+1] * shape[i+1]

        if len(indices) == 1 and isinstance(indices[0], list):
            # Handle special case of list-based indexing
            return [index * strides[0] for index in indices[0]]

        indices_multi_dim = [range(i.start or 0, i.stop or s, i.step or 1) if isinstance(i, slice) else [i] if isinstance(i, int) else i for i, s in zip(indices, shape)]
        all_indices = itertools.product(*indices_multi_dim)

        return [sum(i * s for i, s in zip(index_tuple, strides)) for index_tuple in all_indices] # expensive line

    
    
    @staticmethod
    def _compute_sliced_shape(shape, indices):
        output_shape = []
        for dim_size, idx in zip(shape, indices):
            if isinstance(idx, slice):
                start, stop, step = idx.indices(dim_size)
                sliced_size = len(range(start, stop, step))
                output_shape.append(sliced_size)
            elif isinstance(idx, int):
                continue
            elif isinstance(idx, list):
                output_shape.append(len(idx))
            else:
                raise ValueError("Invalid index encountered: " + str(idx))

        return tuple(output_shape)
    
    def __setitem__(self, indices,value):
        if isinstance(value,(float,int,bool)):
            if isinstance(value,float):
                value = flex.double([value])
            elif isinstance(value,int):
                value = flex.uint64([value])
            elif isinstance(value,bool):
                value = flex.bool([value])
        assert "scitbx_array_family_flex" in str(type(value)) or isinstance(value,self.__class__), "Set new data with a flex array or FlexContainer"
        if isinstance(value,self.__class__):
            value = value.data
        indices = self.__class__._preprocess_indices(indices, self.shape)
        indices_flat = self.__class__._nd_to_1d_indices(indices, self.shape)
        if not  "scitbx_array_family_flex" in str(type(indices_flat)):
            indices_flat = flex.uint32(indices_flat)
        self.data.reshape(flex.grid((self.data.size(),)))
        self._data = self.data.set_selected(indices_flat,value.as_1d())
        self.data.reshape(flex.grid(self.shape))
        
    
    @property
    def np(self):
        return self.numpy()
    
    def numpy(self):
        return self.data.as_numpy_array()
    @property
    def ndim(self):
        return len(self.shape)
    @property
    def shape(self):
        return self._shape
    @shape.setter
    def shape(self,value):
        self._shape = value
    @property
    def data(self):
        return self._data







    
    @staticmethod
    def _reshape_to_nd_list(data, shape):
        if not isinstance(data,list):
            data = list(data)
        # Base case: if the shape tuple has only one dimension left
        if len(shape) == 1:
            return data[:shape[0]]
        # Recursive case: split data into chunks and reshape further
        chunk_size = len(data) // shape[0]
        return [FlexContainer._reshape_to_nd_list(data[i*chunk_size:(i+1)*chunk_size], shape[1:])
                for i in range(shape[0])]
    
    @staticmethod
    def _get_depth(lst, level=1):
        """Recursively determine the depth of the nested list"""
        if not isinstance(lst, list) or not lst:
            return level
        return max(FlexContainer._get_depth(item, level + 1) for item in lst)
    
    @staticmethod
    def _flatten_nested_list(nested_list, target_dim=None):
        current_depth = FlexContainer._get_depth(nested_list)

        if target_dim is None:
            target_dim = 1

        if target_dim <= 0 or target_dim > current_depth:
            raise ValueError(f"Target dimension must be between 1 and {current_depth} (inclusive).")

        flatten_count = current_depth - target_dim
        flattened_list = nested_list
        for _ in range(flatten_count):
            flattened_list = [item for sublist in flattened_list for item in (sublist if isinstance(sublist, list) else [sublist])]
        return flattened_list

    def tolist(self):
        return self._reshape_to_nd_list(self.data,self.shape)
    
    def reshape(self,shape):
        self.data.reshape(flex.grid(shape))
        self.shape = shape

In [13]:
a = h.as_numpy_array()
a.__class__.__repr__(a)

NameError: name 'h' is not defined

> [0;32m/tmp/ipykernel_16304/1353764124.py[0m(1)[0;36m<module>[0;34m()[0m
[0;32m----> 1 [0;31m[0ma[0m [0;34m=[0m [0mh[0m[0;34m.[0m[0mas_numpy_array[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m[0ma[0m[0;34m.[0m[0m__class__[0m[0;34m.[0m[0m__repr__[0m[0;34m([0m[0ma[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  exit


In [3]:
import numpy as np
print(FlexContainer._preprocess_indices([0,1,2],(3,3)))
print(FlexContainer._preprocess_indices([[0,1,2]],(3,3)))
print(FlexContainer._preprocess_indices((1,),(3,3)))
print(FlexContainer._preprocess_indices((slice(None,None,None),2),(3,3)))

[0, 1, 2]
[[0, 1, 2], slice(None, None, None)]
[1, slice(None, None, None)]
[slice(None, None, None), 2]


In [4]:
import numpy as np
from scipy.spatial import distance

# Example data

A = np.array([[1, 2,3], [3, 4,5], [5, 6,7]])
B = np.array([[7, 8,9], [9, 10,11], [11, 12,13]])
shape = A.shape
dist_matrix = distance.cdist(A, B)
flattened_dists = dist_matrix.flatten()

# Create an association between flattened distances and their (i, j) indices
indices = np.array([(i,j) for i in range(shape[0]) for j in range(shape[0])])


i_s = indices[:,0]
j_s = indices[:,1]
ds = np.sqrt(np.sum((A[i_s] - B[j_s])**2,axis=1))


for idx,(i, j) in enumerate(indices):
    d = dist_matrix[i,j]
    print(f"Distance between A[{i}] and B[{j}] is {d}")
    d2 = np.sqrt(np.sum((A[i] - B[j])**2))
    print(f"Distance between A[{i}] and B[{j}] is {d2}")
    d3 = ds[idx]
    print(f"Distance between A[{i}] and B[{j}] is {d3}\n")

Distance between A[0] and B[0] is 10.392304845413264
Distance between A[0] and B[0] is 10.392304845413264
Distance between A[0] and B[0] is 10.392304845413264

Distance between A[0] and B[1] is 13.856406460551018
Distance between A[0] and B[1] is 13.856406460551018
Distance between A[0] and B[1] is 13.856406460551018

Distance between A[0] and B[2] is 17.320508075688775
Distance between A[0] and B[2] is 17.320508075688775
Distance between A[0] and B[2] is 17.320508075688775

Distance between A[1] and B[0] is 6.928203230275509
Distance between A[1] and B[0] is 6.928203230275509
Distance between A[1] and B[0] is 6.928203230275509

Distance between A[1] and B[1] is 10.392304845413264
Distance between A[1] and B[1] is 10.392304845413264
Distance between A[1] and B[1] is 10.392304845413264

Distance between A[1] and B[2] is 13.856406460551018
Distance between A[1] and B[2] is 13.856406460551018
Distance between A[1] and B[2] is 13.856406460551018

Distance between A[2] and B[0] is 3.4641016

In [5]:
A_flex = FlexContainer(A.astype(float))
B_flex = FlexContainer(B.astype(float))

In [6]:
FlexContainer._nd_to_1d_indices(([1, 2, 1],slice(None,None,None)),A_flex.shape)

[3, 4, 5, 6, 7, 8, 3, 4, 5]

In [7]:
a = A_flex
print(a.shape)
print(a[[1,2,1]].shape) # Expected: (3,)
print(a[[1]].shape)

(3, 3)
(3, 3)
(1, 3)


In [8]:

a = ((((A_flex[i_s]-B_flex[j_s])**2)).sum(axis=1))**0.5

ds = (((A[i_s] - B[j_s])**2).sum(axis=1))**0.5

assert np.all(np.isclose(a.np,ds))

In [9]:
from scipy.spatial.distance import pdist,cdist
from cctbx.array_family import flex
import itertools

arr = np.random.random((1000,3))
A,B = arr,arr
A_flex,B_flex = flex.double(A), flex.double(B)



In [10]:
from qscore_cctbx import radial_shell_worker_cctbx
from qscore import radial_shell_worker

In [11]:
%pdb

Automatic pdb calling has been turned ON


In [12]:
from qscore import radial_shell_worker, radial_shell_mp

In [14]:
args = 0,atoms_xyz.as_numpy_array(),16, 8,2,0.9 
probes_xyz, keep_mask, all_pts_np = radial_shell_worker(args) # probe_xyz, keep_mask = 

In [26]:
for i in range(66): # for each atom
    for j in range(8): # for each probe
        pt_np = all_pts_np[i][j]
        pt_flex = all_pts_cctbx[i][j].as_numpy_array()
        assert np.all(np.isclose(pt_np,pt_flex)), (i,j)
        

AssertionError: (12, 6)

In [30]:
i,j = 12,6
print(all_pts_np[i][j])
print(all_pts_cctbx[i][j].as_numpy_array())

[-5.61218292  0.34915574  4.945     ]
[[-6.23383033  1.61436062  4.785     ]]


In [18]:
assert np.all(probes_xyz_flex.as_numpy_array()==probes_xyz)

NameError: name 'probes_xyz' is not defined

In [31]:
from multiprocessing import Pool, cpu_count


In [55]:
radii

[1e-06, 1, 2]

In [35]:
radii = [1e-6,1,2]
probe_results_np= radial_shell_mp(atoms_xyz.as_numpy_array(),radii=radii,n_probes=16,n_probes_target=8)
#probe_xyz,keep_mask = radial_shell_mp(atoms_xyz.as_numpy_array(),radii=radii,n_probes=16,n_probes_target=8)

In [61]:
from multiprocessing import Pool, cpu_count
def radial_shell_cctbx_mp(model,n_probes=64,n_probes_target=8,radii=None,rtol=0.9,num_processes=cpu_count()):

    # Create argument tuples for each chunk
    args = [(i,model,n_probes,n_probes_target,radius_shell,rtol) for i,radius_shell in enumerate(radii)]
    
    # Create a pool of worker processes
    if num_processes >1:
      with Pool(num_processes) as p:
          # Use the pool to run the trilinear_interpolation_worker function in parallel
          results = p.map(radial_shell_worker_cctbx, args)
    else:
      results = []
      for arg in args:
        result = radial_shell_worker_cctbx(arg)
        results.append(result)

        
    probe_xyz_all = [result[0] for result in results]
    keep_mask_all = [result[1] for result in results]
    all_pts_all = [result[2] for result in results]
    return all_pts_all # debug
    n_shells = len(radii)
    n_atoms = model.get_number_of_atoms()
    out_shape = (n_shells,n_atoms,n_probes_target,3 )
    out_size = math.prod(out_shape)
    shell_size = math.prod(out_shape[1:])
    out_probes = flex.double(out_size,-1.0)
    for i,p in enumerate(probe_xyz_all):
        start = i*shell_size
        stop = start+shell_size
        out_probes = out_probes.set_selected(flex.uint32_range(start,stop),p.as_1d())
    out_probes.reshape(flex.grid(*out_shape))

    out_mask = flex.bool(n_atoms*n_shells*n_probes_target,False)
    for i,k in enumerate(keep_mask_all):
        start = i*(n_atoms*n_probes_target)
        stop = start+(n_atoms*n_probes_target)
        out_mask = out_mask.set_selected(flex.uint32_range(start,stop),k.as_1d())
    out_mask = flex.bool(n_atoms*n_shells*n_probes_target,False)
    out_mask.reshape(flex.grid(n_shells,n_atoms,n_probes_target))

In [62]:
probe_results_cctbx= radial_shell_cctbx_mp(model,radii=radii,n_probes=16,n_probes_target=8)


IndexError: tuple index out of range

In [None]:
for i in range(3):
    for j in range(model.get_number_of_atoms()):
        
        a = probe_results_np[i][j]
        b = probe_results_cctbx[i].as_numpy_array()[j]
        assert np.all(np.isclose(a,b)), print(i,j)

In [53]:
probe_results_np[2][12]

array([[-4.43800000e+00,  1.59000000e+00,  1.90500000e+00],
       [-5.08773129e+00,  4.64189514e-01,  2.38500000e+00],
       [-3.81407847e+00,  2.62927308e-01,  2.54500000e+00],
       [-3.00695678e+00,  2.52301410e+00,  2.86500000e+00],
       [-5.25574576e+00, -2.19334648e-01,  3.66500000e+00],
       [-3.88169632e+00, -3.29407779e-01,  3.82500000e+00],
       [-5.61218292e+00,  3.49155742e-01,  4.94500000e+00],
       [-4.23896468e+00,  2.42797347e-03,  5.10500000e+00]])

In [54]:
probe_results_cctbx[2].as_numpy_array()[12]

array([[-4.438     ,  1.59      ,  1.905     ],
       [-5.08773129,  0.46418951,  2.385     ],
       [-3.81407847,  0.26292731,  2.545     ],
       [-3.00695678,  2.5230141 ,  2.865     ],
       [-5.25574576, -0.21933465,  3.665     ],
       [-3.88169632, -0.32940778,  3.825     ],
       [-6.23383033,  1.61436062,  4.785     ],
       [-5.61218292,  0.34915574,  4.945     ]])

In [37]:
a,b,c,d = np.where(~np.isclose(probe_xyz,out_probes.as_numpy_array()))

In [50]:
for i in range(3):
    a = probe_xyz[i]
    b = probe_xyz_all[i].as_numpy_array()
    assert np.all(np.isclose(a,b))

AssertionError: 

False

In [38]:
probe_xyz[a,b,c,d]

array([-5.61218292e+00,  3.49155742e-01,  4.94500000e+00, -4.23896468e+00,
        2.42797347e-03,  5.10500000e+00, -1.91869799e+00,  3.84089035e+00,
        5.30729412e+00,  1.07687204e-01,  2.70859406e+00,  6.24847059e+00,
       -3.28062730e-01,  4.25401302e+00,  6.48376471e+00, -1.82012496e+00,
        4.89006463e+00,  6.71905882e+00, -3.27832259e+00,  4.15030251e+00,
        6.95435294e+00, -4.36311930e-01,  4.05501170e+00,  8.13082353e+00,
       -2.11011507e+00,  4.39425206e+00,  8.36611765e+00, -1.46980203e+00,
        2.01291598e+00,  8.83670588e+00, -1.98073152e+00, -7.18973909e-01,
        6.10600000e+00, -3.34388081e-01, -7.85548900e-01,  6.37266667e+00,
        6.11426224e-01,  6.05332017e-01,  6.63933333e+00,  6.96810849e-03,
        2.20052937e+00,  6.90600000e+00, -3.04281973e+00, -1.22051276e-01,
        7.70600000e+00, -1.59903631e+00, -1.14186503e+00,  7.97266667e+00,
        4.94085670e-02, -4.62034454e-01,  8.23933333e+00, -1.47410918e+00,
        2.05948338e+00,  

In [39]:
out_probes.as_numpy_array()[a,b,c,d]

array([-6.23383033,  1.61436062,  4.785     , -5.61218292,  0.34915574,
        4.945     , -1.702     ,  2.925     ,  5.072     , -1.93259453,
        3.89576347,  5.33866667,  0.18042622,  2.81533202,  6.40533333,
       -0.42403189,  4.41052937,  6.672     , -2.0819653 ,  4.88404277,
        6.93866667, -3.47460398,  3.84156825,  7.20533333, -0.22640721,
        3.54356763,  8.272     , -1.90510918,  4.26948338,  8.53866667,
       -1.271     ,  0.715     ,  5.306     , -1.97645861, -0.81554397,
        6.22907692, -0.21276174, -0.79775108,  6.53676923,  0.67422924,
        0.77036652,  6.84446154, -2.49218647, -0.66955323,  8.07523077,
       -0.5853128 , -0.82450308,  8.38292308, -1.70183662,  1.68992281,
        8.99830769, -1.271     ,  0.715     ,  9.306     ,  2.35211673,
        2.69285523, -1.97471429,  0.98066484,  2.90177417, -1.06042857,
        0.71761155,  1.77174567, -0.94614286,  4.46745093,  2.29545139,
       -0.37471429,  0.86993515,  3.18437026,  0.08242857,  4.00