In [3]:
import numpy as np
from itertools import combinations, combinations_with_replacement

from typing import Tuple, Generator, Iterator

In [4]:
class ColumnBunches():
    
    def __init__(self, data: np.ndarray, k: int, window: int, skip_diag=False):

        self.data = data
        self.k = k  # dimensions
        self.w = window
        self.skip_diag = skip_diag
        
        N = data.shape[1]
        self.M = (N-1) // self.w + 1
        
        self.IJK = self._IJK_gen(self.skip_diag)
        self.IJK_diag = self._IJK_diag_gen()

        
    def _IJK_gen(self, skip_diag=False) -> Generator:
        
        return combinations(range(self.M), self.k) if skip_diag else combinations_with_replacement(range(self.M), self.k)

    
    def _IJK_diag_gen(self) -> Generator:
        
        IJK_gen = self._IJK_gen(skip_diag=False)
        
        while True:
            try:
                IJK = next(IJK_gen)
                if len({*IJK}) < self.k:
                    yield IJK
            except StopIteration:
                return
            
    
    def _get_columns(self, indeces) -> Tuple: # do I want a generator instead of a tuple?
        return tuple(self.data[:,self.w*idx : self.w*(idx+1)] for idx in indeces)

    def __next__(self) -> Tuple:
        return self._get_columns(next(self.IJK))
        
    def __iter__(self) -> Iterator:
        return self
    
    def next_diag(self) -> Tuple:
        return self._get_columns(next(self.IJK_diag))

---

In [5]:
some_data = np.array([list(range(2_000))]*502)

In [13]:
%%timeit
bunches = ColumnBunches(some_data, k=4, window=2, skip_diag=True)
some_bunch = next(bunches)
some_other_bunch = bunches.next_diag()

59.2 µs ± 1.94 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [4]:
%%timeit
bunches = ColumnBunches(some_data, k=10, window=2, skip_diag=True)
some_bunch = next(bunches)
some_other_bunch = bunches.next_diag()

65.9 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]:
%%timeit
bunches = ColumnBunches(some_data, k=20, window=2, skip_diag=True)
some_bunch = next(bunches)
some_other_bunch = bunches.next_diag()

78.5 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [6]:
%%timeit
bunches = ColumnBunches(some_data, k=100, window=2, skip_diag=True)
some_bunch = next(bunches)
some_other_bunch = bunches.next_diag()

176 µs ± 781 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
%%timeit
bunches = ColumnBunches(some_data, k=200, window=2, skip_diag=True)
some_bunch = next(bunches)
some_other_bunch = bunches.next_diag()

300 µs ± 5.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


---

In [9]:
k = 4
M1 = np.array([list(range(10))]*2)
M1

array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [10]:
bunches = ColumnBunches(M1, k=k, window=3)

for bunch in bunches:
    for columns in bunch:
        print(columns)
    print(">-------")

[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[

In [11]:
# get 'non-diagonal' (yellow) bunches only

bunches = ColumnBunches(M1, k=k, window=3, skip_diag=True)

for bunch in bunches:
    for columns in bunch:
        print(columns)
    print(">-------")

[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------


In [12]:
# get 'diagonal' (blue) bunches only

bunches = ColumnBunches(M1, k=k, window=3, skip_diag=True)

while True:
    try:
        bunch = bunches.next_diag()
        for columns in bunch:
            print(columns)
        print(">-------")
    except:
        break

[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[6 7 8]
 [6 7 8]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[0 1 2]
 [0 1 2]]
[[9]
 [9]]
[[9]
 [9]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
>-------
[[0 1 2]
 [0 1 2]]
[[3 4 5]
 [3 4 5]]
[[3 4 5]
 [3 4 5]]
[[6 7 8]
 [6 7 8]]
>-------
[[0 1 2]
 [0 1 2]]
[