In [1]:
import numpy as np
from layout import Layout
from pprint import pprint
from collections import defaultdict

In [2]:
class Run:
    def __init__(self, src_rank, src_i, src_j, src_start_k, dst_rank, dst_i, dst_j, dst_start_k):
        self.src_rank = src_rank
        self.dst_rank = dst_rank
        self.src_loc = (src_i, src_j, src_start_k)
        self.dst_loc = (dst_i, dst_j, dst_start_k)
        self.len = 1 
    def test_and_extend(self, src_rank, src_i, src_j, src_k, dst_rank, dst_i, dst_j, dst_k):
        if (self.src_rank == src_rank and self.dst_rank == dst_rank
            and self.src_loc == (src_i, src_j, src_k - self.len)
            and self.dst_loc == (dst_i, dst_j, dst_k - self.len)):
            self.len += 1
            return True
        else:
            return False
    def __str__(self):
        return f'rank {self.src_rank} {self.src_loc} -> rank {self.dst_rank} {self.dst_loc}, length {self.len}'
    

In [3]:
def build_run_table(src_rank, src_layout, dst_layout):
    run_d = defaultdict(list)
    for i in range(src_layout.shape[0]):
        for j in range(src_layout.shape[1]):
            current_run = None
            for k in range(src_layout.shape[2]):
                gbl_i, gbl_j, gbl_k = src_layout.lcl_to_gbl(src_rank, (i, j, k))
                dst_rank, (dst_i, dst_j, dst_k) = dst_layout.gbl_to_lcl((gbl_i, gbl_j, gbl_k))
                if not current_run or not current_run.test_and_extend(src_rank, i, j, k, dst_rank, dst_i, dst_j, dst_k):
                    if current_run:
                        run_d[(i,j)].append(current_run)
                    current_run = Run(src_rank, i, j, k, dst_rank, dst_i, dst_j, dst_k)
            if current_run:
                run_d[(i, j)].append(current_run)
    return {key: value for key, value in run_d.items()}  # turn it into a regular dict

In [4]:
block_layout = Layout(4*4*4, (128, 128, 128), (512, 512, 512))
xy_slab_layout = Layout(4*4*4, (512, 512, 8), (512, 512, 512))
z_pencil_layout = Layout(4*4*4, (64, 64, 512), (512, 512, 512))
yz_slab_layout = Layout(4*4*4, (8, 512, 512), (512, 512, 512))
x_pencil_layout = Layout(4*4*4, (512, 64, 64), (512, 512, 512))

In [5]:
run_d = build_run_table(0, yz_slab_layout, x_pencil_layout)

In [9]:
run_x_indices = sorted(list(set([i for i,j in run_d])))
i_mid = run_x_indices[len(run_x_indices)//2]
run_y_indices = sorted(list(set([j for i,j in run_d])))
j_mid = run_y_indices[len(run_y_indices)//2]


for run in run_d[(i_mid, j_mid)]:
    print(str(run))

tot_runs = sum([len(run_d[key]) for key in run_d])
print(f'Total runs in the table: {tot_runs}')

rank 0 (4, 256, 0) -> rank 32 (4, 0, 0), length 64
rank 0 (4, 256, 64) -> rank 33 (4, 0, 0), length 64
rank 0 (4, 256, 128) -> rank 34 (4, 0, 0), length 64
rank 0 (4, 256, 192) -> rank 35 (4, 0, 0), length 64
rank 0 (4, 256, 256) -> rank 36 (4, 0, 0), length 64
rank 0 (4, 256, 320) -> rank 37 (4, 0, 0), length 64
rank 0 (4, 256, 384) -> rank 38 (4, 0, 0), length 64
rank 0 (4, 256, 448) -> rank 39 (4, 0, 0), length 64
Total runs in the table: 32768


In [89]:
len([1,2,3])

TypeError: 'int' object is not callable

* block_layout -> yz_slab_layout yields 16384 runs
* yz_slab_layout to x_pencil_layout yields 32768 runs
* x_pencil_layout -> block_layout yields 32768 runs

* block_layout -> xy_slab_layout yields 262144 runs
* xy_slab_layout -> z_pencil_layout yields 262144 runs
* z_pencil_layout -> block_layout yields 16384 runs

In [90]:
src_layout = z_pencil_layout
dst_layout = block_layout
print(src_layout.shape)
test = np.zeros(src_layout.shape)
dst_rank_d = defaultdict(list)
for i, j in run_d:
    #test[i,j,0] = 10000 * i + j
    offset = (((i * src_layout.shape[1]) + j) * src_layout.shape[2])
    #print (i,j,offset)
    #print(test.flatten()[offset])
    for run in run_d[(i, j)]:
        src_i, src_j, src_start_k = run.src_loc
        assert src_i == i and src_j == j, 'src i or j mismatch?'
        run_abs_start = offset + src_start_k
        run_len = run.len
        #print('*',run_abs_start, run_len, run.dst_rank)
        dst_rank_d[run.dst_rank].append((run_abs_start, run_len))
for key in dst_rank_d:
    print(f'To rank {key}:')
    for offset, run_len in dst_rank_d[key]:
        print(f'    {run_len} at {offset}')


(64, 64, 512)
To rank 0:
    64 at 0
    64 at 512
    64 at 1024
    64 at 1536
    64 at 2048
    64 at 2560
    64 at 3072
    64 at 3584
    64 at 4096
    64 at 4608
    64 at 5120
    64 at 5632
    64 at 6144
    64 at 6656
    64 at 7168
    64 at 7680
    64 at 8192
    64 at 8704
    64 at 9216
    64 at 9728
    64 at 10240
    64 at 10752
    64 at 11264
    64 at 11776
    64 at 12288
    64 at 12800
    64 at 13312
    64 at 13824
    64 at 14336
    64 at 14848
    64 at 15360
    64 at 15872
    64 at 16384
    64 at 16896
    64 at 17408
    64 at 17920
    64 at 18432
    64 at 18944
    64 at 19456
    64 at 19968
    64 at 20480
    64 at 20992
    64 at 21504
    64 at 22016
    64 at 22528
    64 at 23040
    64 at 23552
    64 at 24064
    64 at 24576
    64 at 25088
    64 at 25600
    64 at 26112
    64 at 26624
    64 at 27136
    64 at 27648
    64 at 28160
    64 at 28672
    64 at 29184
    64 at 29696
    64 at 30208
    64 at 30720
    64 at 31232
    64 