We consider a simple neural network that performs regression based on $y = \sin x$, using the package Lux.jl.

In [None]:
using Lux, Random, Optimization, OptimizationOptimisers, ComponentArrays, Zygote, Plots, LinearAlgebra

# Define data for regression
n = 100
x = range(-π, π; length = n)
y = sin.(x)

# Define a neural network model
model = Chain(Dense(1 => 10, relu), Dense(10 => 10, relu), Dense(10 => 1))

# Define the loss function based on 2-norm
function regression_loss(ps, (model, st, (x, y)))
    ŷ = vec(model(x', ps, st)[1])
    return norm(ŷ - y)
end

# Train the neural network
rng = MersenneTwister()
ps, st = Lux.setup(rng, model)

# Define the optimization problem
prob = OptimizationProblem(OptimizationFunction(regression_loss, Optimization.AutoZygote()), ComponentArray(ps), (model, st, (x, y)))

# Solve the optimization problem
@time ret = solve(prob, Adam(0.03), maxiters = 250)

# Plot the results
plot(x, y, label="True")
plot!(x, vec(model(x', ret.u, st)[1]), label="Predicted")

We now consider another simple neural network that does number recognition based on the MNIST dataset, using another package FLux.jl.

[Reference: https://github.com/piotrek124-1/Simple_MNIST_Julia/tree/main]

In [None]:
using Flux, MLDatasets

# Import the MNIST dataset
x_train, y_train = MLDatasets.MNIST.traindata(Float32)
x_test, y_test = MLDatasets.MNIST.testdata(Float32)

# Use one-hot encoding for the labels
y_train = Flux.onehotbatch(y_train, 0:9)

# Define a neural network model
model = Chain(Dense(784, 256, relu), Dense(256, 10, relu), softmax)

# Define the loss function based on cross entropy
loss(x, y) = Flux.Losses.logitcrossentropy(model(x), y)

# Train the neural network
parameters = Flux.params(model)
train_data = [(Flux.flatten(x_train), Flux.flatten(y_train))]
for i in 1:300
    Flux.train!(loss, parameters, train_data, Adam(0.003))
end

# Test the neural network
test_data = [(Flux.flatten(x_test), y_test)]
accuracy = 0
for i in 1:length(y_test)
    if findmax(model(test_data[1][1][:, i]))[2] - 1  == y_test[i]
        accuracy = accuracy + 1
    end
end

# Print the accuracy
accuracy / length(y_test)