In [1]:
using NBInclude
using BSON:@save, @load
using Dates: now
using Flux, CMBLensing
using PyPlot
@nbinclude("wienercode.ipynb")

│ For performance reasons, it is recommended to upgrade to a driver that supports CUDA 11.2 or higher.
└ @ CUDA /global/home/users/sguns/.julia/packages/CUDA/02Kjq/src/initialization.jl:42


WienerNet (generic function with 1 method)

In [2]:
@load "mask_128.bson" mask
mask = reshape(mask, 128, 128, 1, 1) 
mask = Flux.gpu(mask)
println("Loaded mask")


@load "training_data_nsims1000_CMBLensing.bson" sims
println("Loaded data")
batchsize = 8
training = []

sims = Flux.chunk(sims, length(sims)/batchsize) 
for s in sims
    dfs = map(1:8) do i
        Float32.(s[i][1][:Ix][:,:,1,1]), Float32.(s[i][3][:Ix][:,:,1,1])
    end
    xs = cat(first.(dfs)..., dims=4)
    ys = cat(last.(dfs)..., dims=4)
    push!(training, (xs, ys))
end
training = Flux.gpu(training)
println("Batched data")

Loaded mask
Loaded data
Batched data


In [3]:
wn = WienerNet() |> Flux.gpu
model(x::Tuple{AbstractArray,AbstractArray}) = wn(x)
model(x::AbstractArray) = wn((x, mask))
loss_J2(x,d) = Flux.mse(model(x), d);
println("Defined model")

Defined model


In [None]:
opt = ADAM(1e-3)

testx = training[1][1][:,:,1:1,1:1]
testy = training[1][2][:,:,1:1,1:1]

evalcb = Flux.throttle(300) do
  savem = Flux.cpu(wn)
  testloss = loss_J2(testx, testy)
  @show(loss_J2(testx, testy))
  @save "weights/J1/J1_model-$(now()).bson" savem opt testloss
end

Flux.@epochs 100 Flux.Optimise.train!(loss_J2, Flux.params(wn), training, opt, cb=evalcb);


┌ Info: Epoch 1
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135
└ @ NNlibCUDA /global/home/users/sguns/.julia/packages/NNlibCUDA/806f6/src/cudnn/cudnn.jl:10


loss_J2(testx, testy) = 5071.188f0
loss_J2(testx, testy) = 1898.1241f0
loss_J2(testx, testy) = 1785.6174f0
loss_J2(testx, testy) = 1576.6389f0
loss_J2(testx, testy) = 1432.3202f0
loss_J2(testx, testy) = 1127.3846f0
loss_J2(testx, testy) = 768.9429f0
loss_J2(testx, testy) = 737.6905f0
loss_J2(testx, testy) = 697.72504f0


┌ Info: Epoch 2
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 519.5831f0
loss_J2(testx, testy) = 518.4003f0
loss_J2(testx, testy) = 548.07684f0
loss_J2(testx, testy) = 581.6485f0
loss_J2(testx, testy) = 462.4315f0
loss_J2(testx, testy) = 474.18164f0


┌ Info: Epoch 3
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 470.58698f0
loss_J2(testx, testy) = 415.1751f0
loss_J2(testx, testy) = 397.18237f0
loss_J2(testx, testy) = 426.7444f0
loss_J2(testx, testy) = 372.60892f0
loss_J2(testx, testy) = 390.0063f0
loss_J2(testx, testy) = 322.62747f0


┌ Info: Epoch 4
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 332.9165f0
loss_J2(testx, testy) = 354.25095f0
loss_J2(testx, testy) = 331.031f0
loss_J2(testx, testy) = 298.08612f0
loss_J2(testx, testy) = 273.06552f0
loss_J2(testx, testy) = 300.2168f0


┌ Info: Epoch 5
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 291.40762f0
loss_J2(testx, testy) = 253.7258f0
loss_J2(testx, testy) = 252.82295f0
loss_J2(testx, testy) = 264.59186f0
loss_J2(testx, testy) = 224.02197f0
loss_J2(testx, testy) = 247.5953f0


┌ Info: Epoch 6
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 262.1327f0
loss_J2(testx, testy) = 232.68451f0
loss_J2(testx, testy) = 207.04909f0
loss_J2(testx, testy) = 232.16278f0
loss_J2(testx, testy) = 220.68256f0
loss_J2(testx, testy) = 207.74481f0


┌ Info: Epoch 7
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 224.6424f0
loss_J2(testx, testy) = 201.45197f0
loss_J2(testx, testy) = 193.83728f0
loss_J2(testx, testy) = 213.9926f0
loss_J2(testx, testy) = 222.13705f0
loss_J2(testx, testy) = 204.68204f0
loss_J2(testx, testy) = 195.34468f0


┌ Info: Epoch 8
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 189.6658f0
loss_J2(testx, testy) = 199.3237f0
loss_J2(testx, testy) = 206.06075f0
loss_J2(testx, testy) = 208.12163f0
loss_J2(testx, testy) = 171.32196f0
loss_J2(testx, testy) = 215.27852f0


┌ Info: Epoch 9
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 177.12119f0
loss_J2(testx, testy) = 177.34944f0
loss_J2(testx, testy) = 157.52965f0
loss_J2(testx, testy) = 186.01273f0
loss_J2(testx, testy) = 173.7105f0
loss_J2(testx, testy) = 150.87332f0


┌ Info: Epoch 10
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 198.1005f0
loss_J2(testx, testy) = 162.95097f0
loss_J2(testx, testy) = 165.87177f0
loss_J2(testx, testy) = 164.24863f0
loss_J2(testx, testy) = 153.18146f0
loss_J2(testx, testy) = 199.29384f0


┌ Info: Epoch 11
└ @ Main /global/home/users/sguns/.julia/packages/Flux/0c9kI/src/optimise/train.jl:135


loss_J2(testx, testy) = 147.01888f0
loss_J2(testx, testy) = 172.89136f0
loss_J2(testx, testy) = 149.74112f0
loss_J2(testx, testy) = 136.13669f0
