In [1]:
import matplotlib
matplotlib.use('nbagg')
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from Brain import Neuron, Net, GMM
from scipy.stats import multivariate_normal
from matplotlib.lines import Line2D
matplotlib.rcParams.update({'font.size': 10})
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import Ellipse

In [2]:
p = GMM([0.5,0.5], np.array([[[0.39,0.5],.08],[[0.61,0.5],0.07]]))
q1 = Neuron([1,1], np.array([0.4]), np.eye(1)*0.0007, decay=0.03, lr_decay=0.005)
q2 = Neuron([1,1], np.array([0.2]), np.eye(1)*np.power(0.035,2), decay=0.02, lr_decay=0.004)
q3 = Neuron([1,1], np.array([0.58]), np.eye(1)*np.power(0.07,2), decay=0.02, lr_decay=0.004)
q4 = Neuron([1,3], np.array([0.02,0.79,0.2]), np.eye(3)*0.005, decay=0.009, lr_decay=0.0015)
q5 = Neuron([1,3], np.array([0.69,0.01,0.3]), np.eye(3)*0.005, decay=0.009, lr_decay=0.0015)
q6 = Neuron([1,2], np.array([0.5,0.5]), np.eye(2)*0.01, decay=0.01, lr_decay=0.001)

In [3]:
num_samples = 1000
samples, labels = p.sample(num_samples)
num_grid_pts = 500
t1 = np.linspace(0,1.0,num_grid_pts)
t2 = np.linspace(0,1.0,num_grid_pts)
q1_hist, q2_hist, q3_hist = ([], [], [])
q4_pts, q5_pts = ([],[]) 
fig2 = plt.figure(2)
ax = fig2.add_subplot(111, projection='3d')
colors = ['orange','black']
fig3 = plt.figure(3)


# For plotting the 3D neurons
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x_pts = np.outer(np.cos(u), np.sin(v))
y_pts = np.outer(np.sin(u), np.sin(v))
z_pts = np.outer(np.ones(np.size(u)), np.cos(v))

In [4]:
for k in range(1000):
    x = np.array(samples[k])
    l = labels[k]
    # Expose the neurons to the stimulus
    q1_hist.append(q1(x[1]))
    q2_hist.append(q2(x[0]))
    q3_hist.append(q3(x[0]))
    pt_3d = np.array([q2_hist[-1],q3_hist[-1],q1_hist[-1]])
    q4_pts.append(q4(pt_3d))
    q5_pts.append(q5(pt_3d))
    q6(np.array([q4_pts[-1],q5_pts[-1]]), 2*l-1)
    
    sb_plt = sns.JointGrid(x=samples[:,0], y=samples[:,1], xlim=(0,1), ylim=(0,1))
    x_hist = sb_plt.ax_marg_x.hist(samples[:,0],density=True,bins=100,color='blue')
    y_hist = sb_plt.ax_marg_y.hist(samples[:,1],density=True,bins=100,color='blue',orientation='horizontal')
    c1 = sb_plt.ax_joint.scatter(samples[:,0], samples[:,1])
    sb_plt.ax_joint.scatter(x[0], x[1],c='magenta')
    sb_plt.ax_joint.plot([x[0],x[0]],[0,1],c='magenta',ls='dashed')
    sb_plt.ax_joint.plot([0,1],[x[1],x[1]],c='magenta',ls='dashed')


    sb_plt.ax_joint.set_xlabel("$x_2$")
    sb_plt.ax_joint.set_xticklabels([])
    sb_plt.ax_joint.set_ylabel("$x_1$")
    sb_plt.ax_joint.set_yticklabels([])

    q1_vals = q1(t1.reshape(-1,1), update=False)
    q2_vals = q2(t2.reshape(-1,1), update=False)
    q3_vals = q3(t2.reshape(-1,1), update=False)
    sb_plt.ax_marg_y.plot((q1_vals/q1_vals.max())*y_hist[0].max(),t1,c='y',label='$q_1$')
    sb_plt.ax_marg_y.plot([0,y_hist[0].max()],[x[1],x[1]],c='magenta',lw=3)
    sb_plt.ax_marg_x.plot(t2, (q2_vals/q2_vals.max())*x_hist[0].max(),c='g',label='$q_2$')
    sb_plt.ax_marg_x.plot([x[0],x[0]],[0,x_hist[0].max()],c='magenta',lw=3)
    sb_plt.ax_marg_x.plot(t2, (q3_vals/q3_vals.max())*x_hist[0].max(),c='r',label='$q_3$')
    sb_plt.ax_marg_x.legend()
    sb_plt.ax_marg_y.legend()
    sb_plt.fig.savefig(f"figs/2d/fig{str(k).zfill(4)}.jpg")
    
    # Make 3D plot
    ax.scatter(q2_hist, q3_hist, q1_hist, c=[colors[li] for li in labels[:k+1]])
    
    # find the rotation matrix and radii of the axes
    U, q4_s, q4_r = np.linalg.svd(q4.get_bias())
    q4_radii = np.sqrt(q4_s)
    q4_center = q4.get_weights()[0]
    q4_x = q4_radii[0] * x_pts
    q4_y = q4_radii[1] * y_pts
    q4_z = q4_radii[2] * z_pts
    
    U, q5_s, q5_r = np.linalg.svd(q5.get_bias())
    q5_radii = np.sqrt(q5_s)
    q5_center = q5.get_weights()[0] 
    q5_x = q5_radii[0] * x_pts
    q5_y = q5_radii[1] * y_pts
    q5_z = q5_radii[2] * z_pts
    
    # Rotate and Translate data points
    for i in range(len(q4_x)):
        for j in range(len(q4_x)):
            [q4_x[i,j],q4_y[i,j],q4_z[i,j]] = np.dot([q4_x[i,j],q4_y[i,j],q4_z[i,j]], q4_r) + q4_center
            [q5_x[i,j],q5_y[i,j],q5_z[i,j]] = np.dot([q5_x[i,j],q5_y[i,j],q5_z[i,j]], q5_r) + q5_center
            
    # Plot the 3D Gaussians
    ax.plot_wireframe(q4_x, q4_y, q4_z, color='magenta', rcount=10, ccount=10)
    ax.plot_wireframe(q5_x, q5_y, q5_z, color='cyan', rcount=10, ccount=10)
    
    #ax.plot_wireframe(x_pts, y_pts, z_pts, color='b',rcount=20,ccount=20)
    # Plotting configuration
    ax.set_xlabel('$q_2(x)$')
    ax.set_ylabel('$q_3(x)$')
    ax.set_zlabel('$q_1(x)$')
    ax.set_xlim3d([0,1])
    ax.set_ylim3d([0,1])
    ax.set_zlim3d([0,1])
    ax.view_init(azim=(45+2*k)%360)
    fig2.savefig(f"figs/3d/fig{str(k).zfill(4)}.jpg")
    ax.cla()
    
    # Binary classifier figure
    ax3 = fig3.gca()
    ax3.scatter(q4_pts, q5_pts, c=[colors[li] for li in labels[:k+1]])
    ax3.set_xlim([-0.1,1.1])
    ax3.set_ylim([-0.1,1.1])
    # Draw q6
    U, q6_s, q6_r = np.linalg.svd(q6.get_bias())
    q6_bias = np.sqrt(q6_s)
    q6_ellipse = Ellipse(xy=q6.get_weights(),
                width=2*np.sqrt(q6_bias[0]), height=2*np.sqrt(q6_bias[1]),
                angle=0)
    ax3.add_artist(q6_ellipse)
    q6_ellipse.set_clip_box(ax3.bbox)
    q6_ellipse.set_alpha(0.8)
    q6_ellipse.set_facecolor((0.2,0.56,0.44))

    fig3.savefig(f"figs/output/fig{str(k).zfill(4)}.jpg")
    ax3.cla()

