In [1]:
using JLD2
using Flux
include("../src/ad.jl")
include("../src/operations.jl")
include("../src/mlp.jl")

test! (generic function with 1 method)

In [2]:
X_train = load("../data/imdb_dataset_prepared.jld2", "X_train")
y_train = load("../data/imdb_dataset_prepared.jld2", "y_train")
X_test = load("../data/imdb_dataset_prepared.jld2", "X_test")
y_test = load("../data/imdb_dataset_prepared.jld2", "y_test")

@show typeof(X_train)
@show typeof(y_train)


typeof(X_train) = Adjoint{Float32, Matrix{Float32}}
typeof(y_train) = BitMatrix


BitMatrix[90m (alias for [39m[90mBitArray{2}[39m[90m)[39m

In [3]:
input_dim, n_samples = size(X_train)
output_dim, _ = size(y_train)
hidden_dim = 32
epochs = 5
lr = 0.001
β1, β2, ϵ = 0.9, 0.999, 1e-8

Wh, Wo = create_mlp(input_dim, hidden_dim, output_dim)
stateWh = AdamState(size(Wh.output))
stateWo = AdamState(size(Wo.output))

x_var = Variable(X_train[:, 1], name="x")
y_var = Variable(y_train[:, 1], name="y")
graph, x_var, y_var, ŷ = build_graph(x_var, y_var, Wh, Wo, dense, σ)

for epoch in 1:epochs
    total_loss = 0.0
    for i in 1:n_samples
        x_var.output .= X_train[:, i]
        y_var.output .= y_train[:, i]
        loss = train_step!(x_var, y_var, Wh, Wo, lr, (stateWh, stateWo), β1, β2, ϵ, graph)
        total_loss += loss
    end

    avg_train_loss = total_loss / n_samples
    test_loss, test_accuracy = test!(X_test, y_test, graph, x_var, y_var, ŷ)

    stateWh = AdamState(size(Wh.output))
    stateWo = AdamState(size(Wo.output))

    println("Epoch $epoch:")
    println("  Train Loss: $(round(avg_train_loss, digits=4))")
    println("  Test  Loss: $(round(test_loss, digits=4)), Accuracy: $(round(test_accuracy, digits=2))%")
end


Epoch 1:
  Train Loss: 0.3431
  Test  Loss: 0.3022, Accuracy: 87.55%
Epoch 2:
  Train Loss: 0.1051
  Test  Loss: 0.3802, Accuracy: 86.65%
Epoch 3:
  Train Loss: 0.0478
  Test  Loss: 0.4961, Accuracy: 84.75%
Epoch 4:
  Train Loss: 0.0256
  Test  Loss: 0.6175, Accuracy: 83.8%
Epoch 5:
  Train Loss: 0.0151
  Test  Loss: 0.7411, Accuracy: 83.3%
