In [1]:
import jax.numpy as jnp
import plotly.express as px
from plotly.subplots import make_subplots
import jax
import numpy as np
from datasets import mnist
import plotly.graph_objects as go

In [2]:
train_images, train_labels, test_images, test_labels = mnist()

train_images = train_images.astype(jnp.float32)
test_images = test_images.astype(jnp.float32)

train_labels = jnp.asarray(train_labels, dtype=jnp.int32)
test_labels = jnp.asarray(test_labels, dtype=jnp.int32)

In [3]:
def visualize_images(images_tensor):
    
    img = images_tensor.reshape(-1, 28, 28)
    
    fig = px.imshow(img[:, :, :], binary_string=False, facet_col=0, facet_col_wrap=5)
    
    item_map={f'{i}':"" for i, key in enumerate(range(img.shape[0]))}
    fig.for_each_annotation(lambda a: a.update(text=item_map[a.text.split("=")[1]])) 
    
    fig.show()

In [5]:
net_parameters = {
    'w0' : np.random.randn(784, 15) * 0.1,
    'w1' : np.random.randn(15, 10) * 0.1,
}

In [6]:
def ReLU(x):
    return jnp.maximum(0,x)

def forward(parameters, x):
    x = x @ parameters['w0']
    x = ReLU(x)
    x = x @ parameters['w1']
    return x

In [7]:
def loss(parameters, x, y):
    out = forward(parameters, x)
    out = jax.nn.softmax(out)
    _loss = -(y * jnp.log(out)).sum(axis=-1).mean()
    return _loss

loss(net_parameters, test_images, test_labels)

DeviceArray(2.344119, dtype=float32)

In [8]:
(forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()

DeviceArray(0.10825, dtype=float32)

In [23]:
grad_loss = jax.grad(loss)
lr = 0.1

for epoch in range(200):

    p_grad = grad_loss(net_parameters, train_images, train_labels)

    net_parameters['w0'] -= lr * p_grad['w0']
    net_parameters['w1'] -= lr * p_grad['w1']
    
    #visualize_images(-p_grad['w0'].T)
        
    print(f"epoch {epoch}")
    print(f"validation loss: {loss(net_parameters, test_images, test_labels)}")
    print(f"train loss: {loss(net_parameters, train_images, train_labels)}")
    acc = (forward(net_parameters, train_images).argmax(axis=-1) == train_labels.argmax(axis=-1)).mean()
    print(f"accuracy: {acc}")
    print("\n")

epoch 0
validation loss: 2.325653076171875
train loss: 2.33022403717041
accuracy: 0.11901666969060898


epoch 1
validation loss: 2.309156656265259
train loss: 2.314194440841675
accuracy: 0.1305333375930786


epoch 2
validation loss: 2.294065475463867
train loss: 2.2995615005493164
accuracy: 0.141116663813591


epoch 3
validation loss: 2.2799882888793945
train loss: 2.28590989112854
accuracy: 0.15301667153835297


epoch 4
validation loss: 2.2666354179382324
train loss: 2.2729482650756836
accuracy: 0.16503334045410156


epoch 5
validation loss: 2.253762722015381
train loss: 2.260450601577759
accuracy: 0.17701667547225952


epoch 6
validation loss: 2.241189479827881
train loss: 2.2482340335845947
accuracy: 0.1901833415031433


epoch 7
validation loss: 2.228740930557251
train loss: 2.236149311065674
accuracy: 0.20288333296775818


epoch 8
validation loss: 2.216329336166382
train loss: 2.2241172790527344
accuracy: 0.21641667187213898


epoch 9
validation loss: 2.203847885131836
train loss: 

accuracy: 0.7497333288192749


epoch 78
validation loss: 0.9205339550971985
train loss: 0.9451223611831665
accuracy: 0.7518166899681091


epoch 79
validation loss: 0.910693347454071
train loss: 0.9352618455886841
accuracy: 0.753849983215332


epoch 80
validation loss: 0.9011291861534119
train loss: 0.9256772994995117
accuracy: 0.7559000253677368


epoch 81
validation loss: 0.8918290734291077
train loss: 0.9163582921028137
accuracy: 0.7578333616256714


epoch 82
validation loss: 0.882784366607666
train loss: 0.9072946906089783
accuracy: 0.7597500085830688


epoch 83
validation loss: 0.8739864230155945
train loss: 0.8984774947166443
accuracy: 0.7613666653633118


epoch 84
validation loss: 0.865426242351532
train loss: 0.8898985981941223
accuracy: 0.7631666660308838


epoch 85
validation loss: 0.8570940494537354
train loss: 0.8815484046936035
accuracy: 0.7647833228111267


epoch 86
validation loss: 0.8489822745323181
train loss: 0.8734184503555298
accuracy: 0.7667666673660278


epoch 87
v

accuracy: 0.8317166566848755


epoch 155
validation loss: 0.5652155876159668
train loss: 0.5886468887329102
accuracy: 0.8323667049407959


epoch 156
validation loss: 0.5631305575370789
train loss: 0.5865421891212463
accuracy: 0.8329333662986755


epoch 157
validation loss: 0.5610747337341309
train loss: 0.5844664573669434
accuracy: 0.8335999846458435


epoch 158
validation loss: 0.5590470433235168
train loss: 0.5824190378189087
accuracy: 0.8342666625976562


epoch 159
validation loss: 0.5570471882820129
train loss: 0.5803986191749573
accuracy: 0.8348166942596436


epoch 160
validation loss: 0.5550748109817505
train loss: 0.5784051418304443
accuracy: 0.8352833390235901


epoch 161
validation loss: 0.5531283020973206
train loss: 0.5764380097389221
accuracy: 0.8358500003814697


epoch 162
validation loss: 0.5512080788612366
train loss: 0.574496865272522
accuracy: 0.8363333344459534


epoch 163
validation loss: 0.549313485622406
train loss: 0.57258141040802
accuracy: 0.836733341217041


ep

In [25]:
# written by chatGPT
def create_distribution(data):
    """
    Create a distribution from a given dataset.

    Parameters:
    data (numpy.array): Array of data points.

    Returns:
    numpy.array: Probability distribution of the data points.
    """
    # Compute the histogram of the data
    counts, bins = np.histogram(data, bins='auto', density=True)
    
    # Compute the probability distribution
    bin_widths = np.diff(bins)
    distribution = counts * bin_widths
    
    return distribution

y0 = np.maximum(0,test_images @ net_parameters['w0'])
avs = y0.sum(axis=-1)
counts, _ = np.histogram(avs)
px.line(counts, line_shape='spline')

In [97]:
def entropy(distr):
    return -(distr*jnp.log(distr+1e-10)).sum(axis=-1)

y0 = test_images @ net_parameters['w0']

entropy(jax.nn.softmax(y0)).mean()

DeviceArray(2.3598225, dtype=float32)