# Figures

## Obtain Relevant Permutations

In [None]:
import numpy as np


k = 3
x = np.zeros((4, k, k, 3))
h = 0.6 / k

for i in range(k):
    for j in range(k):
        x[:, i, j, 0] = -(k - 1) / 2 * h + i * h
        x[:, i, j, 1] = -(k - 1) / 2 * h + j * h
        x[:, i, j, 2] = 1

x[0, :, :, 0] += (-1) * 0.4
x[0, :, :, 1] += (-1) * 0.4

x[1, :, :, 0] += (-1) * 0.4
x[1, :, :, 1] += (+1) * 0.4

x[2, :, :, 0] += (+1) * 0.4
x[2, :, :, 1] += (+1) * 0.4

x[3, :, :, 0] += (+1) * 0.4
x[3, :, :, 1] += (-1) * 0.4

x = x.reshape(2 * k, 2 * k, 3)

In [None]:
from poly_sphere import *


x = make_cube(x)
x = x.reshape((-1, 4 * k ** 2, 3))
x = x.reshape((-1, 3))
z = np.mean(x.reshape(24, k ** 2, 3), axis=1)


def sym(x, around_z):
    if around_z:
        y = rotate_3d(x, np.pi / 2, axis=2)
    else:
        y = rotate_3d(x, np.pi / 2, axis=0)
    n = x.shape[0]
    perm = [None] * n
    for i in range(n):
        for j in range(n):
            if np.linalg.norm(x[i] - y[j], ord=2) < 1e-6:
                perm[i] = j
                break
    return perm


perm0 = sym(x, True)
perm1 = sym(x, False)

print(perm0)
print(perm1)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline


fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:, 0], x[:, 1], x[:, 2], s=10)

for i in range(4 * k ** 2):
    ax.text(x[i, 0], x[i, 1], x[i, 2], str(i))

plt.show()

In [None]:
x = np.arange(4 * k ** 2) + 1
x = np.where(x % k == 0, x - k, x)
x = np.concatenate((x, np.arange(4 * k ** 2 * 5) + 4 * k ** 2))
perm2 = x.tolist()
print(perm2)

## Group Convolution

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
%matplotlib inline


generators = [perm0[:(4 * k ** 2)], perm2[:(4 * k ** 2)]]

group_conv = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', group_conv.num_weights_W)
print('number of parameters in b =', group_conv.num_weights_b)

fig = plt.figure(figsize=(16, 16))
ae.draw_colored_matrix(group_conv.colors_W)
plt.show()

## S4 on Faces

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
%matplotlib inline


x = np.array([0, 0, 1], dtype=np.float32)[None, None, :]
x = make_cube(x)
x = x.reshape((-1, 3))
generators = [sym(x, True), sym(x, False)]

s4_faces = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', s4_faces.num_weights_W)
print('number of parameters in b =', s4_faces.num_weights_b)

fig = plt.figure(figsize=(6, 6))
ae.draw_colored_matrix(s4_faces.colors_W, markersize=100)
plt.show()

## S4 on Flags

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
from poly_sphere import *
%matplotlib inline


x = get_sampling_grid('cube', 1, center=False)
x = x.reshape((-1, 4, 3))
x = 0.99 * x + 0.01 * np.mean(x, axis=1, keepdims=True)
x = x.reshape((-1, 3))
generators = [sym(x, True), sym(x, False)]

s4_flags = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', s4_flags.num_weights_W)
print('number of parameters in b =', s4_flags.num_weights_b)

fig = plt.figure(figsize=(16, 16))
ae.draw_colored_matrix(s4_flags.colors_W, markersize=45)
plt.show()

## S4 on Flags Oriented

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
from poly_sphere import *
%matplotlib inline


x = get_sampling_grid('cube', 1, center=False)
x = x.reshape((-1, 4, 3))
x = 0.99 * x + 0.01 * np.mean(x, axis=1, keepdims=True)
x = x.reshape((-1, 3))
generators = [sym(x, True)]

s4_flags_oriented = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', s4_flags_oriented.num_weights_W)
print('number of parameters in b =', s4_flags_oriented.num_weights_b)

fig = plt.figure(figsize=(16, 16))
ae.draw_colored_matrix(s4_flags_oriented.colors_W, markersize=45)
plt.show()

fig = plt.figure(figsize=(16, 2))
ae.draw_colored_matrix(s4_flags_oriented.colors_b, markersize=45)
plt.show()

## Sphere Layer

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
%matplotlib inline


generators = [perm0, perm1, perm2]

sphere_layer = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', sphere_layer.num_weights_W)
print('number of parameters in b =', sphere_layer.num_weights_b)

fig = plt.figure(figsize=(24, 24))
ae.draw_colored_matrix(sphere_layer.colors_W, markersize=1)
plt.show()

## A Circle Example

In [None]:
import autoequiv as ae
import matplotlib.pyplot as plt
%matplotlib inline


k = 3
generators = [
    (np.arange(k * 8) + 2 * k) % (k * 8),
    np.concatenate([(np.arange(k) + 1) % k, k + (np.arange(k) + 1) % k, 2 * k + np.arange(6 * k)]),
    np.concatenate([np.arange(2 * k)[::-1], 6 * k + np.arange(2 * k)[::-1], 4 * k + np.arange(2 * k)[::-1], 2 * k + np.arange(2 * k)[::-1]])
]

circle_layer = ae.LinearEquiv(
    in_generators=generators,
    out_generators=generators,
    in_channels=1,
    out_channels=1
)
print('number of parameters in W =', circle_layer.num_weights_W)
print('number of parameters in b =', circle_layer.num_weights_b)

fig = plt.figure(figsize=(8, 8))
ae.draw_colored_matrix(circle_layer.colors_W, markersize=10)
plt.show()

## Polyhedral Spheres

In [None]:
import numpy as np


def triangle_area(a, b, c):
    B = (b - a)
    h = c - (b + a) / 2
    return np.linalg.norm(B) * np.linalg.norm(h) / 2


def square_area(a, b, c, d):
    return triangle_area(a, b, c) + triangle_area(c, d, a)


def hexa_area(a, b, c, d, e, f):
    center = (a + b + c + d + e + f) / 6
    area = 0
    area += triangle_area(a, b, center)
    area += triangle_area(b, c, center)
    area += triangle_area(c, d, center)
    area += triangle_area(d, e, center)
    area += triangle_area(e, f, center)
    area += triangle_area(f, a, center)
    return area

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import mpl_toolkits.mplot3d as a3
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from poly_sphere import *
from utils import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'


def plot_cube_sphere(ax, w, project, im=None, flag=False, **kwargs):
    ul = proj_unit(np.array([-1, +1, +1], dtype=float))
    ur = proj_unit(np.array([+1, +1, +1], dtype=float))
    bl = proj_unit(np.array([-1, -1, +1], dtype=float))
    br = proj_unit(np.array([+1, -1, +1], dtype=float))

    x = refine_square(ul, ur, bl, br, w, project) + np.array([0, 0, kwargs.get('pull', 0)], dtype=float)
    x = make_cube(x)
    if kwargs.get('rotate', False):
        shape_bac = x.shape
        x = rotate_3d(x.reshape(-1, 3), np.pi / 4, axis=0)
        x = np.reshape(x, shape_bac)

    if kwargs.get('face', None) is not None:
        x = x[kwargs['face']][None, ...]

    ax.set_xlim(-0.66, 0.66)
    ax.set_ylim(-0.66, 0.66)
    ax.set_zlim(-0.66, 0.66)
    ax.patch.set_alpha(0)
    poly = []
    color_list = []
    area_list = []
    for k in range(x.shape[0]):
        for i in range(x.shape[1] - 1):
            for j in range(x.shape[2] - 1):
                poly.append(np.array([x[k, i, j], x[k, i + 1, j], x[k, i + 1, j + 1], x[k, i, j + 1]]))
                area_list.append(square_area(*poly[-1]))
                if im is not None:
                    color_list.append(im[k, :, i, j])
    
    if flag:
        poly = [*poly[:2 * w ** 2], *poly[5 * w ** 2:]]
        color_list = [*color_list[:2 * w ** 2], *color_list[5 * w ** 2:]]

    # print('area std =', np.std(area_list))
    # print('num pixels = %d' % len(poly))
    poly = np.array(poly) * kwargs.get('scale', 1.0) + kwargs.get('center', 0)
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_facecolor(kwargs.get('facecolor', 'white'))
    if im is not None:
        tri.set_facecolor(color_list)
    tri.set_alpha(kwargs.get('alpha', 1))
    tri.set_linewidth(kwargs.get('linewidth', 1))
    tri.set_3d_properties()
    
    ax.add_collection3d(tri)
    plt.axis('off')
    ax.view_init(15, -70)

    
def plot_icosa_sphere(ax, w, project, hexa, im=None, **kwargs):
    a = np.sqrt(9 * np.tan(np.pi / 5) ** 2 - 3) # triangle side length
    ul = proj_unit(np.array([-a / 2, a / (2 * np.sqrt(3)), 1], dtype=float))
    ur = proj_unit(np.array([+a / 2, a / (2 * np.sqrt(3)), 1], dtype=float))
    bo = proj_unit(np.array([0, -a / np.sqrt(3), 1], dtype=float))

    x = refine_triangle(ul, ur, bo, w, project)
    x = make_icosa(x)

    ax.set_xlim(-0.66, 0.66)
    ax.set_ylim(-0.66, 0.66)
    ax.set_zlim(-0.66, 0.66)
    ax.patch.set_alpha(0)
    poly = []
    color_list = []
    area_list = []
    for k in range(x.shape[0]):
        for i in range(x.shape[1] - 1):
            for j in range(x.shape[2] - i - 1):
                if hexa:
                    center = np.mean([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]], axis=0)
                    poly.append([
                        x[k, i, j],
                        x[k, i, j] + x[k, i + 1, j] - center,
                        x[k, i + 1, j],
                        x[k, i + 1, j] + x[k, i, j + 1] - center,
                        x[k, i, j + 1],
                        x[k, i, j + 1] + x[k, i, j] - center
                    ])
                    if project:
                        poly[-1] = [proj_unit(x) for x in poly[-1]]
                    area_list.append(hexa_area(*poly[-1]))
                    if im is not None:
                        color_list.append(im[k, :, i, j])
                else:
                    poly.append([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
    # print('num pixels = %d' % len(poly))
    if not hexa:
        for k in range(x.shape[0]):
            for i in range(x.shape[1] - 1):
                for j in range(x.shape[2] - i - 2):
                    poly.append([x[k, i + 1, j + 1], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
    # print('area std =', np.std(area_list))
    poly = np.array(poly) * kwargs.get('scale', 1.0) + kwargs.get('center', 0)
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_color('white')
    if im is not None and hexa:
        tri.set_color(color_list)
    tri.set_alpha(kwargs.get('alpha', 1))
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_linewidth(kwargs.get('linewidth', 1))
    tri.set_3d_properties()
    ax.add_collection3d(tri)
    plt.axis('off')
    ax.view_init(15, -70)


def plot_tetra_sphere(ax, w, project, hexa, im=None, **kwargs):
    a = 2 * np.sqrt(6) # triangle side length
    ul = proj_unit(np.array([-a / 2, a / (2 * np.sqrt(3)), 1], dtype=float))
    ur = proj_unit(np.array([+a / 2, a / (2 * np.sqrt(3)), 1], dtype=float))
    bo = proj_unit(np.array([0, -a / np.sqrt(3), 1], dtype=float))

    x = refine_triangle(ul, ur, bo, w, project)
    x = make_tetra(x)

    ax.set_xlim(-0.66, 0.66)
    ax.set_ylim(-0.66, 0.66)
    ax.set_zlim(-0.66, 0.66)
    ax.patch.set_alpha(0)
    poly = []
    color_list = []
    area_list = []
    for k in range(x.shape[0]):
        for i in range(x.shape[1] - 1):
            for j in range(x.shape[2] - i - 1):
                if hexa:
                    center = np.mean([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]], axis=0)
                    poly.append([
                        x[k, i, j],
                        x[k, i, j] + x[k, i + 1, j] - center,
                        x[k, i + 1, j],
                        x[k, i + 1, j] + x[k, i, j + 1] - center,
                        x[k, i, j + 1],
                        x[k, i, j + 1] + x[k, i, j] - center
                    ])
                    if project:
                        poly[-1] = [proj_unit(x) for x in poly[-1]]
                    area_list.append(hexa_area(*poly[-1]))
                    if im is not None:
                        color_list.append(im[k, :, i, j])
                else:
                    poly.append([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
        
    # print('num pixels = %d' % len(poly))
    if not hexa:
        for k in range(x.shape[0]):
            for i in range(x.shape[1] - 1):
                for j in range(x.shape[2] - i - 2):
                    poly.append([x[k, i + 1, j + 1], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
    # print('area std =', np.std(area_list))
    poly = np.array(poly) * kwargs.get('scale', 1.0) + kwargs.get('center', 0)
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_color('white')
    if im is not None and hexa:
        tri.set_color(color_list)
    tri.set_alpha(kwargs.get('alpha', 1))
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_linewidth(kwargs.get('linewidth', 1))
    tri.set_3d_properties()
    ax.add_collection3d(tri)
    plt.axis('off')
    ax.view_init(15, -70)


def plot_octa_sphere(ax, w, project, hexa, im=None, **kwargs):
    ul = proj_unit(np.array([0, -1, 0], dtype=float))
    ur = proj_unit(np.array([+1, 0, 0], dtype=float))
    bo = proj_unit(np.array([0, 0, -1], dtype=float))

    x = refine_triangle(ul, ur, bo, w, project)
    x = make_octa(x)

    ax.set_xlim(-0.66, 0.66)
    ax.set_ylim(-0.66, 0.66)
    ax.set_zlim(-0.66, 0.66)
    ax.patch.set_alpha(0)
    poly = []
    color_list = []
    area_list = []
    for k in range(x.shape[0]):
        for i in range(x.shape[1] - 1):
            for j in range(x.shape[2] - i - 1):
                if hexa:
                    center = np.mean([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]], axis=0)
                    poly.append([
                        x[k, i, j],
                        x[k, i, j] + x[k, i + 1, j] - center,
                        x[k, i + 1, j],
                        x[k, i + 1, j] + x[k, i, j + 1] - center,
                        x[k, i, j + 1],
                        x[k, i, j + 1] + x[k, i, j] - center
                    ])
                    if project:
                        poly[-1] = [proj_unit(x) for x in poly[-1]]
                    area_list.append(hexa_area(*poly[-1]))
                    if im is not None:
                        color_list.append(im[k, :, i, j])
                else:
                    poly.append([x[k, i, j], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
    # print('num pixels = %d' % len(poly))
    if not hexa:
        for k in range(x.shape[0]):
            for i in range(x.shape[1] - 1):
                for j in range(x.shape[2] - i - 2):
                    poly.append([x[k, i + 1, j + 1], x[k, i + 1, j], x[k, i, j + 1]])
                    area_list.append(triangle_area(*poly[-1]))
    # print('area std =', np.std(area_list))
    poly = np.array(poly) * kwargs.get('scale', 1.0) + kwargs.get('center', 0)
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_color('white')
    if im is not None and hexa:
        tri.set_color(color_list)
    tri.set_alpha(kwargs.get('alpha', 1))
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_linewidth(kwargs.get('linewidth', 1))
    tri.set_3d_properties()
    ax.add_collection3d(tri)
    plt.axis('off')
    ax.view_init(15, -70)

In [None]:
levels = 5
fig = plt.figure(figsize=(levels * 5, 4 * 5))
fig.patch.set_alpha(0)
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)

for i in range(levels):
    ax = fig.add_subplot(4, levels, 0 * levels + i + 1, projection='3d')
    plot_tetra_sphere(ax, w=2 ** i, project=True, hexa=False, alpha=1.0)

for i in range(levels):
    ax = fig.add_subplot(4, levels, 1 * levels + i + 1, projection='3d')
    plot_octa_sphere(ax, w=2 ** i, project=True, hexa=False, alpha=1.0)

for i in range(levels):
    ax = fig.add_subplot(4, levels, 2 * levels + i + 1, projection='3d')
    plot_cube_sphere(ax, w=2 ** i, project=True, alpha=1.0)

for i in range(levels):
    ax = fig.add_subplot(4, levels, 3 * levels + i + 1, projection='3d')
    plot_icosa_sphere(ax, w=2 ** i, project=True, hexa=False, alpha=1.0)

plt.show()

## Hexagonal Tiling on Triangular Faces

In [None]:
def plot_triangle(ax, w, hexa, im=None, **kwargs):
    ul = proj_unit(np.array([-1, 0, np.sqrt(3) / 3], dtype=float))
    ur = proj_unit(np.array([+1, 0, np.sqrt(3) / 3], dtype=float))
    bo = proj_unit(np.array([0, 0, -2 * np.sqrt(3) / 3], dtype=float))

    x = refine_triangle(ul, ur, bo, w, project=False)

    ax.set_xlim(-0.6, 0.6)
    ax.set_ylim(-0.6, 0.6)
    ax.set_zlim(-0.6, 0.6)
    poly = []
    color_list = []
    for i in range(x.shape[0] - 1):
        for j in range(x.shape[1] - i - 1):
            if hexa:
                center = np.mean([x[i, j], x[i + 1, j], x[i, j + 1]], axis=0)
                poly.append([
                    x[i, j],
                    x[i, j] + x[i + 1, j] - center,
                    x[i + 1, j],
                    x[i + 1, j] + x[i, j + 1] - center,
                    x[i, j + 1],
                    x[i, j + 1] + x[i, j] - center
                ])
                if im is not None:
                    color_list.append(im[i, j])
            else:
                poly.append([x[i, j], x[i + 1, j], x[i, j + 1]])
    if not hexa:
        for i in range(x.shape[0] - 1):
            for j in range(x.shape[1] - i - 2):
                poly.append([x[i + 1, j + 1], x[i + 1, j], x[i, j + 1]])
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_color('white')
    if im is not None and hexa:
        cmap = plt.get_cmap('rainbow')
        tri.set_color(['white' if i == 0 else cmap(i - 1) for i in color_list])
    tri.set_alpha(kwargs.get('alpha', 1))
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_linewidth(kwargs.get('linewidth', 1))
    tri.set_3d_properties()
    ax.add_collection3d(tri)
    plt.axis('off')
    ax.view_init(0, -90)

In [None]:
fig = plt.figure(figsize=(2 * 6, 6))
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)

ax = fig.add_subplot(1, 2, 2, projection='3d')
plot_triangle(ax, w=9, hexa=True)

ax = fig.add_subplot(1, 2, 1, projection='3d')
plot_triangle(ax, w=9, hexa=False)

plt.show()

## Hexagonal Pooling

In [None]:
import os


os.makedirs('hexa_pool_figures', exist_ok=True)

cnt = 0
for i in range(5):
    for j in range(5 - i):
        fig = plt.figure(figsize=(2 * 6, 6))
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)

        im0 = np.zeros((9, 9), dtype=int)
        for a in [-1, 0, 1]:
            for b in [-1, 0, 1]:
                if 0 <= i * 2 + a <= 8 and 0 <= j * 2 + b <= 8:
                    if (a, b) != (-1, -1) and (a, b) != (1, 1):
                        im0[i * 2 + a, j * 2 + b] = 1

        im1 = np.zeros((5, 5), dtype=int)
        im1[i, j] = 1

        ax = fig.add_subplot(1, 2, 1, projection='3d')
        plot_triangle(ax, w=9, hexa=True, im=im0)

        ax = fig.add_subplot(1, 2, 2, projection='3d')
        plot_triangle(ax, w=5, hexa=True, im=im1)

        fig.savefig('hexa_pool_figures/hexa_pool_%02d.png' % cnt, bbox_inches='tight', dpi=200)
        fig.savefig('hexa_pool_figures/hexa_pool_%02d.pdf' % cnt, bbox_inches='tight')
        plt.close()
        cnt += 1

In [None]:
import cv2
import glob
import imageio


image_list = []
files = sorted(glob.glob('hexa_pool_figures/*.png'))
for file in files:
    image_list.append(cv2.cvtColor(cv2.imread(file), cv2.COLOR_BGR2RGB))
imageio.mimsave('hexa_pool.gif', image_list, fps=2)

In [None]:
from IPython.display import Image


Image('hexa_pool.gif', width=500)

## Spherical MNIST on Polyhedral Spheres

In [None]:
import gzip
import pickle


with gzip.open('smnist_poly/tetra_41_rr/train-images-idx3-ubyte.gz', 'rb') as f:
    X_tetra = pickle.load(f).astype(np.float32) / 255

with gzip.open('smnist_poly/cube_24_rr/train-images-idx3-ubyte.gz', 'rb') as f:
    X_cube = pickle.load(f).astype(np.float32) / 255

with gzip.open('smnist_poly/octa_25_rr/train-images-idx3-ubyte.gz', 'rb') as f:
    X_octa = pickle.load(f).astype(np.float32) / 255

with gzip.open('smnist_poly/icosa_17_rr/train-images-idx3-ubyte.gz', 'rb') as f:
    X_icosa = pickle.load(f).astype(np.float32) / 255

In [None]:
fig = plt.figure(figsize=(4 * 6, 1 * 6))
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)

idx = 2508

ax = fig.add_subplot(1, 4, 1, projection='3d')
plot_tetra_sphere(ax, w=41, project=True, hexa=True, alpha=1)
im = np.concatenate(3 * [X_tetra[idx]], axis=1)
plot_tetra_sphere(ax, w=41, project=True, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 2, projection='3d')
plot_cube_sphere(ax, w=24, project=True, alpha=1)
im = np.concatenate(3 * [X_cube[idx]], axis=1)
plot_cube_sphere(ax, w=24, project=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 3, projection='3d')
plot_octa_sphere(ax, w=25, project=True, hexa=True, alpha=1)
im = np.concatenate(3 * [X_octa[idx]], axis=1)
plot_octa_sphere(ax, w=25, project=True, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 4, projection='3d')
plot_icosa_sphere(ax, w=17, project=True, hexa=True, alpha=1)
im = np.concatenate(3 * [X_icosa[idx]], axis=1)
plot_icosa_sphere(ax, w=17, project=True, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

plt.show()

In [None]:
fig = plt.figure(figsize=(4 * 6, 1 * 6))
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)

idx = 2508

ax = fig.add_subplot(1, 4, 1, projection='3d')
plot_tetra_sphere(ax, w=41, project=False, hexa=True, alpha=1)
im = np.concatenate(3 * [X_tetra[idx]], axis=1)
plot_tetra_sphere(ax, w=41, project=False, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 2, projection='3d')
plot_cube_sphere(ax, w=24, project=False, alpha=1)
im = np.concatenate(3 * [X_cube[idx]], axis=1)
plot_cube_sphere(ax, w=24, project=False, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 3, projection='3d')
plot_octa_sphere(ax, w=25, project=False, hexa=True, alpha=1)
im = np.concatenate(3 * [X_octa[idx]], axis=1)
plot_octa_sphere(ax, w=25, project=False, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

ax = fig.add_subplot(1, 4, 4, projection='3d')
plot_icosa_sphere(ax, w=17, project=False, hexa=True, alpha=1)
im = np.concatenate(3 * [X_icosa[idx]], axis=1)
plot_icosa_sphere(ax, w=17, project=False, hexa=True, im=im, linewidth=0, scale=0.5, center=np.array([1.1, -1.8, 0]), alpha=1)

plt.show()

## Equivariance

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import mpl_toolkits.mplot3d as a3
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
from poly_sphere import *
from utils import *
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx


def get_color(cmap, i, n):
    cm = plt.get_cmap(cmap)
    cNorm  = colors.Normalize(vmin=0, vmax=n - 1)
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    return np.array(scalarMap.to_rgba(i)[:3])

In [None]:
def process(x):
    y = np.ones_like(x)
    n, m = x.shape
    for i in range(n):
        for j in range(m):
            if x[i, j] == 0:
                continue
            for di in [-1, 0, 1]:
                for dj in [-1, 0, 1]:
                    if di == 0 and dj == 0:
                        continue
                    if di != 0 and dj != 0:
                        continue
                    if 0 <= i + di < n and 0 <= j + dj < m and x[i + di, j + dj] == 0:
                        y[i, j] = 0
    return y

In [None]:
L = np.ones((8, 8))
L[2, 2:5] = 0
L[1:3, 4] = 0
L_ = process(L)

F = np.ones((8, 8))
F[1:6, 4] = 0
F[5, 4:7] = 0
F[3, 4:6] = 0
F_ = process(F)

C = np.ones((8, 8))
C[3:5, 4] = 0
C[2, 5] = 0
C[5, 5] = 0
C_ = process(C)

I = np.ones((8, 8))
I[2:6, 5] = 0
I_ = process(I)

H = np.ones((8, 8))
H[2:5, 2] = 0
# H[2:5, 4] = 0
H[3, 2:4] = 0
H_ = process(H)

U = np.ones((8, 8))
U[3:5, 3:5] = 0
# U[2:5, 5] = 0
# U[4, 3:6] = 0
U_ = process(U)

In [None]:
fig = plt.figure(figsize=(2 * 6, 2 * 6))
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0, hspace=0)
plt.axis('off')

W = 8
PULL = 0.3
ALPHA = 0.1
AX_LIM = 1.2
f = 5

c = get_color('rainbow', 1, 10)

im1 = np.ones((6, 3, W, W))
im1[0] = 1 - (1 - I)[None, :, :] * (1 - c)[:, None, None]
im1[1] = 1 - (1 - U)[None, :, :] * (1 - c)[:, None, None]
im1[2] = 1 - (1 - L)[None, :, :] * (1 - c)[:, None, None]
im1[3] = 1 - (1 - H)[None, :, :] * (1 - c)[:, None, None]
im1[4] = 1 - (1 - C)[None, :, :] * (1 - c)[:, None, None]
im1[5] = 1 - (1 - L)[None, :, :] * (1 - c)[:, None, None]

im2 = np.ones((6, 3, W, W))
im2[0] = 1 - (1 - I_)[None, :, :] * (1 - c)[:, None, None]
im2[1] = 1 - (1 - U_)[None, :, :] * (1 - c)[:, None, None]
im2[2] = 1 - (1 - L_)[None, :, :] * (1 - c)[:, None, None]
im2[3] = 1 - (1 - H_)[None, :, :] * (1 - c)[:, None, None]
im2[4] = 1 - (1 - C_)[None, :, :] * (1 - c)[:, None, None]
im2[5] = 1 - (1 - L_)[None, :, :] * (1 - c)[:, None, None]

im3 = im1.copy()
im1[f] = 1
im1[f, :, 2:, :] = im3[f, :, :-2, :]

im4 = im2.copy()
im2[f] = 1
im2[f, :, 2:, :] = im4[f, :, :-2, :]


ax = fig.add_subplot(2, 2, 3, projection='3d')

for i in [0, 1, 5]:
    plot_cube_sphere(ax, w=W, project=True, im=im1[i: i + 1], edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)

# for i in [2, 3, 4]:
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 3, W, W)), edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)

ax.set_xlim(-AX_LIM, AX_LIM)
ax.set_ylim(-AX_LIM, AX_LIM)
ax.set_zlim(-AX_LIM, AX_LIM)
ax.view_init(15, -20)
# ax.view_init(115, 90)


ax = fig.add_subplot(2, 2, 4, projection='3d')

for i in [0, 3, 5]:
    plot_cube_sphere(ax, w=W, project=True, im=im3[i: i + 1], edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)

# for i in [1, 2, 4]:
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 3, W, W)), edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)

ax.set_xlim(-AX_LIM, AX_LIM)
ax.set_ylim(-AX_LIM, AX_LIM)
ax.set_zlim(-AX_LIM, AX_LIM)
# ax.view_init(115, 0)
ax.view_init(15, 70)


ax = fig.add_subplot(2, 2, 1, projection='3d')

for i in [0, 1, 5]:
    plot_cube_sphere(ax, w=W, project=True, im=im2[i: i + 1], edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=2 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=3 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=4 * PULL - PULL * 0.75, face=i)

# for i in [2, 3, 4]:
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)), edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=2 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=3 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=4 * PULL - PULL * 0.75, face=i)
    
ax.set_xlim(-AX_LIM, AX_LIM)
ax.set_ylim(-AX_LIM, AX_LIM)
ax.set_zlim(-AX_LIM, AX_LIM)
# ax.view_init(115, 90)
ax.view_init(15, -20)


ax = fig.add_subplot(2, 2, 2, projection='3d')

for i in [0, 3, 5]:
    plot_cube_sphere(ax, w=W, project=True, im=im4[i: i + 1], edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75 if i not in [0, 2] else 2 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=2 * PULL - PULL * 0.75 if i not in [0, 2] else 1 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=3 * PULL - PULL * 0.75, face=i)
    plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA, flag=False, pull=4 * PULL - PULL * 0.75, face=i)

# for i in [1, 2, 4]:
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)), edgecolor='black', linewidth=0.5, alpha=1.0, flag=False, pull=1 * PULL - PULL * 0.75 if i not in [0, 2] else 2 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=2 * PULL - PULL * 0.75 if i not in [0, 2] else 1 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=3 * PULL - PULL * 0.75, face=i)
#     plot_cube_sphere(ax, w=W, project=True, im=np.ones((6, 4, W, W)) * 0.05, linewidth=0, alpha=ALPHA * 0.33, flag=False, pull=4 * PULL - PULL * 0.75, face=i)

ax.set_xlim(-AX_LIM, AX_LIM)
ax.set_ylim(-AX_LIM, AX_LIM)
ax.set_zlim(-AX_LIM, AX_LIM)
# ax.view_init(115, 0)
ax.view_init(15, 70)

fig.savefig('champion.pdf', bbox_inches='tight')
plt.show()

## Equivariant Padding

In [None]:
mrklist = [
    "3",
    "d",
    ",",
    "2",
    "8",
    "s",
    "|",
    "^",
    "<",
    "4",
    "H",
    "+",
    "o",
    "v",
    "p",
    "D",
    ">",
    "1",
    ".",
    "_",
    "*",
    "h",
]


def plot_grid(ax, center, w, im=None, project=False, xy_dir=None, **kwargs):
    if xy_dir is None:
        ul = center + np.array([-0.5, +0.5, 0], dtype=float)
        ur = center + np.array([+0.5, +0.5, 0], dtype=float)
        bl = center + np.array([-0.5, -0.5, 0], dtype=float)
        br = center + np.array([+0.5, -0.5, 0], dtype=float)
    else:
        x_dir, y_dir = xy_dir
        ul = center + (-x_dir +y_dir)
        ur = center + (+x_dir +y_dir)
        bl = center + (-x_dir -y_dir)
        br = center + (+x_dir -y_dir)

    x = refine_square(ul, ur, bl, br, w, project=project)
    x = x[None, ...]

#     ax.set_xlim(-0.66, 0.66)
#     ax.set_ylim(-0.66, 0.66)
#     ax.set_zlim(-0.66, 0.66)
    poly = []
    color_list = []
    for k in range(x.shape[0]):
        for i in range(x.shape[1] - 1):
            for j in range(x.shape[2] - 1):
                poly.append([x[k, i, j], x[k, i + 1, j], x[k, i + 1, j + 1], x[k, i, j + 1]])
                if im is not None:
                    color_list.append(tuple(im[k, :, i, j]))
                else:
                    color_list.append(kwargs.get('facecolor', 'white'))
                    
    tri = a3.art3d.Poly3DCollection(poly)
    tri.set_edgecolor(kwargs.get('edgecolor', 'k'))
    tri.set_alpha(kwargs.get('alpha', 0.8))
    tri.set_facecolor(kwargs.get('facecolor', 'white'))
    tri.set_facecolor(color_list)
    tri.set_linewidth(kwargs.get('linewidth', 1))
    ax.add_collection3d(tri)

    if im is not None:
        color_to_idx = kwargs.get('color_to_idx', {x: i for i, x in enumerate(sorted(list(set(color_list))))})
        cm = plt.get_cmap('Greys') 
        cNorm  = colors.Normalize(vmin=0, vmax=len(color_to_idx) - 1)
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
        for i in range(len(poly)):
            if color_list[i] == (1, 1, 1):
                continue
            if color_list[i] == (0.75, 0.75, 0.75):
                continue
            color_idx = color_to_idx[color_list[i]]
            marker = mrklist[color_idx]
            x, y, z = np.mean(poly[i], axis=0)
            ax.scatter(x, y, z, marker=marker, s=20, color=scalarMap.to_rgba(color_idx), alpha=1)

    plt.axis('off')
    ax.view_init(90, 0)

In [None]:
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')
cmap = 'rainbow'

im1 = np.zeros((1, 3, 13, 13))
im1[0, :, 2:5, 2:5] = 1 - get_color(cmap, 0, 16)[:, None, None]
im1[0, :, 2:5, -5:-2] = 1 - get_color(cmap, 1, 16)[:, None, None]
im1[0, :, -5:-2, 2:5] = 1 - get_color(cmap, 2, 16)[:, None, None]
im1[0, :, -5:-2, -5:-2] = 1 - get_color(cmap, 3, 16)[:, None, None]

im2 = np.zeros((1, 3, 13, 13))
im2[0, :, 2:5, 2:5] = 1 - get_color(cmap, 4, 16)[:, None, None]
im2[0, :, 2:5, -5:-2] = 1 - get_color(cmap, 5, 16)[:, None, None]
im2[0, :, -5:-2, 2:5] = 1 - get_color(cmap, 6, 16)[:, None, None]
im2[0, :, -5:-2, -5:-2] = 1 - get_color(cmap, 7, 16)[:, None, None]

im3 = np.zeros((1, 3, 13, 13))
im3[0, :, 2:5, 2:5] = 1 - get_color(cmap, 8, 16)[:, None, None]
im3[0, :, 2:5, -5:-2] = 1 - get_color(cmap, 9, 16)[:, None, None]
im3[0, :, -5:-2, 2:5] = 1 - get_color(cmap, 10, 16)[:, None, None]
im3[0, :, -5:-2, -5:-2] = 1 - get_color(cmap, 11, 16)[:, None, None]

im4 = np.zeros((1, 3, 13, 13))
im4[0, :, 2:5, 2:5] = 1 - get_color(cmap, 12, 16)[:, None, None]
im4[0, :, 2:5, -5:-2] = 1 - get_color(cmap, 13, 16)[:, None, None]
im4[0, :, -5:-2, 2:5] = 1 - get_color(cmap, 14, 16)[:, None, None]
im4[0, :, -5:-2, -5:-2] = 1 - get_color(cmap, 15, 16)[:, None, None]

im0 = np.zeros((1, 3, 13, 13))
im0[:, :, :, :-1] += im1[:, :, :, 1:]
im0[:, :, 1:, :] += im2[:, :, :-1, :]
im0[:, :, :, 1:] += im3[:, :, :, :-1]
im0[:, :, :-1, :] += im4[:, :, 1:, :]

im0[0, :, 2:5, 2:5] = 0.25
im0[0, :, 2:5, -5:-2] = 0.25
im0[0, :, -5:-2, 2:5] = 0.25
im0[0, :, -5:-2, -5:-2] = 0.25

im1[0, :, 2:5, 2:4] = 0.25
im1[0, :, 2:5, -5:-3] = 0.25
im1[0, :, -5:-2, 2:4] = 0.25
im1[0, :, -5:-2, -5:-3] = 0.25

im2[0, :, 3:5, 2:5] = 0.25
im2[0, :, 3:5, -5:-2] = 0.25
im2[0, :, -4:-2, 2:5] = 0.25
im2[0, :, -4:-2, -5:-2] = 0.25

im3[0, :, 2:5, 3:5] = 0.25
im3[0, :, 2:5, -4:-2] = 0.25
im3[0, :, -5:-2, 3:5] = 0.25
im3[0, :, -5:-2, -4:-2] = 0.25

im4[0, :, 2:4, 2:5] = 0.25
im4[0, :, 2:4, -5:-2] = 0.25
im4[0, :, -5:-3, 2:5] = 0.25
im4[0, :, -5:-3, -5:-2] = 0.25


color_list = [tuple(x) for x in np.reshape(1 - np.stack((im0, im1, im2, im3, im4), axis=-1), (3, -1)).T]
color_to_idx = {x: i for i, x in enumerate(sorted(set(color_list), key=lambda x: 0.3 * x[0] + 0.59 * x[1] + 0.11 * x[2]))}

plot_grid(ax, np.array([0, 0, 0]), 11, (1 - im0)[:, :, 1:-1, 1:-1], edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([-1, 0, 0]), 11, (1 - im1)[:, :, 1:-1, 1:-1], edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([0, -1, 0]), 11, (1 - im2)[:, :, 1:-1, 1:-1], edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([+1, 0, 0]), 11, (1 - im3)[:, :, 1:-1, 1:-1], edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([0, +1, 0]), 11, (1 - im4)[:, :, 1:-1, 1:-1], edgecolor='white', color_to_idx=color_to_idx)

plot_grid(ax, np.array([0, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([-1, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([0, -1, 0]), 1, alpha=0)
plot_grid(ax, np.array([+1, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([0, +1, 0]), 1, alpha=0)

ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)

plt.show()

In [None]:
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')
cmap = 'rainbow'

im1 = np.zeros((1, 3, 5, 5))
im1[0, :, 1:-1, 1:-1] = 1 - get_color(cmap, 0, 4)[:, None, None]

im2 = np.zeros((1, 3, 5, 5))
im2[0, :, 1:-1, 1:-1] = 1 - get_color(cmap, 1, 4)[:, None, None]

im3 = np.zeros((1, 3, 5, 5))
im3[0, :, 1:-1, 1:-1] = 1 - get_color(cmap, 2, 4)[:, None, None]

im4 = np.zeros((1, 3, 5, 5))
im4[0, :, 1:-1, 1:-1] = 1 - get_color(cmap, 3, 4)[:, None, None]

im0 = np.zeros((1, 3,5, 5))
im0[:, :, :, :-1] += im1[:, :, :, 1:]
im0[:, :, 1:, :] += im2[:, :, :-1, :]
im0[:, :, :, 1:] += im3[:, :, :, :-1]
im0[:, :, :-1, :] += im4[:, :, 1:, :]

im0[0, :, 1:-1, 1:-1] = 0.25
im1[0, :, 1:-1, 1:-2] = 0.25
im2[0, :, 2:-1, 1:-1] = 0.25
im3[0, :, 1:-1, 2:-1] = 0.25
im4[0, :, 1:-2, 1:-1] = 0.25

color_list = [tuple(x) for x in np.reshape(1 - np.stack((im0, im1, im2, im3, im4), axis=-1), (3, -1)).T]
color_to_idx = {x: i for i, x in enumerate(sorted(set(color_list), key=lambda x: 0.3 * x[0] + 0.59 * x[1] + 0.11 * x[2]))}

plot_grid(ax, np.array([0, 0, 0]), 5, (1 - im0), edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([-1, 0, 0]), 5, (1 - im1), edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([0, -1, 0]), 5, (1 - im2), edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([+1, 0, 0]), 5, (1 - im3), edgecolor='white', color_to_idx=color_to_idx)
plot_grid(ax, np.array([0, +1, 0]), 5, (1 - im4), edgecolor='white', color_to_idx=color_to_idx)

plot_grid(ax, np.array([0, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([-1, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([0, -1, 0]), 1, alpha=0)
plot_grid(ax, np.array([+1, 0, 0]), 1, alpha=0)
plot_grid(ax, np.array([0, +1, 0]), 1, alpha=0)

ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)

plt.show()

## Padding Graph

In [None]:
import numpy as np
from utils import rotate_3d


def sym(x, axis):
    y = rotate_3d(x, np.pi / 2, axis=axis)
    n = x.shape[0]
    perm = [None] * n
    for i in range(n):
        for j in range(n):
            if np.linalg.norm(x[i] - y[j], ord=2) < 1e-6:
                perm[i] = j
                break
    return perm

In [None]:
def draw(ax, G, pos, node_color, node_size, alpha=1.0):
    ax.set_xlim(-0.6, +0.6)
    ax.set_ylim(-0.6, +0.6)
    ax.set_zlim(-0.6, +0.6)
    for key, value in pos.items():
        xi = value[0]
        yi = value[1]
        zi = value[2]
        ax.scatter(xi, yi, zi, s=node_size, c=node_color, alpha=alpha)
        for i, j in enumerate(G.edges()):
            x = np.array((pos[j[0]][0], pos[j[1]][0]))
            y = np.array((pos[j[0]][1], pos[j[1]][1]))
            z = np.array((pos[j[0]][2], pos[j[1]][2]))
            ax.plot(x, y, z, c=G[j[0]][j[1]]['color'], alpha=alpha)
    plt.axis('off')
    ax.view_init(10, 10)

In [None]:
import networkx as nx
from poly_sphere import *
%matplotlib inline


x = get_sampling_grid('cube', 1, center=False)
x = x.reshape((-1, 4, 3))
x = 0.5 * x + 0.5 * np.mean(x, axis=1, keepdims=True)
x = x.reshape((-1, 3))
perm_x = sym(x, axis=0)
perm_y = sym(x, axis=1)
perm_z = sym(x, axis=2)

same_face = lambda i, j: (i // 4) == (j // 4)
pos = {i: x[i] for i in range(x.shape[0])}
G = nx.Graph()
cmap = 'rainbow'
color_list = [get_color(cmap, i, 3) for i in range(3)]
for k, perm in enumerate([perm_x, perm_y, perm_z]):
    for i, j in enumerate(perm):
        if not same_face(i, j):
            G.add_edge(i, j, color=color_list[k])
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')
draw(ax, G, pos, 'k', 50, alpha=1)
plot_cube_sphere(ax, w=1, project=False, linewidth=1, alpha=0.25, edgecolor='k', facecolor=(1, 1, 1, 0.25))

plt.show()

In [None]:
import networkx as nx
from poly_sphere import *
%matplotlib inline


x = get_sampling_grid('cube', 1, center=False)
x = x.reshape((-1, 4, 3))
x = np.mean(x, axis=1, keepdims=True)
x = x.reshape((-1, 3))
perm_x = sym(x, axis=0)
perm_y = sym(x, axis=1)
perm_z = sym(x, axis=2)

same_face = lambda i, j: (i // 4) == (j // 4)
pos = {i: x[i] for i in range(x.shape[0])}
G = nx.Graph()
cmap = 'rainbow'
color_list = [get_color(cmap, i, 3) for i in range(3)]
for k, perm in enumerate([perm_x, perm_y, perm_z]):
    for i, j in enumerate(perm):
        G.add_edge(i, j, color=color_list[k])
fig = plt.figure(figsize=(14, 14))
ax = fig.add_subplot(111, projection='3d')
draw(ax, G, pos, 'k', 50, alpha=1)
plot_cube_sphere(ax, w=1, project=False, linewidth=1, alpha=0.25, edgecolor='k', facecolor=(1, 1, 1, 0.25))

plt.show()