In [37]:
import numpy as np
import math

In [38]:
def get_parts_matrix(max_dist, n_parts, seed):
        # return parts as binary mask on 7 x 7 grid to ease evaluation
        np.random.seed(seed)
        n_row = 7
        n_pix_per_cell = 50
        part_locs = [np.random.randint(350, size=2) for _ in range(n_parts)]
        parts = np.zeros((n_parts, n_row, n_row))
        if max_dist == 0:
            for part_loc, part_id in zip(part_locs, range(n_parts)):
                x_coord = int(part_loc[0] // n_pix_per_cell)
                y_coord = int(part_loc[1] // n_pix_per_cell)
                if part_id != -1:
                    parts[part_id, y_coord, x_coord] = 1
        else:
            for part_loc, part_id in zip(part_locs, range(n_parts)):
                if part_id == -1:
                    continue
                x_coord_part = part_loc[0]
                y_coord_part = part_loc[1]
                for cell_x, cell_y in np.ndindex(n_row, n_row):
                    x_coord_cell = cell_x * n_pix_per_cell
                    y_coord_cell = cell_y * n_pix_per_cell
                    dx = max(x_coord_cell - x_coord_part, 0, x_coord_part - x_coord_cell - n_pix_per_cell)
                    dy = max(y_coord_cell - y_coord_part, 0, y_coord_part - y_coord_cell - n_pix_per_cell)
                    dist = math.sqrt(dx * dx + dy * dy) / n_pix_per_cell
                    print(f"{cell_x}, {cell_y}, {dist}, {1 - dist / max_dist}, {max(parts[part_id, cell_y, cell_x], 1 - dist / max_dist)}")
                    parts[part_id, cell_y, cell_x] = max(parts[part_id, cell_y, cell_x], 1 - dist / max_dist)
        return parts

In [39]:
get_parts_matrix(0, 1, 0)

array([[[0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]]])

In [40]:
get_parts_matrix(3.0, 1, 0)

0, 0, 2.44, 0.18666666666666665, 0.18666666666666665
0, 1, 2.4407375934335915, 0.18642080218880286, 0.18642080218880286
0, 2, 2.6603007348794234, 0.11323308837352553, 0.11323308837352553
0, 3, 3.193305497443049, -0.0644351658143496, 0.0
0, 4, 3.9137194585202453, -0.3045731528400817, 0.0
0, 5, 4.736792163479415, -0.5789307211598052, 0.0
0, 6, 5.617579549948537, -0.8725265166495122, 0.0
1, 0, 1.44, 0.52, 0.52
1, 1, 1.4412494579357176, 0.5195835140214276, 0.5195835140214276
1, 2, 1.7880715869338117, 0.4039761376887294, 0.4039761376887294
1, 3, 2.513404066201851, 0.16219864459938294, 0.16219864459938294
1, 4, 3.3818929610500685, -0.1272976536833561, 0.0
1, 5, 4.307806866608576, -0.43593562220285875, 0.0
1, 6, 5.260912468384168, -0.7536374894613893, 0.0
2, 0, 0.44, 0.8533333333333333, 0.8533333333333333
2, 1, 0.4440720662234904, 0.8519759779255032, 0.8519759779255032
2, 2, 1.1476933388322859, 0.6174355537225713, 0.6174355537225713
2, 3, 2.1064662351910606, 0.29784458826964644, 0.29784458826

array([[[0.18666667, 0.52      , 0.85333333, 1.        , 0.81333333,
         0.48      , 0.14666667],
        [0.1864208 , 0.51958351, 0.85197598, 0.98      , 0.81226496,
         0.47961553, 0.14643232],
        [0.11323309, 0.40397614, 0.61743555, 0.64666667, 0.60038908,
         0.37131531, 0.07640798],
        [0.        , 0.16219864, 0.29784459, 0.31333333, 0.28841335,
         0.13865738, 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ,
         0.        , 0.        ]]])