In [1]:
using MLDatasets
using Zarr
using NPZ
using Flux
using Flux: @functor, chunk, params
using Flux.Data: DataLoader
using Parameters: @with_kw
using BSON
using CUDA
using Images
using Logging: with_logger
using ProgressMeter: Progress, next!
using TensorBoardLogger: TBLogger, tb_overwrite
using Random
using Statistics
using DifferentialEquations
using Plots

In [97]:
function GaussianFourierProjection(embed_dim, scale)
    # Instantiate W once
    W = randn(Float32, embed_dim ÷ 2) .* scale
    # Return a function that always references the same W
    function GaussFourierProject(t)
        t_proj = t' .* W * Float32(2π)
        [sin.(t_proj); cos.(t_proj)]
    end
end

struct UNet
    layers::NamedTuple
end

marginal_prob_std(t, sigma=25.0f0) = sqrt.((sigma .^ (2t) .- 1.0f0) ./ 2.0f0 ./ log(sigma))

#User Facing API for UNet architecture.
function UNet(channels=[32, 64, 128, 256], embed_dim=256, scale=30.0f0)
    return UNet((
        gaussfourierproj=GaussianFourierProjection(embed_dim, scale),
        linear=Dense(embed_dim, embed_dim, swish),
        # Encoding
        conv1=Conv((3, 3), 1 => channels[1], stride=1, bias=false),
        dense1=Dense(embed_dim, channels[1]),
        gnorm1=GroupNorm(channels[1], 4, swish),
        conv2=Conv((3, 3), channels[1] => channels[2], stride=2, bias=false),
        dense2=Dense(embed_dim, channels[2]),
        gnorm2=GroupNorm(channels[2], 32, swish),
        conv3=Conv((3, 3), channels[2] => channels[3], stride=2, bias=false),
        dense3=Dense(embed_dim, channels[3]),
        gnorm3=GroupNorm(channels[3], 32, swish),
        conv4=Conv((3, 3), channels[3] => channels[4], stride=2, bias=false),
        dense4=Dense(embed_dim, channels[4]),
        gnorm4=GroupNorm(channels[4], 32, swish),
        # Decoding
        tconv4=ConvTranspose((3, 3), channels[4] => channels[3], stride=2, bias=false),
        dense5=Dense(embed_dim, channels[3]),
        tgnorm4=GroupNorm(channels[3], 32, swish),
        tconv3=ConvTranspose((3, 3), channels[3] + channels[3] => channels[2], pad=(0, -1, 0, -1), stride=2, bias=false),
        dense6=Dense(embed_dim, channels[2]),
        tgnorm3=GroupNorm(channels[2], 32, swish),
        tconv2=ConvTranspose((3, 3), channels[2] + channels[2] => channels[1], pad=(0, -1, 0, -1), stride=2, bias=false),
        dense7=Dense(embed_dim, channels[1]),
        tgnorm2=GroupNorm(channels[1], 32, swish),
        tconv1=ConvTranspose((3, 3), channels[1] + channels[1] => 1, stride=1, bias=false),
    ))
end

@functor UNet

# helper to expand dims, similar to tensorflow expand dims
expand_dims(x::AbstractVecOrMat, dims::Int=2) = reshape(x, (ntuple(i -> 1, dims)..., size(x)...))

#the UNet struct callable and shows an example of a "Functional" API for modeling in Flux. \n
function (unet::UNet)(x, t)
    # Embedding
    embed = unet.layers.gaussfourierproj(t)
    embed = unet.layers.linear(embed)
    # Encoder
    h1 = unet.layers.conv1(x)
    h1 = h1 .+ expand_dims(unet.layers.dense1(embed), 2)
    h1 = unet.layers.gnorm1(h1)
    h2 = unet.layers.conv2(h1)
    h2 = h2 .+ expand_dims(unet.layers.dense2(embed), 2)
    h2 = unet.layers.gnorm2(h2)
    h3 = unet.layers.conv3(h2)
    h3 = h3 .+ expand_dims(unet.layers.dense3(embed), 2)
    h3 = unet.layers.gnorm3(h3)
    h4 = unet.layers.conv4(h3)
    h4 = h4 .+ expand_dims(unet.layers.dense4(embed), 2)
    h4 = unet.layers.gnorm4(h4)
    # Decoder
    h = unet.layers.tconv4(h4)
    h = h .+ expand_dims(unet.layers.dense5(embed), 2)
    h = unet.layers.tgnorm4(h)
    h = unet.layers.tconv3(cat(h, h3; dims=3))
    h = h .+ expand_dims(unet.layers.dense6(embed), 2)
    h = unet.layers.tgnorm3(h)
    h = unet.layers.tconv2(cat(h, h2, dims=3))
    h = h .+ expand_dims(unet.layers.dense7(embed), 2)
    h = unet.layers.tgnorm2(h)
    h = unet.layers.tconv1(cat(h, h1, dims=3))
    # Scaling Factor
    h ./ expand_dims(marginal_prob_std(t), 3)
end

function model_loss(model, x, ϵ=1.0f-5)
    batch_size = size(x)[end]
    # (batch) of random times to approximate 𝔼[⋅] wrt. 𝘪 ∼ 𝒰(0, 𝘛)
    random_t = rand!(similar(x, batch_size)) .* (1.0f0 - ϵ) .+ ϵ
    # (batch) of perturbations to approximate 𝔼[⋅] wrt. 𝘹(0) ∼ 𝒫₀(𝘹)
    z = randn!(similar(x))
    std = expand_dims(marginal_prob_std(random_t), 3)
    # (batch) of perturbed 𝘹(𝘵)'s to approximate 𝔼 wrt. 𝘹(t) ∼ 𝒫₀ₜ(𝘹(𝘵)|𝘹(0))
    perturbed_x = x + z .* std
    # 𝘚₀(𝘹(𝘵), 𝘵)
    score = model(perturbed_x, random_t)
    # mean over batches
    mean(
        # L₂ norm over WHC dimensions
        sum((score .* std + z) .^ 2; dims=1:(ndims(x) - 1))
    )
end

model_loss (generic function with 2 methods)

In [106]:
#Helper function from DrWatson.jl to convert a struct to a dict
function struct2dict(::Type{DT}, s) where {DT<:AbstractDict}
    DT(x => getfield(s, x) for x in fieldnames(typeof(s)))
end
struct2dict(s) = struct2dict(Dict, s)

# arguments for the `train` function 
@with_kw mutable struct Args
    η = 1e-4                                        # learning rate
    batch_size = 32                                 # batch size
    epochs = 50                                     # number of epochs
    seed = 1                                        # random seed
    cuda = false                                    # use CPU
    verbose_freq = 10                               # logging for every verbose_freq iterations
    tblogger = true                                 # log training with tensorboard
    save_path = "output"                            # results path
end

function train(; kws...)
    # load hyperparamters
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)

    # GPU config
    if args.cuda && CUDA.has_cuda()
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    # load MNIST images
    xtrain = npzread("data/raw/mnist.npy")
    loader = DataLoader((xtrain), batchsize=32, shuffle=true)
    # initialize UNet model
    unet = UNet() |> device
    # ADAM optimizer
    opt = ADAM(args.η)
    # parameters
    ps = Flux.params(unet)
    !ispath(args.save_path) && mkpath(args.save_path)

    # logging by TensorBoard.jl
    if args.tblogger
        tblogger = TBLogger(args.save_path, tb_overwrite)
    end

    # Training
    train_steps = 0
    @info "Start Training, total $(args.epochs) epochs"
    for epoch = 1:args.epochs
        @info "Epoch $(epoch)"
        progress = Progress(length(loader))

        for x in loader
            x = device(x)
            loss, grad = Flux.withgradient(ps) do
                model_loss(unet, x)
            end
            Flux.Optimise.update!(opt, ps, grad)
            # progress meter
            next!(progress; showvalues=[(:loss, loss)])

            # logging with TensorBoard
            if args.tblogger && train_steps % args.verbose_freq == 0
                with_logger(tblogger) do
                    @info "train" loss = loss
                end
            end
            train_steps += 1
        end
    end

    # save model
    model_path = joinpath(args.save_path, "model.bson")
    let unet = cpu(unet), args = struct2dict(args)
        BSON.@save model_path unet args
        @info "Model saved: $(model_path)"
    end
end

#if abspath(PROGRAM_FILE) == @__FILE__
    #train()
#end

train (generic function with 1 method)

In [107]:
train()

┌ Info: Training on CPU
└ @ Main /Users/jamesfranke/Documents/julia/tc/diffusion_model.ipynb:30
┌ Info: Start Training, total 50 epochs
└ @ Main /Users/jamesfranke/Documents/julia/tc/diffusion_model.ipynb:51


┌ Info: Epoch 1
└ @ Main /Users/jamesfranke/Documents/julia/tc/diffusion_model.ipynb:53


[32mProgress:   0%|                                         |  ETA: 5:15:11[39m[K
[34m  loss:  1209.433[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 3:40:37[39m[K
[34m  loss:  1117.2347[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 2:52:57[39m[K
[34m  loss:  1148.6041[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 2:21:31[39m[K
[34m  loss:  1097.9457[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 2:00:36[39m[K
[34m  loss:  992.6916[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 1:46:27[39m[K
[34m  loss:  916.779[39m[K[A


[K[A[32mProgress:   0%|▏                                        |  ETA: 1:37:57[39m[K
[34m  loss:  970.44196[39m[K[A


[K[A[32mProgress:   0%|▎                                        |  ETA: 1:28:59[39m[K
[34m  loss:  904.6869[39m[K[A


[K[A[32mProgress:   1%|▎                                        |  ETA: 1:23:03[39m[K
[34m  loss:  913.7043[39m[K[A


[K[A[32mProgress:   1%|▎                                        |  ETA: 1:16:55[39m[K
[34m  loss:  829.19977[39m[K[A


[K[A[32mProgress:   1%|▎                                        |  ETA: 1:11:57[39m[K
[34m  loss:  813.6649[39m[K[A


[K[A[32mProgress:   1%|▎                                        |  ETA: 1:07:30[39m[K
[34m  loss:  782.1341[39m[K[A


[K[A[32mProgress:   1%|▎                                        |  ETA: 1:04:42[39m[K
[34m  loss:  771.20404[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 1:01:33[39m[K
[34m  loss:  740.91675[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 0:58:56[39m[K
[34m  loss:  738.5127[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 0:56:15[39m[K
[34m  loss:  722.12787[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 0:53:49[39m

[K
[34m  loss:  713.568[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 0:51:47[39m[K
[34m  loss:  695.28455[39m[K[A


[K[A[32mProgress:   1%|▍                                        |  ETA: 0:50:14[39m[K
[34m  loss:  629.6488[39m[K[A


[K[A[32mProgress:   1%|▌                                        |  ETA: 0:48:36[39m[K
[34m  loss:  606.98645[39m[K[A


[K[A[32mProgress:   1%|▌                                        |  ETA: 0:47:02[39m[K
[34m  loss:  614.9315[39m[K[A


[K[A[32mProgress:   1%|▌                                        |  ETA: 0:45:39[39m[K
[34m  loss:  643.95197[39m[K[A


[K[A[32mProgress:   1%|▌                                        |  ETA: 0:44:19[39m[K
[34m  loss:  628.9923[39m[K[A


[K[A[32mProgress:   1%|▌                                        |  ETA: 0:43:02[39m[K
[34m  loss:  597.2251[39m[K[A


[K[A[32mProgress:   1%|▋                                        |  ETA: 0:41:54[39m[K
[34m  loss:  582.2612[39m[K[A


[K[A[32mProgress:   1%|▋                                        |  ETA: 0:40:54[39m[K
[34m  loss:  553.5307[39m[K[A


[K[A[32mProgress:   1%|▋                                        |  ETA: 0:40:10[39m[K
[34m  loss:  565.2458[39m[K[A


[K[A[32mProgress:   2%|▋                                        |  ETA: 0:39:13[39m[K
[34m  loss:  537.68384[39m[K[A


[K[A[32mProgress:   2%|▋                                        |  ETA: 0:38:18[39m[K
[34m  loss:  515.405[39m[K[A


[K[A[32mProgress:   2%|▋                                        |  ETA: 0:37:27[39m[K
[34m  loss:  526.2335[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:36:40[39m[K
[34m  loss:  523.2399[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:36:01[39m[K
[34m  loss:  493.09097[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:35:21[39m[K
[34m  loss:  494.51514[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:34:46[39m[K
[34m  loss:  516.01086[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:34:09[39m[K
[34m  loss:  474.1456[39m[K[A


[K[A[32mProgress:   2%|▊                                        |  ETA: 0:33:38[39m[K
[34m  loss:  477.0858[39m[K[A


[K[A[32mProgress:   2%|▉                                        |  ETA: 0:33:06[39m[K
[34m  loss:  512.4031[39m[K[A


[K[A[32mProgress:   2%|▉                                        |  ETA: 0:32:36[39m

[K
[34m  loss:  485.2993[39m[K[A


[K[A[32mProgress:   2%|▉                                        |  ETA: 0:32:08[39m[K
[34m  loss:  451.96362[39m[K[A


[K[A[32mProgress:   2%|▉                                        |  ETA: 0:31:38[39m[K
[34m  loss:  466.40662[39m[K[A


[K[A[32mProgress:   2%|▉                                        |  ETA: 0:31:13[39m[K
[34m  loss:  470.36688[39m[K[A


[K[A[32mProgress:   2%|█                                        |  ETA: 0:30:45[39m[K
[34m  loss:  391.21432[39m[K[A


[K[A[32mProgress:   2%|█                                        |  ETA: 0:30:20[39m[K
[34m  loss:  458.78183[39m[K[A


[K[A[32mProgress:   2%|█                                        |  ETA: 0:30:11[39m[K
[34m  loss:  401.74557[39m[K[A


[K[A[32mProgress:   2%|█                                        |  ETA: 0:29:49[39m[K
[34m  loss:  413.5223[39m[K[A


[K[A[32mProgress:   3%|█                                        |  ETA: 0:29:32[39m[K
[34m  loss:  421.00958[39m[K[A


[K[A[32mProgress:   3%|█                                        |  ETA: 0:29:15[39m[K
[34m  loss:  440.43854[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:28:53[39m[K
[34m  loss:  398.06705[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:28:39[39m[K
[34m  loss:  408.57312[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:28:20[39m[K
[34m  loss:  394.7592[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:28:02[39m[K
[34m  loss:  366.50354[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:27:44[39m[K
[34m  loss:  385.62244[39m[K[A


[K[A[32mProgress:   3%|█▏                                       |  ETA: 0:27:29[39m[K
[34m  loss:  422.5066[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:27:14[39m[K
[34m  loss:  377.45737[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:26:57[39m[K
[34m  loss:  361.8951[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:26:44[39m[K
[34m  loss:  382.53058[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:26:29[39m[K
[34m  loss:  402.56528[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:26:16[39m[K
[34m  loss:  393.73923[39m[K[A


[K[A[32mProgress:   3%|█▎                                       |  ETA: 0:26:11[39m[K
[34m  loss:  351.73862[39m[K[A


[K[A[32mProgress:   3%|█▍                                       |  ETA: 0:25:57[39m[K
[34m  loss:  332.40994[39m[K[A


[K[A[32mProgress:   3%|█▍                                       |  ETA: 0:25:45[39m[K
[34m  loss:  324.59314[39m[K[A


[K[A[32mProgress:   3%|█▍                                       |  ETA: 0:25:34[39m[K
[34m  loss:  344.2359[39m[K[A


[K[A[32mProgress:   3%|█▍                                       |  ETA: 0:25:22[39m[K
[34m  loss:  375.58124[39m[K[A


[K[A[32mProgress:   3%|█▍                                       |  ETA: 0:25:12[39m[K
[34m  loss:  346.89783[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:25:01[39m[K
[34m  loss:  341.33072[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:24:51[39m[K
[34m  loss:  303.13605[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:24:40[39m[K
[34m  loss:  326.2411[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:24:30[39m[K
[34m  loss:  327.3444[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:24:20[39m[K
[34m  loss:  349.51764[39m[K[A


[K[A[32mProgress:   4%|█▌                                       |  ETA: 0:24:12[39m[K
[34m  loss:  388.54144[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:24:04[39m[K
[34m  loss:  360.48868[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:23:53[39m[K
[34m  loss:  362.451[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:23:45[39m[K
[34m  loss:  356.642[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:23:36[39m[K
[34m  loss:  377.27155[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:23:29[39m[K
[34m  loss:  360.02155[39m[K[A


[K[A[32mProgress:   4%|█▋                                       |  ETA: 0:23:22[39m[K
[34m  loss:  300.83633[39m[K[A


[K[A[32mProgress:   4%|█▊                                       |  ETA: 0:23:16[39m[K
[34m  loss:  328.12396[39m[K[A


[K[A[32mProgress:   4%|█▊                                       |  ETA: 0:23:08[39m[K
[34m  loss:  323.3532[39m[K[A


[K[A[32mProgress:   4%|█▊                                       |  ETA: 0:23:01[39m[K
[34m  loss:  299.4128[39m[K[A


[K[A[32mProgress:   4%|█▊                                       |  ETA: 0:22:54[39m[K
[34m  loss:  301.84933[39m[K[A


[K[A[32mProgress:   4%|█▊                                       |  ETA: 0:22:46[39m[K
[34m  loss:  361.20593[39m[K[A


[K[A[32mProgress:   4%|█▉                                       |  ETA: 0:22:40[39m[K
[34m  loss:  339.53375[39m[K[A


[K[A[32mProgress:   4%|█▉                                       |  ETA: 0:22:36[39m[K
[34m  loss:  312.4355[39m[K[A


[K[A[32mProgress:   5%|█▉                                       |  ETA: 0:22:28[39m[K
[34m  loss:  335.00406[39m[K[A


[K[A[32mProgress:   5%|█▉                                       |  ETA: 0:22:20[39m[K
[34m  loss:  347.59186[39m[K[A


[K[A[32mProgress:   5%|█▉                                       |  ETA: 0:22:13[39m[K
[34m  loss:  306.30594[39m[K[A


[K[A[32mProgress:   5%|█▉                                       |  ETA: 0:22:07[39m[K
[34m  loss:  325.16077[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:22:01[39m[K
[34m  loss:  302.80145[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:21:55[39m[K
[34m  loss:  315.90134[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:21:50[39m[K
[34m  loss:  297.84998[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:21:47[39m[K
[34m  loss:  297.3059[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:21:41[39m[K
[34m  loss:  329.12704[39m[K[A


[K[A[32mProgress:   5%|██                                       |  ETA: 0:21:35[39m[K
[34m  loss:  287.7255[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:29[39m[K
[34m  loss:  284.52542[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:25[39m[K
[34m  loss:  280.16806[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:20[39m[K
[34m  loss:  306.90875[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:14[39m[K
[34m  loss:  285.20233[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:09[39m[K
[34m  loss:  323.9697[39m[K[A


[K[A[32mProgress:   5%|██▏                                      |  ETA: 0:21:03[39m[K
[34m  loss:  315.60718[39m[K[A


[K[A[32mProgress:   5%|██▎                                      |  ETA: 0:20:58[39m[K
[34m  loss:  299.42065[39m[K[A


[K[A[32mProgress:   5%|██▎                                      |  ETA: 0:20:54[39m[K
[34m  loss:  274.30035[39m[K[A


[K[A[32mProgress:   5%|██▎                                      |  ETA: 0:20:49[39m[K
[34m  loss:  327.78046[39m[K[A


[K[A[32mProgress:   6%|██▎                                      |  ETA: 0:20:45[39m[K
[34m  loss:  328.1895[39m[K[A


[K[A[32mProgress:   6%|██▎                                      |  ETA: 0:20:39[39m[K
[34m  loss:  247.76385[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:35[39m[K
[34m  loss:  242.64917[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:31[39m[K
[34m  loss:  299.13492[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:27[39m[K
[34m  loss:  300.62738[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:22[39m[K
[34m  loss:  303.31296[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:17[39m[K
[34m  loss:  252.79395[39m[K[A


[K[A[32mProgress:   6%|██▍                                      |  ETA: 0:20:13[39m[K
[34m  loss:  278.21304[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:20:08[39m[K
[34m  loss:  254.67186[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:20:03[39m[K
[34m  loss:  307.56793[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:19:58[39m[K
[34m  loss:  333.43546[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:19:54[39m[K
[34m  loss:  283.70398[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:19:51[39m[K
[34m  loss:  266.74796[39m[K[A


[K[A[32mProgress:   6%|██▌                                      |  ETA: 0:19:47[39m[K
[34m  loss:  258.26456[39m[K[A


[K[A[32mProgress:   6%|██▋                                      |  ETA: 0:19:42[39m[K
[34m  loss:  268.49188[39m[K[A


[K[A[32mProgress:   6%|██▋                                      |  ETA: 0:19:38[39m[K
[34m  loss:  260.99823[39m[K[A


[K[A[32mProgress:   6%|██▋                                      |  ETA: 0:19:49[39m[K
[34m  loss:  282.38446[39m[K[A


[K[A[32mProgress:   6%|██▋                                      |  ETA: 0:19:48[39m[K
[34m  loss:  216.22693[39m[K[A


[K[A[32mProgress:   7%|██▋                                      |  ETA: 0:19:46[39m[K
[34m  loss:  255.56464[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:45[39m[K
[34m  loss:  234.04504[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:41[39m[K
[34m  loss:  276.03476[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:38[39m[K
[34m  loss:  275.33072[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:36[39m[K
[34m  loss:  288.09204[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:37[39m[K
[34m  loss:  262.5788[39m[K[A


[K[A[32mProgress:   7%|██▊                                      |  ETA: 0:19:39[39m[K
[34m  loss:  252.13496[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:38[39m[K
[34m  loss:  250.42351[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:35[39m[K
[34m  loss:  247.65315[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:32[39m[K
[34m  loss:  292.78827[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:28[39m[K
[34m  loss:  273.6877[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:26[39m[K
[34m  loss:  321.15204[39m[K[A


[K[A[32mProgress:   7%|██▉                                      |  ETA: 0:19:24[39m[K
[34m  loss:  284.08942[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:22[39m[K
[34m  loss:  311.89114[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:19[39m[K
[34m  loss:  250.8934[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:16[39m[K
[34m  loss:  250.53795[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:13[39m[K
[34m  loss:  231.5242[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:10[39m[K
[34m  loss:  246.42018[39m[K[A


[K[A[32mProgress:   7%|███                                      |  ETA: 0:19:07[39m[K
[34m  loss:  273.61157[39m[K[A


[K[A[32mProgress:   8%|███▏                                     |  ETA: 0:19:03[39m[K
[34m  loss:  246.75816[39m[K[A


[K[A[32mProgress:   8%|███▏                                     |  ETA: 0:19:02[39m[K
[34m  loss:  223.82574[39m[K[A


[K[A[32mProgress:   8%|███▏                                     |  ETA: 0:18:59[39m[K
[34m  loss:  223.90025[39m[K[A


[K[A[32mProgress:   8%|███▏                                     |  ETA: 0:18:56[39m[K
[34m  loss:  235.66302[39m[K[A


[K[A[32mProgress:   8%|███▏                                     |  ETA: 0:18:57[39m[K
[34m  loss:  287.79172[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:55[39m[K
[34m  loss:  280.33954[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:52[39m[K
[34m  loss:  242.81027[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:52[39m[K
[34m  loss:  265.59732[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:50[39m[K
[34m  loss:  232.03589[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:48[39m[K
[34m  loss:  277.63538[39m[K[A


[K[A[32mProgress:   8%|███▎                                     |  ETA: 0:18:45[39m[K
[34m  loss:  265.32175[39m[K[A


[K[A

[32mProgress:   8%|███▍                                     |  ETA: 0:18:42[39m[K
[34m  loss:  252.95781[39m[K[A


[K[A[32mProgress:   8%|███▍                                     |  ETA: 0:18:41[39m[K
[34m  loss:  289.66394[39m[K[A


[K[A[32mProgress:   8%|███▍                                     |  ETA: 0:18:38[39m[K
[34m  loss:  218.35185[39m[K[A


[K[A[32mProgress:   8%|███▍                                     |  ETA: 0:18:35[39m[K
[34m  loss:  249.04756[39m[K[A


[K[A[32mProgress:   8%|███▍                                     |  ETA: 0:18:35[39m[K
[34m  loss:  272.88315[39m[K[A


[K[A[32mProgress:   8%|███▍                                     |  ETA: 0:18:33[39m[K
[34m  loss:  227.27098[39m[K[A


[K[A[32mProgress:   8%|███▌                                     |  ETA: 0:18:31[39m[K
[34m  loss:  237.5046[39m[K[A


[K[A[32mProgress:   8%|███▌                                     |  ETA: 0:18:28[39m[K
[34m  loss:  225.61313[39m[K[A


[K[A[32mProgress:   9%|███▌                                     |  ETA: 0:18:26[39m[K
[34m  loss:  229.37128[39m[K[A


[K[A[32mProgress:   9%|███▌                                     |  ETA: 0:18:24[39m[K
[34m  loss:  238.71323[39m[K[A


[K[A[32mProgress:   9%|███▌                                     |  ETA: 0:18:21[39m[K
[34m  loss:  265.73016[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:19[39m[K
[34m  loss:  271.69873[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:16[39m[K
[34m  loss:  243.60526[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:14[39m[K
[34m  loss:  233.30832[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:12[39m[K
[34m  loss:  226.7424[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:09[39m[K
[34m  loss:  243.74077[39m[K[A


[K[A[32mProgress:   9%|███▋                                     |  ETA: 0:18:08[39m[K
[34m  loss:  310.10803[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:18:06[39m[K
[34m  loss:  248.57416[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:18:03[39m[K
[34m  loss:  227.31787[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:18:01[39m[K
[34m  loss:  255.58424[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:17:59[39m[K
[34m  loss:  215.04239[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:17:56[39m[K
[34m  loss:  309.8025[39m[K[A


[K[A[32mProgress:   9%|███▊                                     |  ETA: 0:17:54[39m[K
[34m  loss:  217.17563[39m[K[A


[K[A[32mProgress:   9%|███▉                                     |  ETA: 0:17:52[39m[K
[34m  loss:  276.30606[39m[K[A


[K[A[32mProgress:   9%|███▉                                     |  ETA: 0:17:50[39m[K
[34m  loss:  230.06114[39m[K[A


[K[A[32mProgress:   9%|███▉                                     |  ETA: 0:17:48[39m[K
[34m  loss:  228.50629[39m[K[A


[K[A[32mProgress:   9%|███▉                                     |  ETA: 0:17:46[39m[K
[34m  loss:  233.68796[39m[K[A


[K[A[32mProgress:  10%|███▉                                     |  ETA: 0:17:44[39m[K
[34m  loss:  227.52588[39m[K[A


[K[A[32mProgress:  10%|███▉                                     |  ETA: 0:17:42[39m[K
[34m  loss:  242.84108[39m[K[A


[K[A[32mProgress:  10%|████                                     |  ETA: 0:17:39[39m[K
[34m  loss:  254.48213[39m[K[A


[K[A[32mProgress:  10%|████                                     |  ETA: 0:17:37[39m[K
[34m  loss:  197.70232[39m[K[A


[K[A[32mProgress:  10%|████                                     |  ETA: 0:17:35[39m[K
[34m  loss:  243.00919[39m[K[A


[K[A[32mProgress:  10%|████                                     |  ETA: 0:17:32[39m[K
[34m  loss:  218.99002[39m[K[A


[K[A[32mProgress:  10%|████                                     |  ETA: 0:17:30[39m[K
[34m  loss:  235.80836[39m[K[A


[K[A

[32mProgress:  10%|████▏                                    |  ETA: 0:17:27[39m[K
[34m  loss:  253.97264[39m[K[A


[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:17:25[39m[K
[34m  loss:  224.93333[39m[K[A


[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:17:23[39m[K
[34m  loss:  237.95456[39m[K[A


[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:17:22[39m[K
[34m  loss:  254.91446[39m[K[A


[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:17:20[39m[K
[34m  loss:  219.26436[39m[K[A


[K[A[32mProgress:  10%|████▏                                    |  ETA: 0:17:18[39m[K
[34m  loss:  231.89688[39m[K[A


[K[A[32mProgress:  10%|████▎                                    |  ETA: 0:17:19[39m[K
[34m  loss:  216.63466[39m[K[A


[K[A[32mProgress:  10%|████▎                                    |  ETA: 0:17:17[39m[K
[34m  loss:  218.32556[39m[K[A


[K[A[32mProgress:  10%|████▎                                    |  ETA: 0:17:15[39m[K
[34m  loss:  219.06195[39m[K[A


[K[A[32mProgress:  10%|████▎                                    |  ETA: 0:17:14[39m[K
[34m  loss:  237.36206[39m[K[A


[K[A[32mProgress:  10%|████▎                                    |  ETA: 0:17:14[39m[K
[34m  loss:  270.89133[39m[K[A


[K[A[32mProgress:  11%|████▎                                    |  ETA: 0:17:13[39m[K
[34m  loss:  197.24146[39m[K[A


[K[A[32mProgress:  11%|████▍                                    |  ETA: 0:17:12[39m[K
[34m  loss:  266.391[39m[K[A


[K[A[32mProgress:  11%|████▍                                    |  ETA: 0:17:10[39m[K
[34m  loss:  224.48318[39m[K[A


[K[A[32mProgress:  11%|████▍                                    |  ETA: 0:17:09[39m[K
[34m  loss:  265.13754[39m[K[A


[K[A[32mProgress:  11%|████▍                                    |  ETA: 0:17:08[39m[K
[34m  loss:  229.78603[39m[K[A


[K[A[32mProgress:  11%|████▍                                    |  ETA: 0:17:06[39m[K
[34m  loss:  232.21295[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:17:05[39m[K
[34m  loss:  254.02402[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:17:04[39m[K
[34m  loss:  199.71013[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:17:03[39m[K
[34m  loss:  264.74997[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:17:01[39m[K
[34m  loss:  209.51794[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:16:59[39m[K
[34m  loss:  230.59361[39m[K[A


[K[A[32mProgress:  11%|████▌                                    |  ETA: 0:16:58[39m[K
[34m  loss:  231.85463[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:58[39m[K
[34m  loss:  234.04688[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:59[39m[K
[34m  loss:  211.07843[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:57[39m[K
[34m  loss:  218.7608[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:56[39m[K
[34m  loss:  219.69537[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:55[39m[K
[34m  loss:  196.11159[39m[K[A


[K[A[32mProgress:  11%|████▋                                    |  ETA: 0:16:53[39m[K
[34m  loss:  203.58356[39m[K[A


[K[A[32mProgress:  11%|████▊                                    |  ETA: 0:16:53[39m[K
[34m  loss:  266.68[39m[K[A


[K[A[32mProgress:  12%|████▊                                    |  ETA: 0:16:52[39m[K
[34m  loss:  226.64072[39m[K[A


[K[A[32mProgress:  12%|████▊                                    |  ETA: 0:16:50[39m[K
[34m  loss:  213.95406[39m[K[A

In [100]:
xtrain = npzread("data/raw/mnist.npy")
size(xtrain)

(28, 28, 1, 60000)

In [None]:
#Helper function yielding the diffusion coefficient from a SDE.
diffusion_coeff(t, sigma=convert(eltype(t), 25.0f0)) = sigma .^ t

#Helper function that produces images from a batch of images.
function convert_to_image(x, y_size)
    Gray.(permutedims(vcat(reshape.(chunk(x |> cpu, y_size), 28, :)...), (2, 1)))
end

# Helper to make an animation from a batch of images.
function convert_to_animation(x)
    frames = size(x)[end]
    batches = size(x)[end-1]
    animation = @animate for i = 1:frames+frames÷4
        if i <= frames
            heatmap(
                convert_to_image(x[:, :, :, :, i], batches),
                title="Iteration: $i out of $frames"
            )
        else
            heatmap(
                convert_to_image(x[:, :, :, :, end], batches),
                title="Iteration: $frames out of $frames"
            )
        end
    end
    return animation
end

#Helper function that generates inputs to a sampler.
function setup_sampler(device, num_images=5, num_steps=500, ϵ=1.0f-3)
    t = ones(Float32, num_images) |> device
    init_x = (
        randn(Float32, (28, 28, 1, num_images)) .*
        expand_dims(marginal_prob_std(t), 3)
    ) |> device
    time_steps = LinRange(1.0f0, ϵ, num_steps)
    Δt = time_steps[1] - time_steps[2]
    return time_steps, Δt, init_x
end

function DifferentialEquations_problem(model, init_x, time_steps, Δt)
    function f(u, p, t)
        batch_time_step = fill!(similar(u, size(u)[end]), 1) .* t
        return (
            -expand_dims(diffusion_coeff(batch_time_step), 3) .^ 2 .*
            model(u, batch_time_step)
        )
    end

    function g(u, p, t)
        batch_time_step = fill!(similar(u), 1) .* t
        diffusion_coeff(batch_time_step)
    end
    tspan = (time_steps[begin], time_steps[end])
    SDEProblem(f, g, init_x, tspan), ODEProblem(f, init_x, tspan)
end

function plot_result(unet, args)
    args = Args(; args...)
    args.seed > 0 && Random.seed!(args.seed)
    device = args.cuda && CUDA.has_cuda() ? gpu : cpu
    unet = unet |> device
    time_steps, Δt, init_x = setup_sampler(device)
    
    # Setup an SDEProblem and ODEProblem to input to `solve()`.
    # Use dt=Δt to make the sample paths comparable to calculating "by hand".
    sde_problem, ode_problem = DifferentialEquations_problem(unet, init_x, time_steps, Δt)

    @info "Probability Flow ODE Sampling w/ DifferentialEquations.jl"
    diff_eq_ode = solve(ode_problem, dt=Δt, adaptive=false)
    diff_eq_ode_end = diff_eq_ode[:, :, :, :, end]
    diff_eq_ode_images = convert_to_image(diff_eq_ode_end, size(diff_eq_ode_end)[end])
    save(joinpath(args.save_path, "diff_eq_ode_images.jpeg"), diff_eq_ode_images)
    diff_eq_ode_animation = convert_to_animation(diff_eq_ode)
    gif(diff_eq_ode_animation, joinpath(args.save_path, "diff_eq_ode.gif"), fps=50)
    ode_plot = plot(diff_eq_ode, title="Probability Flow ODE", legend=false, ylabel="x", la=0.25)
    plot!(time_steps, diffusion_coeff(time_steps), xflip=true, ls=:dash, lc=:red)
    plot!(time_steps, -diffusion_coeff(time_steps), xflip=true, ls=:dash, lc=:red)
    savefig(ode_plot, joinpath(args.save_path, "diff_eq_ode_plot.png"))
end

if abspath(PROGRAM_FILE) == @__FILE__
    ############################################################################
    # Issue loading function closures with BSON:
    # https://github.com/JuliaIO/BSON.jl/issues/69
    #
    BSON.@load "output/model.bson" unet args
    #
    # BSON.@load does not work if defined inside plot_result(⋅) because
    # it contains a function closure, GaussFourierProject(⋅), containing W.
    ###########################################################################
    plot_result(unet, args)
end

# OLD

In [None]:
    # Predictor Corrector
    pc = predictor_corrector_sampler(unet, init_x, time_steps, Δt)
    pc_images = convert_to_image(pc, size(pc)[end])
    save(joinpath(args.save_path, "pc_images.jpeg"), pc_images)

    # Euler-Maruyama
    euler_maruyama = Euler_Maruyama_sampler(unet, init_x, time_steps, Δt)
    sampled_noise = convert_to_image(init_x, size(init_x)[end])
    save(joinpath(args.save_path, "sampled_noise.jpeg"), sampled_noise)
    em_images = convert_to_image(euler_maruyama, size(euler_maruyama)[end])
    save(joinpath(args.save_path, "em_images.jpeg"), em_images)

    
    @info "Euler-Maruyama Sampling w/ DifferentialEquations.jl"
    diff_eq_em = solve(sde_problem, EM(), dt=Δt)
    diff_eq_em_end = diff_eq_em[:, :, :, :, end]
    diff_eq_em_images = convert_to_image(diff_eq_em_end, size(diff_eq_em_end)[end])
    save(joinpath(args.save_path, "diff_eq_em_images.jpeg"), diff_eq_em_images)
    diff_eq_em_animation = convert_to_animation(diff_eq_em)
    gif(diff_eq_em_animation, joinpath(args.save_path, "diff_eq_em.gif"), fps=50)
    em_plot = plot(diff_eq_em, title="Euler-Maruyama", legend=false, ylabel="x", la=0.25)
    plot!(time_steps, diffusion_coeff(time_steps), xflip=true, ls=:dash, lc=:red)
    plot!(time_steps, -diffusion_coeff(time_steps), xflip=true, ls=:dash, lc=:red)
    savefig(em_plot, joinpath(args.save_path, "diff_eq_em_plot.png"))
    
"""
Sample from a diffusion model using the Euler-Maruyama method.
# References
https://yang-song.github.io/blog/2021/score/#how-to-solve-the-reverse-sde
"""
function Euler_Maruyama_sampler(model, init_x, time_steps, Δt)
    x = mean_x = init_x
    @showprogress "Euler-Maruyama Sampling" for time_step in time_steps
        batch_time_step = fill!(similar(init_x, size(init_x)[end]), 1) .* time_step
        g = diffusion_coeff(batch_time_step)
        mean_x = x .+ expand_dims(g, 3) .^ 2 .* model(x, batch_time_step) .* Δt
        x = mean_x .+ sqrt(Δt) .* expand_dims(g, 3) .* randn(Float32, size(x))
    end
    return mean_x
end

"""
Sample from a diffusion model using the Predictor-Corrector method.
# References
https://yang-song.github.io/blog/2021/score/#how-to-solve-the-reverse-sde
"""
function predictor_corrector_sampler(model, init_x, time_steps, Δt, snr=0.16f0)
    x = mean_x = init_x
    @showprogress "Predictor Corrector Sampling" for time_step in time_steps
        batch_time_step = fill!(similar(init_x, size(init_x)[end]), 1) .* time_step
        # Corrector step (Langevin MCMC)
        grad = model(x, batch_time_step)
        num_pixels = prod(size(grad)[1:end-1])
        grad_batch_vector = reshape(grad, (size(grad)[end], num_pixels))
        grad_norm = mean(sqrt, sum(abs2, grad_batch_vector, dims=2))
        noise_norm = Float32(sqrt(num_pixels))
        langevin_step_size = 2 * (snr * noise_norm / grad_norm)^2
        x += (
            langevin_step_size .* grad .+
            sqrt(2 * langevin_step_size) .* randn(Float32, size(x))
        )
        # Predictor step (Euler-Maruyama)
        g = diffusion_coeff(batch_time_step)
        mean_x = x .+ expand_dims((g .^ 2), 3) .* model(x, batch_time_step) .* Δt
        x = mean_x + sqrt.(expand_dims((g .^ 2), 3) .* Δt) .* randn(Float32, size(x))
    end
    return mean_x
end

"""
Helper to create a SDEProblem with DifferentialEquations.jl
# Notes
The reverse-time SDE is given by:  
𝘥x = -σ²ᵗ 𝘚₀(𝙭, 𝘵)𝘥𝘵 + σᵗ𝘥𝘸  
⟹ `f(u, p, t)` = -σ²ᵗ 𝘚₀(𝙭, 𝘵)  
⟹ `g(u, p, t` = σᵗ
"""