# Logistic Regression Example in Flux

Based on [this tutorial](https://fluxml.ai/Flux.jl/stable/tutorials/logistic_regression/)

In [1]:
using Flux, Statistics, MLDatasets, DataFrames, OneHotArrays

In [2]:
Iris()

dataset Iris:
  metadata   =>    Dict{String, Any} with 4 entries
  features   =>    150×4 DataFrame
  targets    =>    150×1 DataFrame
  dataframe  =>    150×5 DataFrame

In [3]:
x, y = Iris(as_df=false)[:];

In [4]:
y

1×150 Matrix{InlineStrings.String15}:
 "Iris-setosa"  "Iris-setosa"  …  "Iris-virginica"  "Iris-virginica"

In [5]:
x |> summary

"4×150 Matrix{Float64}"

In [6]:
x[:,23]

4-element Vector{Float64}:
 4.6
 3.6
 1.0
 0.2

In [7]:
x = Float32.(x);

In [8]:
y = vec(y);

In [9]:
y

150-element Vector{InlineStrings.String15}:
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 ⋮
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"

In [10]:
const classes = unique(y)

3-element Vector{InlineStrings.String15}:
 "Iris-setosa"
 "Iris-versicolor"
 "Iris-virginica"

In [11]:
flux_y_onehot = onehotbatch(y, classes)

3×150 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 1  1  1  1  1  1  1  1  1  1  1  1  1  …  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅     1  1  1  1  1  1  1  1  1  1  1  1

In [12]:
flux_model = Chain(Dense(4 => 100, relu),
    Dense(100 => 20, relu),
    Dense(20 => 3, relu),
    softmax)


Chain(
  Dense(4 => 100, relu),                [90m# 500 parameters[39m
  Dense(100 => 20, relu),               [90m# 2_020 parameters[39m
  Dense(20 => 3, relu),                 [90m# 63 parameters[39m
  NNlib.softmax,
) [90m                  # Total: 6 arrays, [39m2_583 parameters, 10.395 KiB.

In [13]:
function flux_loss(flux_model, features, labels_onehot)
    ŷ = flux_model(features)
    Flux.logitcrossentropy(ŷ, labels_onehot)
end;

In [14]:
flux_loss(flux_model, x, flux_y_onehot)

1.101367f0

In [15]:
flux_accuracy(x, y) = mean(Flux.onecold(flux_model(x), classes) .== y);

In [16]:
function train_flux_model!(f_loss, model, features, labels_onehot)
    dLdm, _, _ = gradient(f_loss, model, features, labels_onehot)
    @. model[1].weight = model[1].weight - 0.1 * dLdm[:layers][1][:weight]
    @. model[1].bias = model[1].bias - 0.1 * dLdm[:layers][1][:bias]
end;

In [17]:
for i = 1:500
    train_flux_model!(flux_loss, flux_model, x, flux_y_onehot);
    flux_accuracy(x, y) >= 0.98 && break
end

In [18]:
@show flux_accuracy(x, y);

flux_accuracy(x, y) = 0.8333333333333334


In [19]:
flux_loss(flux_model, x, flux_y_onehot)

0.8960455f0

In [20]:
println(Flux.onecold(flux_model(x), classes))

InlineStrings.String15["Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-setosa", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicolor", "Iris-versicol

In [21]:
y

150-element Vector{InlineStrings.String15}:
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 "Iris-setosa"
 ⋮
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"
 "Iris-virginica"