In [1]:
import numpy as np
from typing import Tuple, Generator, Iterator

# for plotting voxels
import matplotlib.pyplot as plt
from functools import reduce
from mpl_toolkits.mplot3d import Axes3D # registers the 3D projection, unused

In [2]:
class ColumnBunches3():
    
    def __init__(self, data: np.ndarray, window: int, skip_diag=False):
        
        self.data = data
        N = data.shape[1]
        self.w = window
        self.sd = int(skip_diag)
        
        self.M = (N-1) // self.w + 1
        
        # I, J, K indeces identify bunches of w columns
        # they sit in range(0:M-1)
        self.IJK = self._IJK_gen()
        self.IJK_diag = self._IJK_diag_gen()
        
    def _IJK_gen(self) -> Generator:
        for I in range(self.M):
            for J in range(I + self.sd, self.M):
                for K in range(J + self.sd, self.M):
                    yield I, J, K
                
    def _IJK_diag_gen(self) -> Generator:
        for I in range(self.M):
            for J in range(I, self.M):
                for K in range(J, self.M):
                    if len({I, J, K}) < 3:
                        yield I, J, K
    
    def _get_columns(self, indeces) -> Tuple:
        return [self.data[:,self.w*idx : self.w*(idx+1)] for idx in indeces]
    
    def __next__(self) -> Tuple:
        return self._get_columns(next(self.IJK))
    
    def __iter__(self) -> Iterator:
        return self
    
    def next_diag(self) -> Tuple:
        return self._get_columns(next(self.IJK_diag))    
        
    def show_voxels(self) -> None:
        shape = (self.M,)*3
        x, y, z = np.indices(shape)
        colors = np.empty(shape, dtype=object)
 
        masks = ( (y == I) & (x == J) & (z == K) for I, J, K in self._IJK_gen() )
        voxels = reduce(lambda m1,m2: m1|m2, masks, np.zeros(shape, dtype=bool))
        colors[voxels] = '#FFD65DF0'
        
        masks_diag = ( (y == I) & (x == J) & (z == K) for I, J, K in self._IJK_diag_gen() )
        voxels_diag = reduce(lambda m1,m2: m1|m2, masks_diag, np.zeros(shape, dtype=bool))
        colors[voxels_diag] = '#7A88CCF0'
        
        fig = plt.figure()
        ax = fig.gca(projection='3d')
        ax.grid(False)
        ax.set_yticks(range(self.M))
        ax.set_xticks(range(self.M))
        ax.set_zticks(range(self.M))
        ax.set_ylabel('I')
        ax.set_xlabel('J')
        ax.set_zlabel('K')
        ax.invert_yaxis()
        ax.invert_zaxis()
        ax.voxels(voxels, facecolors = colors, alpha=0.5, edgecolor='k')
        return ax

In [3]:
M1 = np.array([list(range(10))]*2)
M1

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [4]:
bunches = ColumnBunches3(M1, window=3, skip_diag=False)

In [5]:
# illustration of the distinction:
# * blue: diagonal voxels (repeated columns in the returned bunches)
# * yellow: non-diagonal voxels (no repeated bunches)

bunches = ColumnBunches3(M1, window=3, skip_diag=False)

# %matplotlib inline
%matplotlib qt
bunches.show_voxels()
plt.show()

In [6]:
# get all column bunches (diagonal and non-diagonal)

bunches = ColumnBunches3(M1, window=3)

for bunch in bunches:
    for columns in bunch:
        print(columns)
    print(">-------")

[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[6 7 8]
 [6 7 8

In [14]:
# get 'non-diagonal' (yellow) bunches only

bunches = ColumnBunches3(M1, window=3, skip_diag=True)

for bunch in bunches:
    for columns in bunch:
        print(columns)
    print(">-------")

[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------


In [15]:
# get 'diagonal' (blue) bunches only

bunches = ColumnBunches3(M1, window=3, skip_diag=True)

while True:
    try:
        bunch = bunches.next_diag()
        for columns in bunch:
            print(columns)
        print(">-------")
    except:
        break

[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[9]
 [9]]
[[9]
 [9]]
[[9]
 [9]]
>-------
