In [1]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import random
import scipy
from math import pi, sqrt

EPSILON = 1e-8

tetris = [[(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],  # chiral_shape_1
          [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)], # chiral_shape_2
          [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],  # square
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],  # line
          [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],  # corner
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],  # T
          [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],  # zigzag
          [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)]]  # L

dataset = [np.array(points_) for points_ in tetris]
num_classes = len(dataset)

In [2]:
%%html
<center><img src='images/data.png', width=800, height=800>

In [3]:
from geomstats.geometry.pre_shape import PreShapeSpace
m_ambient = 3; k_landmarks = 4
preshape = PreShapeSpace(m_ambient=m_ambient, k_landmarks=k_landmarks)

tetris_preshape = preshape.projection(np.array(tetris))

INFO: Using numpy backend


In [4]:
def random_rotation_matrix(numpy_random_state):
    """
    Generates a random 3D rotation matrix from axis and angle.
    Args:
        numpy_random_state: numpy random state object
    Returns:
        Random rotation matrix.
    """
    rng = numpy_random_state
    axis = rng.randn(3)
    axis /= np.linalg.norm(axis) + EPSILON
    theta = 2 * np.pi * rng.uniform(0.0, 1.0)
    return rotation_matrix(axis, theta)

def rotation_matrix(axis, theta):
    return scipy.linalg.expm(np.cross(np.eye(3), axis * theta))

In [5]:
%%time

rng = np.random.RandomState()
test_set_size = 25 # we have 25 x 8 test shapes
predictions = [list() for i in range(len(dataset))]

correct_predictions = 0
total_predictions = 0
for i in range(test_set_size):
    for label, shape in enumerate(dataset):
        rotation = random_rotation_matrix(rng)
        rotated_shape = np.dot(shape, rotation)
        translation = np.expand_dims(np.random.uniform(low=-3., high=3., size=(3)), axis=0)
        translated_shape = rotated_shape + translation
        # project the  translated shape to the preshape space
        translated_shape_projection = preshape.projection(translated_shape)
        tetris_preshape_align = preshape.align(point=tetris_preshape, base_point=translated_shape_projection)
        output_label =np.argmin(np.sum(np.sum((tetris_preshape_align - translated_shape_projection)**2,axis=2),axis=1))
        total_predictions += 1
        if output_label == label:
            correct_predictions += 1
print('Test accuracy: %f' % (float(correct_predictions) / total_predictions))

 2.12033368e-16 3.09315064e-16 1.67261431e-16 2.96189651e-17]
 7.56846051e-17 1.33827028e-16 7.52185003e-17 4.86233347e-17]
 6.83688674e-17 1.70591243e-16 7.53659695e-17 4.73381597e-17]
 7.32173240e-17 3.30373041e-16 1.13466545e-16 4.71614063e-17]
 9.16686595e-17 2.29480429e-16 9.82391922e-17 2.81641166e-17]
 5.74605492e-17 2.55351708e-16 7.12707389e-17 5.82109816e-17]
 4.75809934e-17 8.65985787e-17 7.57179146e-17 0.00000000e+00]
 4.30561563e-17 1.56107759e-16 8.86987389e-17 1.46511717e-17]


 1.34847539e-16 1.32262806e-16 4.55428847e-17 7.29336852e-17]
 8.84199033e-17 1.47275025e-16 6.85908777e-17 4.98835657e-17]
 1.14473256e-16 2.52159809e-16 1.20407908e-16 0.00000000e+00]
 7.29554669e-17 1.00908062e-16 1.21167714e-16 4.77460740e-17]
 8.70499375e-17 1.19103740e-16 7.05107913e-17 2.67424241e-17]
 9.14807292e-17 2.91083281e-16 8.16672186e-17 0.00000000e+00]
 7.44216580e-17 9.95355732e-17 4.83881854e-17 1.45301207e-17]
 5.04346228e-17 1.92394021e-16 5.35759757e-17 2.46347085e-17]


 8.32297517e-17 2.41701590e-16 5.02732336e-17 2.57444985e-17]
 1.95406886e-16 4.43723690e-16 2.89090424e-16 8.27836450e-17]
 7.24913138e-17 1.18651065e-16 1.27644156e-16 4.55297688e-17]
 4.27996197e-17 9.92318381e-17 7.86734788e-17 8.05760484e-17]
 2.61984419e-17 2.13620673e-16 1.01210439e-16 8.82042556e-17]
 1.28583268e-16 2.24031238e-16 1.13117896e-16 3.66651746e-17]
 1.13436197e-16 2.19616970e-16 8.34348068e-17 6.24071183e-17]
 1.73919799e-16 2.70213023e-16 1.07939651e-16 2.42239997e-17]
 1.06336021e-16 1.76325327e-16 7.41858293e-17 5.37114756e-17]




Test accuracy: 1.000000
CPU times: user 1.4 s, sys: 454 ms, total: 1.85 s
Wall time: 656 ms
