In [1]:
using Knet
using CUDA
using Images
using MLDatasets: MNIST
import YAML

include("utils/common.jl")
include("src/networks.jl")

In [2]:
function main()

    # ---------------- Configure ----------------
    cfg = YAML.load_file("configs/dcgan_asymmetric_two_stage_mnist.yml")

    # ---------------- Dataset ----------------
    xtrn, _ = MNIST.traindata(Float32);
    xtst, _ = MNIST.testdata(Float32);
    xtrn = imresize(xtrn, (cfg["image_size"],cfg["image_size"]))
    xtst = imresize(xtst, (cfg["image_size"],cfg["image_size"]))
    xtrn = normalize(xtrn)
    xtst = normalize(xtst)
    dtrn = minibatch(xtrn, cfg["batch_size"]; xtype = CuArray)
    dtst = minibatch(xtst, cfg["batch_size"]; xtype = CuArray)
    println("Data Summary:")
    println.(summary.((dtrn,dtst)));

    # ---------------- Network ----------------
    generator = DCGenerator(; zdim = cfg["zdim"], num_channels = cfg["num_channel"])
    discriminator = DCDiscriminator(; num_channels = cfg["num_channel"])


    # ---------------- Optimzier ----------------
    # TODO

    # ---------------- Training ----------------
    zdim = cfg["zdim"]
    z_fix = CuArray(randn(Float64, (100, zdim, 1, 1)))

    x = first(dtst) # Only taking the first minibatch for now, will fix once there is a proper loop
    batch_size = size(x)[end]
    x = reshape(x, (cfg["image_size"],cfg["image_size"],1,batch_size))

    y_real = CuArray(ones(Int32, batch_size))
    y_fake = CuArray(zeros(Int32, batch_size))

    # ------------ Training Discriminator ------------
    z = CuArray(randn(Float64, (1, 1, zdim, batch_size)))
    fake = generator(z)

    pred_real = discriminator(x)
    pred_real = reshape(pred_real, batch_size)
    pred_fake = discriminator(fake)
    pred_fake = reshape(pred_fake, batch_size)

    loss_real = bce(pred_real,y_real)
    loss_fake = bce(pred_fake,y_fake)
    loss_d = loss_real + loss_fake

    # Loss backward, optimizer step

    # ------------ Training Generator ------------
    z = CuArray(randn(Float64, (1, 1, zdim, batch_size)))
    fake = generator(z)
    pred_fake = discriminator(fake)
    pred_fake = reshape(pred_fake, batch_size)

    loss_g = bce(pred_fake,y_real)

    # Loss backward, optimizer step
    return loss_d, loss_g
end

main (generic function with 1 method)

In [3]:
main()

Data Summary:
468-element Knet.Train20.Data{CuArray}
78-element Knet.Train20.Data{CuArray}


(1.4429246974693415, 0.4739218986503221)