In [1]:
using Flux
using MLDatasets
using DataFrames
using Statistics

In [2]:
function loss(model, features, labels_onehot)
    y_hat = model(features)
    return Flux.logitcrossentropy(y_hat, labels_onehot)
end

loss (generic function with 1 method)

In [3]:
function train_model!(f_loss, model, features, labels_onehot; learning_rate=0.01)
    dLdm, _, _ = gradient(f_loss, model, features, labels_onehot)
    @. model[1].weight = model[1].weight - learning_rate * dLdm[:layers][1][:weight]
    @. model[1].bias = model[1].bias - learning_rate * dLdm[:layers][1][:bias]
end

train_model! (generic function with 1 method)

In [4]:
function train_until_accuracy_reached!(f_loss, model, features, labels, classes; max_epochs=10000, threshold=0.98, learning_rate=0.01)
    labels_onehot = Flux.onehotbatch(labels, classes)
    accuracy(x, y) = Statistics.mean(Flux.onecold(model(x), classes) .== y)
    for epoch in 1:max_epochs
        train_model!(f_loss, model, features, labels_onehot; learning_rate=learning_rate)

        current_accuracy = accuracy(features, labels)
        if current_accuracy >= threshold
            println("Converged at epoch $epoch with accuracy $current_accuracy")
            break
        end
    end
end

train_until_accuracy_reached! (generic function with 1 method)

In [5]:
#サンプルデータ
x, y = Iris(as_df=false)[:]
x = Float32.(x)
y = vec(y)
classes = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"];

model = Chain(Dense(4 => 3), softmax)
features = x
labels = y
train_until_accuracy_reached!(loss, model, features, labels, classes; max_epochs=1000000, learning_rate=0.01)

Converged at epoch 5529 with accuracy 0.98


In [6]:
#学習したモデルで予測
new_data = [[5.1, 3.5, 1.4, 0.2], [6.2, 2.9, 4.3, 1.3], [5.9, 3.0, 5.1, 1.8]]
for sample in new_data
    x = Float32.(sample)
    y_hat = model(x)
    predicted_class = Flux.onecold(y_hat, classes)
    println("Predicted class for sample ", sample, ": ", predicted_class)
end

Predicted class for sample [5.1, 3.5, 1.4, 0.2]: Iris-setosa
Predicted class for sample [6.2, 2.9, 4.3, 1.3]: Iris-versicolor
Predicted class for sample [5.9, 3.0, 5.1, 1.8]: Iris-virginica


In [7]:
Flux.params(model)

Params([Float32[-0.03557788 1.8849722 -1.0714439 -0.968261; -0.02271377 0.10371932 0.73199344 -0.5782927; -1.1295877 -1.0471723 2.1424239 1.9715953], Float32[0.2750254, 0.25109768, -0.52612257]])