In [None]:
#r "nuget: DiffSharp.Core"
#r "nuget: DiffSharp.Backends.Reference"

In [None]:
open DiffSharp

let t1 = dsharp.tensor [ 0.0 ..0.2.. 1.0 ] // Gives [0., 0.2, 0.4, 0.6, 0.8, 1.]
let t2 = dsharp.tensor [ 1, 2, 3, 4, 5, 6 ]

In [None]:
let t3 = dsharp.tensor [[[[0.0 .. 10.0]]]]
let t4 = dsharp.tensor [[[[0.0 ..0.1.. 1.0]]]]
t3.conv2d(t4)

In [None]:
let f (x: Tensor) = x.exp().sum()

dsharp.grad f (dsharp.tensor([1.8, 2.5]))

In [None]:
let x0 = dsharp.tensor(1.)
let y0 = dsharp.tensor(2.)

dsharp.diff (fun y -> 2 * y)
dsharp.diff (fun x -> x * dsharp.diff (fun y -> x * y) y0) x0

In [None]:
open DiffSharp.Data
open DiffSharp.Model
open DiffSharp.Compose
open DiffSharp.Util
open DiffSharp.Optim

In [None]:
let epochs = 2
let batchSize = 32
let numBatches = 5

In [None]:
//let trainSet = MNIST("../data", train=true, transform=id)
//let trainLoader = trainSet.loader(batchSize=batchSize, shuffle=true)

//let validSet = MNIST("../data", train=false, transform=id)
//let validLoader = validSet.loader(batchSize=batchSize, shuffle=false)

In [None]:
let encoder =
    Conv2d(1, 32, 4, 2)
    --> dsharp.relu
    --> Conv2d(32, 64, 4, 2)
    --> dsharp.relu
    --> Conv2d(64, 128, 4, 2)
    --> dsharp.flatten(1)

In [None]:
let decoder =
    dsharp.unflatten(1, [128;1;1])
    --> ConvTranspose2d(128, 64, 4, 2)
    --> dsharp.relu
    --> ConvTranspose2d(64, 32, 4, 3)
    --> dsharp.relu
    --> ConvTranspose2d(32, 1, 4, 2)
    --> dsharp.sigmoid

In [None]:
let model = VAE([1;28;28], 64, encoder, decoder)

In [None]:
let lr = dsharp.tensor(0.001)
let optimizer = Adam(model, lr=lr)  

In [None]:
for epoch = 1 to epochs do
    let batches = trainLoader.epoch(numBatches)
    for i, x, _ in batches do
        model.reverseDiff()
        let l = model.loss(x)
        l.reverse()
        optimizer.step()
        print $"Epoch: {epoch} minibatch: {i} loss: {l}" 

let validLoss = 
    validLoader.epoch() 
    |> Seq.sumBy (fun (_, x, _) -> model.loss(x, normalize=false))

print $"Validation loss: {validLoss/validSet.length}"