In [None]:
import gc
import pickle as pkl
import numpy as np
from numba import njit

from scipy.ndimage import label

from collections import defaultdict

---
---
---

In [None]:
offsets = np.array([[-1, -1, -1], [-1, -1,  0], [-1, -1,  1],
                    [-1,  0, -1], [-1,  0,  0], [-1,  0,  1], 
                    [-1,  1, -1], [-1,  1,  0], [-1,  1,  1],
                    [ 0, -1, -1], [ 0, -1,  0], [ 0, -1,  1],
                    [ 0,  0, -1],               [ 0,  0,  1],
                    [ 0,  1, -1], [ 0,  1,  0], [ 0,  1,  1],
                    [ 1, -1, -1], [ 1, -1,  0], [ 1, -1,  1],
                    [ 1,  0, -1], [ 1,  0,  0], [ 1,  0,  1],
                    [ 1,  1, -1], [ 1,  1,  0], [ 1,  1,  1]])

In [None]:
def shift_cube(data_cube, shift):
    
    
    
    shifted_cube = np.empty_like(data_cube)
    x_shift, y_shift, z_shift = shift
    
    
    
    shifted_cube[max(0, x_shift) :min(data_cube.shape[0], data_cube.shape[0] + x_shift),
                 max(0, y_shift) :min(data_cube.shape[1], data_cube.shape[1] + y_shift),
                 max(0, z_shift) :min(data_cube.shape[2], data_cube.shape[2] + z_shift)] = \
    data_cube[   max(0, -x_shift):min(data_cube.shape[0], data_cube.shape[0] - x_shift),
                 max(0, -y_shift):min(data_cube.shape[1], data_cube.shape[1] - y_shift),
                 max(0, -z_shift):min(data_cube.shape[2], data_cube.shape[2] - z_shift)]
    
    
    
    # Reflect along x-axis
    if x_shift != 0:
        if x_shift == 1: vals_x = [0,-1]
        else:            vals_x = [-1,0]
        shifted_cube[vals_x[0],
                     max(0, y_shift) :min(data_cube.shape[1], data_cube.shape[1] + y_shift),
                     max(0, z_shift) :min(data_cube.shape[2], data_cube.shape[2] + z_shift)] = \
        data_cube[   vals_x[1],
                     max(0, -y_shift):min(data_cube.shape[1], data_cube.shape[1] - y_shift),
                     max(0, -z_shift):min(data_cube.shape[2], data_cube.shape[2] - z_shift)]

    if y_shift != 0:
        if y_shift == 1: vals_y = [0,-1]
        else:            vals_y = [-1,0]
        shifted_cube[max(0, x_shift) :min(data_cube.shape[0], data_cube.shape[0] + x_shift),
                     vals_y[0],
                     max(0, z_shift) :min(data_cube.shape[2], data_cube.shape[2] + z_shift)] = \
        data_cube[   max(0, -x_shift):min(data_cube.shape[0], data_cube.shape[0] - x_shift),
                     vals_y[1],
                     max(0, -z_shift):min(data_cube.shape[2], data_cube.shape[2] - z_shift)]

    if z_shift != 0:
        if z_shift == 1: vals_z = [0,-1]
        else:            vals_z = [-1,0]
        shifted_cube[max(0, x_shift) :min(data_cube.shape[0], data_cube.shape[0] + x_shift),
                     max(0, y_shift) :min(data_cube.shape[1], data_cube.shape[1] + y_shift),
                     vals_z[0],] = \
        data_cube[   max(0, -x_shift):min(data_cube.shape[0], data_cube.shape[0] - x_shift),
                     max(0, -y_shift):min(data_cube.shape[1], data_cube.shape[1] - y_shift), 
                     vals_z[1]]
    
    
    
    if x_shift != 0 and y_shift != 0:
        if x_shift == 1: vals_x = [0,-1]
        else:            vals_x = [-1,0]
        if y_shift == 1: vals_y = [0,-1]
        else:            vals_y = [-1,0]
        
        shifted_cube[vals_x[0],
                     vals_y[0],
                     max(0, z_shift) :min(data_cube.shape[2], data_cube.shape[2] + z_shift)] = \
        data_cube[   vals_x[-1],
                     vals_y[-1],
                     max(0, -z_shift):min(data_cube.shape[2], data_cube.shape[2] - z_shift)]
    
    if x_shift != 0 and z_shift != 0:
        if x_shift == 1: vals_x = [0,-1]
        else:            vals_x = [-1,0]
        if z_shift == 1: vals_z = [0,-1]
        else:            vals_z = [-1,0]
        shifted_cube[vals_x[0],
                     max(0, y_shift) :min(data_cube.shape[1], data_cube.shape[1] + y_shift),
                     vals_z[0]] = \
        data_cube[   vals_x[-1],
                     max(0, -y_shift):min(data_cube.shape[1], data_cube.shape[1] - y_shift),
                     vals_z[-1]]
    
    if y_shift != 0 and z_shift != 0:
        if y_shift == 1: vals_y = [0,-1]
        else:            vals_y = [-1,0]
        if z_shift == 1: vals_z = [0,-1]
        else:            vals_z = [-1,0]
        shifted_cube[max(0, x_shift) :min(data_cube.shape[0], data_cube.shape[0] + x_shift),
                     vals_y[0],
                     vals_z[0]] = \
        data_cube[   max(0, -x_shift):min(data_cube.shape[0], data_cube.shape[0] - x_shift),
                     vals_y[-1],
                     vals_z[-1]]
    
    
    
    if x_shift != 0 and y_shift != 0 and z_shift != 0:
        if x_shift == 1: vals_x = [0,-1]
        else:            vals_x = [-1,0]
        if y_shift == 1: vals_y = [0,-1]
        else:            vals_y = [-1,0]
        if z_shift == 1: vals_z = [0,-1]
        else:            vals_z = [-1,0]
        shifted_cube[vals_x[0],
                     vals_y[0],
                     vals_z[0]] = \
        data_cube[   vals_x[-1],
                     vals_y[-1],
                     vals_z[-1]]
    
    
    
    return shifted_cube

---

In the functions below we basically make a numba function of array[array = i] = j, but that is because, while this is less effective at sizes 128 and 256, it is more at 512... which is the one that actually takes the longest.

In [None]:
@njit
def edge_label(x, size):

    '''
    When connecting indices from one edge to the other.
    '''
    
    if   x == 0:      x = size-1
    elif x == size-1: x = 0
    
    return x

In [None]:
@njit
def labeled_pairs(labeled):

    size = labeled.shape[0]
    labels_pairs = [[labeled[0][0][0],labeled[0][0][0]]][:0]
    for i in range(size):
        for j in range(size):
            for k in range(size):

                if (i == 0 or i == size-1) or (j == 0 or j == size-1) or (k == 0 or k == size-1):
                    l1 = labeled[i][j][k]
                    if (l1 != 0):
                        l2 = labeled[edge_label(i, size)][edge_label(j, size)][edge_label(k, size)]
                        if (l2 != 0) and (l2 != l1):
                            labels_pairs.append([l1,l2])

    return labels_pairs

---

In [None]:
def combine_lists(lists):

    '''
    [[1,2], [2,4], [3,1]], [5,6]] -> [[1,2,4,3], [5,6]]
    '''
    
    # Create a graph where each node is an element, and each edge is a list
    graph = defaultdict(set)
    
    for i0, i1 in lists: graph[i0].add(i1); graph[i1].add(i0)

    # Function to perform a depth-first search (DFS) to find all connected nodes
    def dfs(node, visited):
        visited.add(node)
        component = [node]
        for neighbor in graph[node]:
            if neighbor not in visited:
                component.extend(dfs(neighbor, visited))
        return component

    visited = set(); combined_lists = []

    # Find all connected components in the graph
    for node in graph:
        if node not in visited:
            new_component = sorted(dfs(node, visited))
            combined_lists.append(new_component)
    
    return combined_lists

In [None]:
@njit
def abc_a(labels_pairs_lens, labels_pairs_flat, labeled):

    '''
    [a,b,c]->[a]
    '''
    
    index = 0
    size = labeled.shape[0]
    for i in labels_pairs_lens:
        for j in range(index+1, index+i):
            lpf_j = labels_pairs_flat[j]
            lpf_i = labels_pairs_flat[index]
            
            for i0 in range(size):
                for i1 in range(size):
                    for i2 in range(size):
                        
                        if labeled[i0][i1][i2] == lpf_j:
                            labeled[i0][i1][i2] = lpf_i
        index += i

---

In [None]:
@njit
def fake_indices_into_labeled(fake_indices, labeled):

    index = 0
    size = labeled.shape[0]
    for fake_index in fake_indices:
        
        for i0 in range(size):
            for i1 in range(size):
                for i2 in range(size):
                    
                    if labeled[i0][i1][i2] == fake_index:
                        labeled[i0][i1][i2] = 0

In [None]:
@njit
def list_pairs_maker(labeled):

    max_labeled = np.max(labeled)
    list_pairs = [[[0,0,0]][:0] for _ in range(max_labeled)]

    size = labeled.shape[0]
    for i in range(size):
        for j in range(size):
            for k in range(size):
                
                l1 = labeled[i][j][k]
                if l1 != 0: list_pairs[l1-1].append([i,j,k])

    list_pairs = [i for i in list_pairs if (len(i) != 0)]
    
    return list_pairs

---
---
---