In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive, FloatSlider, HBox, VBox, Dropdown, Output, Label

#gates
def xor(inputs):
    return np.array([-1 if (x == 1 and y == -1) or (x == -1 and y == 1) else 1 for x, y in inputs])

def and_gate(inputs):
    return np.array([1 if (x == 1 and y == 1) else -1 for x, y in inputs])

def nand_gate(inputs):
    return np.array([-1 if (x == 1 and y == 1) else 1 for x, y in inputs])

def or_gate(inputs):
    return np.array([1 if (x == 1 or y == 1) else -1 for x, y in inputs])

# 2-layer, relu
def nn(inputs, weights_layer1, weights_layer2):
    def relu(x):
        return np.maximum(0, x)
    layer1_output = relu(np.dot(inputs, weights_layer1))
    output = np.dot(layer1_output, weights_layer2)
    return output

#plot stuff
def plot_decision_boundary_colored(weights_layer1, weights_layer2, logic_gate, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
    ax.clear()

    x_min, x_max = -2, 2
    y_min, y_max = -2, 2
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 1000), np.linspace(y_min, y_max, 1000))
    grid = np.c_[xx.ravel(), yy.ravel()]
    
    # compute all possible outputs
    output = nn(grid, weights_layer1, weights_layer2)
    
    # normalize and plot
    normalized_output = (output - output.min()) / (output.max() - output.min())
    normalized_output = normalized_output.reshape(xx.shape)
    ax.contourf(xx, yy, normalized_output, alpha=0.7, cmap='RdBu', levels=100)
    logic_gate_output = logic_gate(inputs)
    for i, input_data in enumerate(inputs):
        if logic_gate_output[i] == 1:
            ax.scatter(input_data[0], input_data[1], color='blue', edgecolor='k')
        else:
            ax.scatter(input_data[0], input_data[1], color='red', edgecolor='k')
    
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title('Decision Boundary')
    
    return ax

def interactive_plot(w11, w12, w21, w22, w31, w32, gate):
    weights_layer1 = np.array([[w11, w12], [w21, w22]])
    weights_layer2 = np.array([w31, w32])
    
    logic_gate = {'XOR': xor, 'AND': and_gate, 'NAND': nand_gate, 'OR': or_gate}[gate]
    
    fig, ax = plt.subplots(figsize=(5, 5))
    plot_decision_boundary_colored(weights_layer1, weights_layer2, logic_gate, ax=ax)
    plt.show()
    
    success = check_success(weights_layer1, weights_layer2, logic_gate)
    success_label.value = "Success!" if success else "Try again..."

#win condition
def check_success(weights_layer1, weights_layer2, logic_gate):
    outputs = np.where(nn(inputs, weights_layer1, weights_layer2) == 0, 1, -1)
    expected_outputs = logic_gate(inputs)
    success = np.all(outputs == expected_outputs)
    return success

# all possible (discrete) inputs. since we dont use bias we need data to be normalized, so -1 instead of 0
inputs = np.array([[-1, -1], [-1, 1], [1, -1], [1, 1]])

#GUI
w11_slider = FloatSlider(value=-1, min=-1, max=1, step=0.01, description='w11')
w11_slider.value = np.random.randn()
w12_slider = FloatSlider(value=1, min=-1, max=1, step=0.01, description='w12')
w12_slider.value = np.random.randn()
w21_slider = FloatSlider(value=1, min=-1, max=1, step=0.01, description='w21')
w21_slider.value = np.random.randn()
w22_slider = FloatSlider(value=-1, min=-1, max=1, step=0.01, description='w22')
w22_slider.value = np.random.randn()
w31_slider = FloatSlider(value=1, min=-1, max=1, step=0.01, description='w31')
w31_slider.value = np.random.randn()
w32_slider = FloatSlider(value=1, min=-1, max=1, step=0.01, description='w32')
w32_slider.value = np.random.randn()

gate_selector = Dropdown(
    options=['XOR', 'AND', 'NAND', 'OR'],
    value='XOR',
    description='Gate:',
)

success_label = Label(value="")

interactive_plot = interactive(interactive_plot, w11=w11_slider, w12=w12_slider, w21=w21_slider, w22=w22_slider, w31=w31_slider, w32=w32_slider, gate=gate_selector)
widgets = VBox([interactive_plot, HBox([w11_slider, w12_slider, w21_slider, w22_slider, w31_slider, w32_slider]), gate_selector, success_label])

display(widgets)


VBox(children=(interactive(children=(FloatSlider(value=0.5014701238550621, description='w11', max=1.0, min=-1.…