Aktivace Julia prostředí

In [1]:
using Pkg;
Pkg.activate(".");

[32m[1m Activating[22m[39m environment at `~/Documents/git/XP36VPD-presentations/MIL-presentation/demo/Project.toml`


Instalace balíků - potřeba pustit pouze pokud notebook pouštíte poprvé

In [2]:
# Pkg.instantiate();

Načtení používaných balíků

In [3]:
using Flux;
using Flux: @epochs, logitcrossentropy, onehotbatch, throttle;
using JLD2;
using Mill;
using MilDatasets;
using Random;
using Statistics;

Načtení dat - dataset Musk2

In [4]:
@load "Musk2.jld2" X bags y;

`X` je matice příznaků instancí

In [5]:
X

166×6598 Array{Float64,2}:
   46.0    41.0    46.0    41.0    41.0  …    44.0    44.0    51.0    51.0
 -108.0  -188.0  -194.0  -188.0  -188.0     -104.0  -102.0  -121.0  -122.0
  -60.0  -145.0  -145.0  -145.0  -145.0      -19.0   -19.0   -23.0   -23.0
  -69.0    22.0    28.0    22.0    22.0     -105.0  -104.0  -106.0  -106.0
 -117.0  -117.0  -117.0  -117.0  -117.0     -117.0  -117.0  -117.0  -117.0
   49.0    -6.0    73.0    -7.0    -7.0  …   142.0    72.0    63.0   190.0
   38.0    57.0    57.0    57.0    57.0     -165.0  -165.0  -161.0  -161.0
 -161.0  -171.0  -168.0  -170.0  -170.0       68.0    65.0    79.0    80.0
   -8.0   -39.0   -39.0   -39.0   -39.0     -225.0  -219.0  -224.0  -227.0
    5.0  -100.0   -22.0   -99.0   -99.0      -32.0   -12.0   -30.0   -52.0
 -323.0  -319.0  -319.0  -319.0  -319.0  …  -124.0  -107.0  -129.0  -139.0
 -220.0  -111.0  -111.0  -111.0  -111.0      -77.0   -66.0   -54.0   -63.0
 -113.0  -228.0  -104.0  -228.0  -228.0      -43.0   -58.0   -60.0   -51.

`bags` je přiřazení instancí do bagů

In [6]:
bags

6598-element Array{Int64,1}:
   1
   1
   1
   1
   1
   1
   1
   1
   1
   1
   1
   1
   1
   ⋮
 102
 102
 102
 102
 102
 102
 102
 102
 102
 102
 102
 102

`y` je vektor tříd - pouze na úrovni bagů

In [7]:
y

102-element Array{Int64,1}:
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 ⋮
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0

In [8]:
ds = BagNode(ArrayNode(X), bags, y)

[34mBagNode with 102 bag(s)[39m
[34m  └── [39m[37mArrayNode(166, 6598)[39m

Alternativně načteme data z balíčku běžně používaných MIL datasetů

In [9]:
ds = MilDatasets.Musk2()

[34mBagNode with 102 bag(s)[39m
[34m  └── [39m[37mArrayNode(166, 6598)[39m

Rozdělení na trénovací a testovací data

In [10]:
perm = randperm(length(ds.bags));
delim = Int(round(0.8 * length(ds.bags)));
train = ds[perm[1:delim]];
test = ds[perm[delim + 1:end]];

Nastavení parametrů modelu
- Šířka skryté vrstvy
- Počet epoch

In [11]:
hidden_layer_width = 10;
epochs = 10;

Definice modelu pomocí metod balíčku `Mill.jl`
- Instanční model: Dvě vrstvy šířky `hidden_layer_width` s aktivační funkcí ReLU
- Agregace: `mean` a `max`
- Bag model: Dvě vrstvy šířky `hidden_layer_width` s aktivační funkcí ReLU

In [12]:
model = BagModel(
    Dense(size(train.data.data, 1), hidden_layer_width, relu),
    SegmentedMeanMax(hidden_layer_width),
    Chain(Dense(2 * hidden_layer_width, hidden_layer_width, relu), Dense(hidden_layer_width, 2))
)

[34mBagModel ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Chain(Dense(20, 10, relu), Dense(10, 2)))[39m
[34m  └── [39m[37mArrayModel(Dense(166, 10, relu))[39m

Definice ztrátové funkce - cross-entropy.

In [13]:
loss(x) = logitcrossentropy(model(x).data, onehotbatch(x.metadata, 0:1));

Trénujeme pomocí metody ADAM s `η = 0.05`.

In [14]:
opt = ADAM(0.05);
evalcb() = @show(loss(test));

@epochs epochs Flux.train!(loss, params(model), (train,), opt, cb = throttle(evalcb, 10));

┌ Info: Epoch 1
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 237.47441f0


┌ Info: Epoch 2
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 84.63966f0


┌ Info: Epoch 3
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 17.265991f0


┌ Info: Epoch 4
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 18.48794f0


┌ Info: Epoch 5
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 7.2142744f0


┌ Info: Epoch 6
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 2.4359815f0


┌ Info: Epoch 7
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 2.1777444f0


┌ Info: Epoch 8
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 2.269794f0


┌ Info: Epoch 9
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 1.8330243f0


┌ Info: Epoch 10
└ @ Main /home/marekdedic/.julia/packages/Flux/q3zeA/src/optimise/train.jl:136


loss(test) = 1.3226635f0
