In [1]:
import numpy as np
import scipy.spatial

class particle_generator:
    def __init__(self, config) -> None:
        self.config = config

    def _generate_input(self):
        return np.random.randn(self.config['data_num'], self.config['input_length'], self.config['input_dim'])

    def _build_distance_matrix(self, input):
        # for a m*n input, return m*m pairwise distance matrix
        res = []
        for x in input:
            densed = scipy.spatial.distance.pdist(x)
            res.append(scipy.spatial.distance.squareform(densed))
        return res
        
    def _parallel_build_distance_matrix(self, input, cpus=8):
        assert len(input) 
        import multiprocessing
        try:
            multiprocessing.set_start_method('fork')
        except Exception:
            pass

        n = self.config['data_num']

        with multiprocessing.Pool(cpus) as pool:
            res = pool.map(self._build_distance_matrix,[n//cpus]*cpus, chunksize=1)

        res = sum(res, start=[]) # Combine the results.
        return res

In [2]:
config = {'data_num': 100000, 'input_length': 128, 'input_dim':2, 'r0':0.5}

generator = particle_generator(config)

In [3]:
x = generator._generate_input()

In [4]:
y = generator._build_distance_matrix(x)

In [5]:
generator._parallel_build_distance_matrix(x)

TypeError: 'int' object is not iterable

In [None]:
8 //4