In [1]:
import numpy as np
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import optax
import plotly.graph_objects as go

In [2]:
for bitpattern in range(2**8):
    print(bitpattern, f'{bitpattern:08b}')

0 00000000
1 00000001
2 00000010
3 00000011
4 00000100
5 00000101
6 00000110
7 00000111
8 00001000
9 00001001
10 00001010
11 00001011
12 00001100
13 00001101
14 00001110
15 00001111
16 00010000
17 00010001
18 00010010
19 00010011
20 00010100
21 00010101
22 00010110
23 00010111
24 00011000
25 00011001
26 00011010
27 00011011
28 00011100
29 00011101
30 00011110
31 00011111
32 00100000
33 00100001
34 00100010
35 00100011
36 00100100
37 00100101
38 00100110
39 00100111
40 00101000
41 00101001
42 00101010
43 00101011
44 00101100
45 00101101
46 00101110
47 00101111
48 00110000
49 00110001
50 00110010
51 00110011
52 00110100
53 00110101
54 00110110
55 00110111
56 00111000
57 00111001
58 00111010
59 00111011
60 00111100
61 00111101
62 00111110
63 00111111
64 01000000
65 01000001
66 01000010
67 01000011
68 01000100
69 01000101
70 01000110
71 01000111
72 01001000
73 01001001
74 01001010
75 01001011
76 01001100
77 01001101
78 01001110
79 01001111
80 01010000
81 01010001
82 01010010
83 01010011
84

In [3]:
import optax
from jax.experimental.host_callback import id_print
from jax.config import config
config.update("jax_debug_nans", True)


def eval_plane(plane, point):
    d = point.shape[0]
    return jnp.dot(plane[:d], point) - plane[d]

@jax.jit
@jax.vmap
def find_plane(points, signs):
    # points: (N, d = (2, 3))
    # signs: (N,)
    d = points.shape[1]
    N = points.shape[0]
    plane = jnp.ones(d + 1)

    eval_plane_points = lambda p: jax.vmap(eval_plane, in_axes=(None, 0))(p, points)

    def loss(plane):
        return (
                jnp.max(jnp.abs(eval_plane_points(plane) - signs))
                # + 1e4 * (jnp.linalg.norm(plane[:d]) - 1.) ** 2
                # + 1e-3 / jnp.linalg.norm(plane[d-1])
                # + 1e-6 * jnp.linalg.norm(plane[d])
        )

    opt = optax.adam(1e-2)
    opt_state = opt.init(plane)
    def step(_i, state):
        plane, opt_state = state
        loss_val, grad = jax.value_and_grad(loss)(plane)
        updates, opt_state = opt.update(grad, opt_state, plane)
        plane = optax.apply_updates(plane, updates)
        return plane, opt_state

    plane, opt_state = jax.lax.fori_loop(0, 1000, step, (plane, opt_state))

    # Check classification accuracy on points
    accuracy = jnp.mean(jnp.sign(eval_plane_points(plane)) == signs)
    return plane, accuracy


bitpatterns = range(2**8)
points = jnp.array([[[int(b) for b in f'{bp:03b}'] for bp in range(2**3)] for _ in bitpatterns], dtype=float) - 0.5
signs = jnp.array([[1 if int(x) == 1 else -1 for x in f'{bitpattern:08b}'] for bitpattern in bitpatterns], dtype=float)
# print(points, signs)
planes, accuracys = find_plane(points, signs)
# print(bitpattern, f'{bitpattern:08b}', plane, accuracy)


In [19]:
# Define canonical bit patterns
# If a pattern can be rotated or inverted to another, they belong to the same class
import itertools

def generate_permutations():
    axis_permutations = list(itertools.permutations([0, 1, 2]))
    sign_permutations = list(itertools.product([-1, 1], repeat=3))
    return [np.array([[sp[i] if j == ap[i] else 0 for j in range(3)] for i in range(3)])
            for ap, sp in itertools.product(axis_permutations, sign_permutations)]

def is_orthogonal(R):
    return np.allclose(R @ R.T, np.eye(3))

def generate_cube_symmetries():
    return np.array([R for R in generate_permutations() if is_orthogonal(R)])

cube_symmetries = generate_cube_symmetries()
print(len(cube_symmetries))

def rotate_point(R, point):
    return R @ point

rotate_points = jax.vmap(rotate_point, in_axes=(None, 0))
multirotate_points = jax.vmap(rotate_points, in_axes=(0, None))

def all_to_one_comparison(many, single):
    return jax.vmap(jnp.allclose, in_axes=(0, None))(many, single)

@jax.jit
def sign_order(points, signs):
    lexsorted = jnp.lexsort((signs, points[:, 2], points[:, 1], points[:, 0]))
    return signs[lexsorted]

print(sign_order(np.array([[1, 2, 3], [3, 2, 1]]), np.array([1, -1])))

seen = []
seen_idcs = []
def in_seen(ps, signs):
    # ps: (N, d)
    # signs: (N,)
    # (|R|, N, d)
    all_variants = multirotate_points(cube_symmetries, ps)

    all_sign_orders = jax.vmap(sign_order, in_axes=(0, None))(all_variants, signs)
    all_sign_orders = jnp.concatenate((all_sign_orders, -all_sign_orders), axis=0)

    for arr in seen:
        if jnp.any(all_to_one_comparison(all_sign_orders, arr)):
            return True, None
    return False, all_sign_orders[0]

for i, (pointset, ss, acc) in enumerate(zip(points, signs, accuracys)):
    if acc != 1.0:
        continue
    if not seen:
        seen.append(sign_order(pointset, ss))
        seen_idcs.append(i)
        continue
    seen_before, so = in_seen(pointset, ss)
    if not seen_before:
        seen.append(so)
        seen_idcs.append(i)
        print(so, i)

valid_idcs = np.array(seen_idcs)
print(len(valid_idcs))

48
[ 1 -1]
[ 1. -1. -1. -1. -1. -1. -1. -1.] 1
[ 1.  1. -1. -1. -1. -1. -1. -1.] 3
[ 1.  1.  1. -1. -1. -1. -1. -1.] 7
[ 1.  1.  1.  1. -1. -1. -1. -1.] 15
[ 1.  1.  1. -1.  1. -1. -1. -1.] 23
6


In [21]:
from plotly.subplots import make_subplots

def plot_points_with_plane(points, signs, plane):
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=signs,                # set color to an array/list of desired values
            opacity=0.8
        )
    ))
    # Add plane
    x, y = np.meshgrid(np.linspace(-0.5, 0.5, 10), np.linspace(-0.5, 0.5, 10))
    # Assume d = 0, then calculate z from x and y
    z = (plane[3] - plane[0] * x - plane[1] * y) / plane[2]
    fig.add_trace(go.Surface(x=x, y=y, z=z))
    # Set bounds and aspect to 1
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[-0.5, 0.5]),
            yaxis=dict(range=[-0.5, 0.5]),
            zaxis=dict(range=[-0.5, 0.5]),
            aspectmode="manual",
            aspectratio=dict(x=1, y=1, z=1)
        )
    )
    # Hide color bar
    fig.update_layout(coloraxis_showscale=False)
    fig.show()

def plot_multiple(points_batch, signs_batch, plane_batch, acc_batch):
    # Multiple plots in subplots (not in same plot!)
    nr, nc = points_batch.shape[0] // 2, 2
    fig = make_subplots(rows=nr, cols=nc, specs=[[{'type': 'scene'} for _ in range(nc)] for _ in range(nr)],
                        horizontal_spacing=0.05, vertical_spacing=0.05)
    fig.print_grid()
    s = 0
    for i in range(nr):
        for j in range(nc):
            s += 1
            idx = i * nc + j
            # Add points
            fig.add_trace(
                go.Scatter3d(
                    x=points_batch[idx, :, 0],
                    y=points_batch[idx, :, 1],
                    z=points_batch[idx, :, 2],
                    mode='markers',
                    marker=dict(
                        size=5,
                        color=signs_batch[idx],                # set color to an array/list of desired values
                        opacity=0.8
                    )
                ),
                row=i + 1, col=j + 1
            )
            # Add plane
            x, y = np.meshgrid(np.linspace(-0.5, 0.5, 10), np.linspace(-0.5, 0.5, 10))
            # Assume d = 0, then calculate z from x and y
            z = (plane_batch[idx, 3] - plane_batch[idx, 0] * x - plane_batch[idx, 1] * y) / plane_batch[idx, 2]
            fig.add_trace(
                go.Surface(x=x, y=y, z=z),
                row=i + 1, col=j + 1
            )
            # Set bounds and aspect to 1
            fig.update_layout(**{
                f'scene{s if s > 1 else ""}': dict(
                    xaxis=dict(range=[-0.5, 0.5]),
                    yaxis=dict(range=[-0.5, 0.5]),
                    zaxis=dict(range=[-0.5, 0.5]),
                    aspectmode="manual",
                    aspectratio=dict(x=1, y=1, z=1),
                    bgcolor='white' if acc_batch[idx] == 1.0 else 'red'
                )
            })
    # Hide color bar
    fig.update_layout(height=1000, width=1000, coloraxis_showscale=False)
    print('Showing plot')
    fig.show()



plot_multiple(points[valid_idcs], signs[valid_idcs], planes[valid_idcs], accuracys[valid_idcs])

#i = 0
#for p, s, pl, acc in zip(points, signs, planes, accuracys):
#    # Plot
#    plot_points_with_plane(p, s, pl)
#    print(pl)
#    print(acc)
#    i += 1
#    if i > 6:
#        break

This is the format of your plot grid:
[ (1,1) scene  ]  [ (1,2) scene2 ]
[ (2,1) scene3 ]  [ (2,2) scene4 ]
[ (3,1) scene5 ]  [ (3,2) scene6 ]

Showing plot


In [5]:
from plotly.subplots import make_subplots
# Initialize figure with 4 3D subplots
fig = make_subplots(
    rows=2, cols=2,
    specs=[[{'type': 'surface'}, {'type': 'surface'}],
           [{'type': 'surface'}, {'type': 'surface'}]])

# Generate data
x = np.linspace(-5, 80, 10)
y = np.linspace(-5, 60, 10)
xGrid, yGrid = np.meshgrid(y, x)
z = xGrid ** 3 + yGrid ** 3

# adding surfaces to subplots.
fig.add_trace(
    go.Surface(x=x, y=y, z=z, colorscale='Viridis', showscale=False),
    row=1, col=1)

fig.add_trace(
    go.Surface(x=x, y=y, z=z, colorscale='RdBu', showscale=False),
    row=1, col=2)

fig.add_trace(
    go.Surface(x=x, y=y, z=z, colorscale='YlOrRd', showscale=False),
    row=2, col=1)

fig.add_trace(
    go.Surface(x=x, y=y, z=z, colorscale='YlGnBu', showscale=False),
    row=2, col=2)

fig.update_layout(
    title_text='3D subplots with different colorscales',
    height=800,
    width=800
)

fig.show()