In [1]:
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_classification
from typing import Tuple, List

import plotly.graph_objects as go
from plotly.figure_factory import create_trisurf

In [2]:
def stable_softmax(z):
    z_star = z.max(axis=1, keepdims=True)
    ez = np.exp(z-z_star)

    return ez/ez.sum(axis=1, keepdims=True)

def d_theta(phi_x, y_hat, Y):
    return 1/len(phi_x)*phi_x.T@(y_hat - Y)

# Creando los datos

In [36]:
# Límites de la gráfica
bounds = [-3, 3]

X, Y = make_classification(n_samples=100, n_features=3, n_redundant=0, n_informative=2, n_clusters_per_class=1, 
                           n_classes=3, class_sep=3)
X = (X - np.mean(X, axis=0, keepdims=True)) / np.std(X, axis=0, keepdims=True)
phi_x = np.hstack((np.ones((len(X), 1)), X))

Y_one_hot = np.zeros((len(Y), 3))
Y_one_hot[range(len(Y)), Y] = 1

# Hacer cuadrada la gráfica si tienen ejes del mismo tamaño
layout = go.Layout(
    width=600,
    height=600,
    scene = dict(aspectmode='cube',
                 xaxis_range=bounds,
                 yaxis_range=bounds,
                 zaxis_range=bounds))

border_colors = ['white' if y == 0 else 'black' for y in Y]
fig = go.Figure(data=[go.Scatter3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], 
                                   mode='markers', marker={'line':{'width':1, 'color':border_colors}, 
                                                           'size':6, 'color':Y})], 
                layout=layout)

fig.show()

# Parámetros

In [37]:
alpha = 0.1
theta = np.random.randn(4, 3)
accuracy = -1

k = 0
max_iters = 301

Despejando $x_2$ del plano $\theta_0 + x_1\theta_1 + x_2\theta_2 + x_3\theta_3 = 0$ que pasa por $(0,0)$

se tiene $x_2 = \frac{-\theta_0-x_1\theta_1-x_2\theta_2}{\theta_3}$

In [38]:
xs = []
ys = []
z1s = []
z2s = []
z3s = []

grad = np.ones_like(theta)
while np.abs(grad.max()) > 1e-4 and accuracy <= 0.99 and k < max_iters:    
    z = phi_x@theta
    y_hat = stable_softmax(z)
    
    grad = d_theta(phi_x, y_hat, Y_one_hot)
    
    theta = theta - alpha*grad
    accuracy = accuracy_score(Y, y_hat.argmax(axis=1))
    k += 1
    
    # Guardando los extremos de la recta en cada paso
    xs.append([bounds[0], bounds[0], bounds[1], bounds[1]])
    ys.append([bounds[0], bounds[1], bounds[0], bounds[1]])
    
    z1s.append([float(-(theta[0,0] + theta[1,0]*bounds[0] + theta[2,0]*bounds[0]) / theta[3,0]), 
                float(-(theta[0,0] + theta[1,0]*bounds[0] + theta[2,0]*bounds[1]) / theta[3,0]),
                float(-(theta[0,0] + theta[1,0]*bounds[1] + theta[2,0]*bounds[0]) / theta[3,0]),
                float(-(theta[0,0] + theta[1,0]*bounds[1] + theta[2,0]*bounds[1]) / theta[3,0])])
    
    z2s.append([float(-(theta[0,1] + theta[1,1]*bounds[0] + theta[2,1]*bounds[0]) / theta[3,1]), 
                float(-(theta[0,1] + theta[1,1]*bounds[0] + theta[2,1]*bounds[1]) / theta[3,1]),
                float(-(theta[0,1] + theta[1,1]*bounds[1] + theta[2,1]*bounds[0]) / theta[3,1]),
                float(-(theta[0,1] + theta[1,1]*bounds[1] + theta[2,1]*bounds[1]) / theta[3,1])])
    
    z3s.append([float(-(theta[0,2] + theta[1,2]*bounds[0] + theta[2,2]*bounds[0]) / theta[3,2]), 
                float(-(theta[0,2] + theta[1,2]*bounds[0] + theta[2,2]*bounds[1]) / theta[3,2]),
                float(-(theta[0,2] + theta[1,2]*bounds[1] + theta[2,2]*bounds[0]) / theta[3,2]),
                float(-(theta[0,2] + theta[1,2]*bounds[1] + theta[2,2]*bounds[1]) / theta[3,2])])
    
xs = np.asarray(xs)
ys = np.asarray(ys)
z1s = np.asarray(z1s)
z2s = np.asarray(z2s)
z3s = np.asarray(z3s)

print(accuracy)

1.0


# Animación

In [6]:
from scipy.spatial import Delaunay

In [39]:
# Matriz de vectores colúmna 4x3
idx = -1
# Triangulando la superficie
tri = Delaunay( np.hstack([xs[idx].reshape(-1,1), ys[idx].reshape(-1,1)]) )
simplices = tri.simplices

fig = go.Figure(data=[go.Scatter3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], 
                                   mode='markers', marker={'line':{'width':1, 'color':border_colors}, 
                                                           'size':6, 'color':Y})], 
                layout=layout)

fig.add_trace(create_trisurf(x=xs[idx], y=ys[idx], z=z1s[idx], simplices=simplices,
               show_colorbar=False, colormap='#9467bd', plot_edges=False).data[0])
fig.add_trace(create_trisurf(x=xs[idx], y=ys[idx], z=z2s[idx], simplices=simplices,
               show_colorbar=False, colormap='#e377c2', plot_edges=False).data[0])
fig.add_trace(create_trisurf(x=xs[idx], y=ys[idx], z=z3s[idx], simplices=simplices,
               show_colorbar=False, colormap='#bcbd22', plot_edges=False).data[0])

fig.show()

In [41]:
# Triangulando la superficie
tri = Delaunay(np.hstack([xs[0].reshape(-1,1), ys[0].reshape(-1,1)]))
simplices = tri.simplices

# Config animación
layout2=go.Layout(
    width=600,
    height=600,
    scene = dict(aspectmode='cube',
                 xaxis_range=bounds,
                 yaxis_range=bounds,
                 zaxis_range=bounds),
    updatemenus=[dict(buttons = [dict(args = [None, {"frame": {"duration": 10, 
                                                               "redraw": True},
                                                               "fromcurrent": True, 
                                                               "transition": {"duration": 0}}],
                                       label = "Play",
                                       method = "animate")],
                                type='buttons',
                                showactive=True,
                     )])

splot = go.Scatter3d(x=X[:, 0], y=X[:, 1], z=X[:, 2], mode='markers', 
                     marker={'line':{'width':1, 'color':border_colors}, 
                             'size':6, 'color':Y}, showlegend=False) 

# Dos veces para que se mantengan los puntos en la animación
fig = go.Figure(data=[splot, splot, splot, splot], 
                layout=layout2,
                frames = [go.Frame(data=[
                    create_trisurf(x=x, y=y, z=z1, simplices=simplices, show_colorbar=False, colormap='#9467bd', plot_edges=False).data[0],
                    create_trisurf(x=x, y=y, z=z2, simplices=simplices, show_colorbar=False, colormap='#e377c2', plot_edges=False).data[0],
                    create_trisurf(x=x, y=y, z=z3, simplices=simplices, show_colorbar=False, colormap='#bcbd22', plot_edges=False).data[0]])
                          for x, y, z1, z2, z3 in zip(xs, ys, z1s, z2s, z3s)])

# for frame_idx in range(len(fig['frames'])):
#     for i in range(3):
#         fig['frames'][frame_idx]['data'][i].update(opacity=0.75)

fig.show()