In [1]:
# Example from this question: 
# https://discourse.julialang.org/t/why-does-binarycrossentropy-needs-an-index-in-a-denoising-autoencoder/41013

In [2]:
using Flux, Random

In [3]:
data = rand(2000,100)
data_corrupted = copy(data)

# Corrupt data
rng = MersenneTwister(1234)
for sample_index in 1:size(data)[2]
        # Create random indices
        indices = findall(bitrand(rng, 2000) .> 0) 
        
        # Change values at indices to 0
        for i in 1:size(indices)[1]
                data_corrupted[indices[i], sample_index] = 0
        end
end

In [4]:
data

2000×100 Matrix{Float64}:
 0.11351     0.512045   0.711521   …  0.497645  0.42522    0.850795
 0.927832    0.027396   0.0661235     0.317897  0.587738   0.536914
 0.396891    0.723363   0.296006      0.99137   0.729346   0.331776
 0.317445    0.677793   0.813511      0.378295  0.0405941  0.334668
 0.71864     0.826319   0.647217      0.25561   0.357937   0.0170603
 0.00975673  0.454618   0.641941   …  0.470681  0.959315   0.004981
 0.0800379   0.0524156  0.361736      0.872288  0.576295   0.321588
 0.889318    0.208279   0.418621      0.296998  0.0237296  0.98878
 0.171517    0.468157   0.207138      0.901629  0.205819   0.383007
 0.6382      0.563001   0.374413      0.85366   0.770146   0.895557
 ⋮                                 ⋱                       
 0.10427     0.3411     0.0248435     0.950551  0.315407   0.987709
 0.327329    0.262026   0.282734      0.712335  0.183549   0.600097
 0.287979    0.421628   0.0111928     0.057634  0.830809   0.539893
 0.220182    0.219428   0.1398

In [5]:
data_corrupted

2000×100 Matrix{Float64}:
 0.11351    0.0        0.711521   0.764365  …  0.0       0.0        0.0
 0.927832   0.0        0.0661235  0.0          0.317897  0.587738   0.536914
 0.0        0.723363   0.0        0.219799     0.99137   0.729346   0.331776
 0.317445   0.677793   0.0        0.736131     0.378295  0.0405941  0.0
 0.71864    0.0        0.647217   0.0          0.0       0.0        0.0170603
 0.0        0.0        0.0        0.0       …  0.0       0.0        0.0
 0.0800379  0.0524156  0.0        0.866262     0.0       0.0        0.0
 0.0        0.208279   0.418621   0.0          0.296998  0.0237296  0.0
 0.171517   0.0        0.207138   0.0          0.901629  0.0        0.383007
 0.6382     0.0        0.0        0.812615     0.0       0.0        0.0
 ⋮                                          ⋱                       
 0.0        0.0        0.0248435  0.0          0.950551  0.315407   0.0
 0.0        0.262026   0.282734   0.718394     0.0       0.183549   0.600097
 0.0        0.0

In [6]:
# Partition into batches of 10
data_partitioned = [data[:, i:min(i+10-1,size(data, 2))] for i in 1:10:size(data, 2)]
data_corrupted_partitioned = [data_corrupted[:, i:min(i+10-1,size(data_corrupted, 2))] for i in 1:10:size(data_corrupted, 2)]

10-element Vector{Matrix{Float64}}:
 [0.1135097940037636 0.0 … 0.7304379749388124 0.26078544265830106; 0.9278323016035726 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.6922027958164487 0.0; 0.0 0.7612376853085423 … 0.3976698840769388 0.25343508034444906]
 [0.0 0.0 … 0.0 0.42019421283858915; 0.10924015393968323 0.0 … 0.0 0.16402214917104352; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.133277946077911 … 0.13495477438148262 0.025302681635060664]
 [0.1992246690859628 0.0 … 0.0 0.0; 0.0 0.4738134905769972 … 0.4811916366602632 0.3804090220222316; … ; 0.11457163433824302 0.0 … 0.18999145873861523 0.11737659209016293; 0.46106176448341996 0.6758246238955057 … 0.5063864615455899 0.7310601149694442]
 [0.8413880633208418 0.0 … 0.0 0.3939576706853526; 0.0 0.0 … 0.12494320725250163 0.0; … ; 0.0 0.0 … 0.0 0.19862443004228758; 0.0 0.5544977893800244 … 0.8279362920022149 0.0]
 [0.44053506229449335 0.0 … 0.2179508744854679 0.0; 0.0 0.6619669811760553 … 0.0 0.0; … ; 0.17141636007065275 0.7323948168825462 … 0.47576030375148604 0.99488

In [7]:
data_partitioned

10-element Vector{Matrix{Float64}}:
 [0.1135097940037636 0.5120449892480132 … 0.7304379749388124 0.26078544265830106; 0.9278323016035726 0.027396001205986842 … 0.3232769232617778 0.45784441734009584; … ; 0.27993798235931044 0.7748252016194123 … 0.6922027958164487 0.823855162463647; 0.24210058444620852 0.7612376853085423 … 0.3976698840769388 0.25343508034444906]
 [0.8349576218004329 0.609235842707443 … 0.9951783080626506 0.42019421283858915; 0.10924015393968323 0.7765302919061131 … 0.20789509562543518 0.16402214917104352; … ; 0.8835015085933419 0.3413812586781684 … 0.8710155537075489 0.8476867542075555; 0.5064129054028446 0.133277946077911 … 0.13495477438148262 0.025302681635060664]
 [0.1992246690859628 0.2910537990732267 … 0.6354674501102784 0.2463821301745045; 0.08559836855059666 0.4738134905769972 … 0.4811916366602632 0.3804090220222316; … ; 0.11457163433824302 0.833782529901623 … 0.18999145873861523 0.11737659209016293; 0.46106176448341996 0.6758246238955057 … 0.5063864615455899 0.7

In [8]:
data_corrupted_partitioned

10-element Vector{Matrix{Float64}}:
 [0.1135097940037636 0.0 … 0.7304379749388124 0.26078544265830106; 0.9278323016035726 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.6922027958164487 0.0; 0.0 0.7612376853085423 … 0.3976698840769388 0.25343508034444906]
 [0.0 0.0 … 0.0 0.42019421283858915; 0.10924015393968323 0.0 … 0.0 0.16402214917104352; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.133277946077911 … 0.13495477438148262 0.025302681635060664]
 [0.1992246690859628 0.0 … 0.0 0.0; 0.0 0.4738134905769972 … 0.4811916366602632 0.3804090220222316; … ; 0.11457163433824302 0.0 … 0.18999145873861523 0.11737659209016293; 0.46106176448341996 0.6758246238955057 … 0.5063864615455899 0.7310601149694442]
 [0.8413880633208418 0.0 … 0.0 0.3939576706853526; 0.0 0.0 … 0.12494320725250163 0.0; … ; 0.0 0.0 … 0.0 0.19862443004228758; 0.0 0.5544977893800244 … 0.8279362920022149 0.0]
 [0.44053506229449335 0.0 … 0.2179508744854679 0.0; 0.0 0.6619669811760553 … 0.0 0.0; … ; 0.17141636007065275 0.7323948168825462 … 0.47576030375148604 0.99488

In [9]:
# Define model
encoder = Dense(2000, 50, σ)
decoder = Dense(50, 2000, σ)
m = Chain(encoder, decoder)

Chain(
  Dense(2000 => 50, σ),                 [90m# 100_050 parameters[39m
  Dense(50 => 2000, σ),                 [90m# 102_000 parameters[39m
) [90m                  # Total: 4 arrays, [39m202_050 parameters, 789.508 KiB.

In [10]:
# Defining the loss function
#loss(x, y) = Flux.crossentropy(m(x), y)
#loss(x, y) = Flux.binarycrossentropy(m(x)[1], y[1])
loss(x, y) = Flux.binarycrossentropy(m(x), y)

loss (generic function with 1 method)

In [11]:
# Defining the optimiser
opt = ADAM()

ps = Flux.params(m)

# Train
Flux.train!(loss,ps, zip(data_corrupted_partitioned, data_partitioned), opt)

In [12]:
data_corrupted_partitioned[1][1:10]

10-element Vector{Float64}:
 0.1135097940037636
 0.9278323016035726
 0.0
 0.3174448977386146
 0.7186403626847977
 0.0
 0.08003786229039134
 0.0
 0.17151748252115284
 0.6382002871339625

In [13]:
data_partitioned[1][1:10]

10-element Vector{Float64}:
 0.1135097940037636
 0.9278323016035726
 0.39689086658864714
 0.3174448977386146
 0.7186403626847977
 0.009756725253390996
 0.08003786229039134
 0.8893175014411373
 0.17151748252115284
 0.6382002871339625

In [14]:
data_restored = m(data_corrupted_partitioned[1])[1:10]

10-element Vector{Float64}:
 0.5015006687151287
 0.4918030610143283
 0.5025799877360617
 0.5030997505284625
 0.497593531243816
 0.49668303267762665
 0.4949695432189623
 0.5020046004136498
 0.4951998066681611
 0.5029165954410296

In [15]:
entropy = Flux.crossentropy(m(data_corrupted_partitioned[1]),data_partitioned[1])

695.3510418677977

In [16]:
entropy = Flux.binarycrossentropy(m(data_corrupted_partitioned[1]),data_partitioned[1])

0.6923299688909956