In [586]:
from __future__ import annotations

from collections import defaultdict
from typing import Any, Hashable

import numpy as np

np.random.seed(42)


class Chunk:
  def __init__(self, subsequence=None, start=-1, end=-1, seq_indices=[]):
    self.subsequence = subsequence
    self.start = start
    self.end = end
    self.seq_indices: list[int] = seq_indices

  def __eq__(self, other):
    if not isinstance(other, Chunk):
      return False
    return (
      self.start == other.start
      and self.end == other.end
      and np.array_equal(self.subsequence, other.subsequence)
    )

  def __hash__(self):
    return hash((self.start, self.end, tuple(self.subsequence)))

  def add_index(self, index: int):
    self.seq_indices.append(index)

  def __repr__(self) -> str:
    return f"Chunk({self.subsequence}, {self.start}, {self.end}, #{len(self.seq_indices)})"


class DB:
  chunks: dict[Hashable, Chunk]
  start_map: dict[int, list[Chunk]]
  end_map: dict[int, list[Chunk]]
  continue_map: dict[tuple[str, str], list[Chunk]]
  seq_map: dict[tuple, list[Chunk]]

  def __init__(self) -> None:
    self.chunks = defaultdict(Chunk)
    self.start_map = defaultdict(list)
    self.end_map = defaultdict(list)
    self.start_end_map = defaultdict(list)
    self.continue_map = defaultdict(list)
    self.seq_map = defaultdict(list)
    self.length_map = defaultdict(list)

  def __getitem__(self, key: Hashable) -> Chunk | None:
    return self.chunks.get(key, None)

  def add(self, chunk: Chunk):
    if chunk in self.chunks:
      self.chunks[chunk].seq_indices.extend(chunk.seq_indices)
    else:
      self.chunks[chunk] = chunk
      self.start_map[chunk.start].append(chunk)
      self.end_map[chunk.end].append(chunk)
      self.start_end_map[(chunk.start, chunk.end)].append(chunk)
      self.seq_map[tuple(chunk.subsequence)].append(chunk)
      self.length_map[len(chunk.subsequence)].append(chunk)
      for i in range(len(chunk.subsequence)):
        if i > 1:
          self.continue_map[
            (f"{chunk.start}*", tuple(chunk.subsequence[:i]))
          ].append(chunk)
        if i < len(chunk.subsequence) - 2:
          self.continue_map[
            (f"*{chunk.end}", tuple(chunk.subsequence[i + 1 :]))
          ].append(chunk)

  def get_candidate(self, chunk: Chunk) -> list[Chunk]:
    candidates = {
      *self.get(start=chunk.start, sub_sequences=chunk.subsequence),
      *self.get(end=chunk.end, sub_sequences=chunk.subsequence),
    }
    for i in range(1, len(chunk.subsequence) - 1):
      left = self.get(
        start=chunk.start + i,
        end=chunk.end,
        sub_sequences=chunk.subsequence[i:],
      )
      left_continue = self.get(
        start=chunk.start + i, sub_sequences=chunk.subsequence[i:]
      )
      right = self.get(
        start=chunk.start,
        end=chunk.end - i,
        sub_sequences=chunk.subsequence[:-i],
      )
      right_continue = self.get(
        end=chunk.end - i, sub_sequences=chunk.subsequence[:-i]
      )
      candidates.update(set(left + left_continue + right + right_continue))

      if len(candidates) > 1:
        break

    return sorted(
      list(candidates),
      key=lambda x: len(x.subsequence) * 100000 + len(x.seq_indices),
      reverse=True,
    )

  def get(
    self, start: int = -1, end: int = -1, sub_sequences=None
  ) -> list[Chunk]:
    if sub_sequences is None:
      if start == -1 and end == -1:
        raise ValueError("Either start, end or sub_sequences must be provided")
      elif start == -1:
        return self.end_map[end]
      elif end == -1:
        return self.start_map[start]
      else:
        return self.start_end_map[(start, end)]

    else:
      if start == -1 and end == -1:
        return self.seq_map[tuple(sub_sequences)]
      elif start == -1:
        return self.continue_map[(f"*{end}", tuple(sub_sequences))]
      elif end == -1:
        return self.continue_map[(f"{start}*", tuple(sub_sequences))]
      else:
        c = self.chunks.get(
          Chunk(
            start=start, end=end, subsequence=sub_sequences, seq_indices=[]
          ),
          None,
        )
        return [c] if c else []

  def get_random(self):
    return np.random.choice(list(self.chunks.values()))

In [649]:
def find_chunks(
  data: np.ndarray,
  threshold: int = 10,
  max_chunk_length: int = 6,
):
  n_sequences, n_length = data.shape
  db = DB()

  for chunk_length in range(max_chunk_length, 0, -1):
    # Create a view of all possible chunks of the current length
    chunk_view = np.lib.stride_tricks.sliding_window_view(
      data, (1, chunk_length)
    ).reshape(n_sequences, n_length - chunk_length + 1, chunk_length)

    # Hash each chunk
    chunk_hashes = np.apply_along_axis(lambda x: hash(tuple(x)), 2, chunk_view)

    # Find unique chunks and their counts
    unique_chunks, indices, counts = np.unique(
      chunk_hashes, return_inverse=True, return_counts=True, axis=None
    )

    # Process only chunks that meet the threshold
    mask = counts >= threshold
    for chunk_hash in unique_chunks[mask]:
      # Get the indices of sequences containing this chunk
      seq_indices = np.where(chunk_hashes == chunk_hash)[0]

      # Get the start position of the chunk
      start = np.where(chunk_hashes == chunk_hash)[1][0]
      end = start + chunk_length

      # Get the actual chunk sequence
      chunk_seq = chunk_view[seq_indices[0], start]

      db.add(
        Chunk(
          subsequence=chunk_seq,
          start=int(start),
          end=int(end),
          seq_indices=seq_indices,
        )
      )

  return db


# Example usage
threshold = 100
max_chunk_length = 8
data = np.random.randint(1, 5, (10_000, 32))
db = find_chunks(data, threshold, max_chunk_length)

In [650]:
db.length_map.keys()

dict_keys([6, 5, 4, 3, 2, 1])

In [651]:
nodes = sorted(
  list(db.chunks.values()),
  key=lambda x: len(x.subsequence) * 100000 + len(x.seq_indices),
  reverse=True,
)
len(nodes)

1365

In [652]:
from collections import deque

processed = set()

paths = []

for node in nodes:
  deq = deque([node])
  processed.add(node)

  left_candi = [c for c in db.get_candidate(deq[0]) if c not in processed]
  right_candi = [c for c in db.get_candidate(deq[-1]) if c not in processed]

  while len(left_candi) > 0 or len(right_candi) > 0:
    if len(left_candi) > 0:
      deq.appendleft(left_candi[0])
      processed.add(left_candi[0])
    if len(right_candi) > 0:
      deq.append(right_candi[0])
      processed.add(right_candi[0])

    left_candi = [c for c in db.get_candidate(deq[0]) if c not in processed]
    right_candi = [c for c in db.get_candidate(deq[-1]) if c not in processed]
  paths.append(deq)

paths = sorted(
  [p for p in paths],
  key=lambda x: sum([len(c.seq_indices) for c in x]),
  reverse=True,
)
len(paths)

1365

In [671]:
canvas = np.zeros((data.shape[0] * 10000, data.shape[1]))
canvas.shape

(100000000, 32)

In [757]:
import numpy as np
from tqdm import tqdm


def get_shape(path: list[Chunk]) -> np.ndarray:
  total_length = sum(len(chunk.seq_indices) for chunk in path)
  shape = np.zeros((total_length, canvas.shape[1]))
  current_index = 0
  for chunk in path:
    chunk_length = len(chunk.seq_indices)
    shape[
      current_index : current_index + chunk_length, chunk.start : chunk.end
    ] = chunk.subsequence
    current_index += chunk_length
  return shape


def find_drawable_position(canvas: np.ndarray, shape: np.ndarray) -> int:
  if canvas.shape[1] != shape.shape[1]:
    raise ValueError("Canvas and shape must have the same width")

  canvas_height, shape_height = canvas.shape[0], shape.shape[0]
  max_start = canvas_height - shape_height + 1

  for i in range(0, max_start, max(threshold, int(shape_height / 10))):
    sub_canvas = canvas[i : i + shape_height]
    target = shape != 0
    if np.all(sub_canvas[target] == 0):
      return i

  return -1  # No drawable position found


canvas = np.zeros((data.shape[0] * 10000, data.shape[1]))
positions = []
for path in tqdm(paths, desc="Processing paths"):
  shape = get_shape(path)
  position = find_drawable_position(canvas, shape)
  positions.append(position)
  if position != -1:
    canvas[position : position + shape.shape[0]][shape != 0] = shape[shape != 0]

Processing paths: 100%|██████████| 1365/1365 [01:27<00:00, 15.53it/s]


In [758]:
canvas[0]

array([4., 4., 2., 2., 2., 4., 0., 0., 0., 3., 4., 4., 3., 2., 2., 1., 2.,
       3., 4., 4., 1., 1., 2., 2., 2., 1., 0., 0., 2., 3., 2., 1.])

In [759]:
filtered_canvas = canvas[~np.all(canvas == 0, axis=1)]
print(canvas.shape)
print(filtered_canvas.shape)

(100000000, 32)
(724164, 32)


In [760]:
np.sum(filtered_canvas != 0) / np.prod(filtered_canvas.shape)

np.float64(0.5280062596317961)

In [761]:
def print_row(row: np.ndarray):
  print("".join([f"{int(x)}" if x != 0 else " " for x in row]))

In [762]:
prev_row = np.zeros(canvas.shape[1])
for i in range(0, canvas.shape[0]):
  row = canvas[i]
  if not np.array_equal(row, prev_row) and not np.sum(row) == 0:
    print_row(row)
    prev_row = row

442224   34432212344112221  2321
442224   344322123441       2321
442224   344322123441 3341  2321
442224   34432212344133341  2321
442224  2344 2212344133341  2321
442224  2344     344133341  2321
442224  2344     344133341      
442224  2344     34413334131223 
442224  2344    33441 334131223 
442224  2344    33441 3341      
442224  2344    33441 33414243  
442224  2344    33441433414243  
442224  2344   23344 433414243  
442224  2344   23344  33414243  
442224  2344  22334   33414243  
442224  2344 32233    33414243  
442224 4234  32233    33414243  
442224 4234 13223     33414243  
442224 4234 13223     3341 2433 
442224 4234   22342   3341 2433 
442224 4234     34241 3341 2433 
442224 4234     34241      2433 
442224 42341422134241 231112433 
442224 423414221 42413231112433 
442224 4234      42413     2433 
442224 42344     42413     2433 
442224 42344 124142413     2433 
442224 42344 1241 24134    2433 
442224 4234  1241 24134    2433 
442224 4234  1241 24134     4332
442224 423

KeyboardInterrupt: 

In [681]:
sum([len(c.seq_indices) for c in paths[0]])

82886

In [616]:
canvas

array([[0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [561]:
d = deque([1, 2, 3])
print(d[0], d[-1])

1 3


In [533]:
while len(candi):
  org_candi = candi.copy()
  new_candi = [db.get_candidate(c) for c in candi]
  edges = [(org_candi[i], c[0]) for i, c in enumerate(new_candi) if len(c)]
  print(edges)
  break

[(Chunk([0 2 1 2], 7, 11, #123), Chunk([2 1 2 2], 8, 12, #140)), (Chunk([1 2 2 2], 9, 13, #107), Chunk([2 1 2 2], 8, 12, #140)), (Chunk([2 1 2], 8, 11, #504), Chunk([2 1 2 2], 8, 12, #140)), (Chunk([1 2 2], 9, 12, #489), Chunk([2 1 2 2], 8, 12, #140))]


In [413]:
c = db.get_random()
print(c)
print(db.get(start=c.start, sub_sequences=c.subsequence))
print(db.get(end=c.end, sub_sequences=c.subsequence))
print(db.get(start=c.start, end=c.end - 1, sub_sequences=c.subsequence[:-1]))
print(db.get(start=c.start + 1, end=c.end, sub_sequences=c.subsequence[1:]))

Chunk([2 3 3 0 1 0], 10, 16, #60)
[]
[]
[]
[]


In [103]:
db.get(1, 4)

[Chunk([3 3 1], 1, 4, #2345), Chunk([0 2 1], 1, 4, #2376)]

In [108]:
print((db.get(1, 2)[0]))
print(db.get(1, 2, [2]))

Chunk([0], 1, 2, #40120)
[]


In [27]:
print(db.get(1, 4))
print(db.get(0, 5, [1, 3, 3, 1, 3]))

[Chunk([3, 0, 2], 1, 4, #2298), Chunk([3, 1, 0], 1, 4, #2290), Chunk([1, 1, 0], 1, 4, #2351), Chunk([1, 3, 0], 1, 4, #2374), Chunk([1, 3, 2], 1, 4, #2331)]
[]


In [28]:
db.get(14, 18, (2, 1, 1, 0))

[]

In [29]:
hash((1, 2, (1, 2)))

-1429464707349485113

In [30]:
hash((1, 2, (1, 2)))

-1429464707349485113